1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifdef TENSORFLOW_USE_MPI
17
18#if GOOGLE_CUDA
19
20#define EIGEN_USE_GPU
21
22#include "tensorflow/contrib/mpi_collectives/kernels/ring.h"
23
24namespace tensorflow {
25namespace contrib {
26namespace mpi_collectives {
27
28using CPUDevice = Eigen::ThreadPoolDevice;
29
30template <>
31MPI_Datatype MPIType<float>() {
32  return MPI_FLOAT;
33};
34template <>
35MPI_Datatype MPIType<int>() {
36  return MPI_INT;
37};
38template <>
39MPI_Datatype MPIType<long long>() {
40  return MPI_LONG_LONG;
41};
42
43template <>
44DataType TensorFlowDataType<float>() {
45  return DT_FLOAT;
46};
47template <>
48DataType TensorFlowDataType<int>() {
49  return DT_INT32;
50};
51template <>
52DataType TensorFlowDataType<long long>() {
53  return DT_INT64;
54};
55
56// Generate all necessary specializations for RingAllreduce.
57template Status RingAllreduce<GPUDevice, int>(OpKernelContext*, const Tensor*,
58                                              Tensor*, Tensor*);
59template Status RingAllreduce<GPUDevice, long long>(OpKernelContext*,
60                                                    const Tensor*, Tensor*,
61                                                    Tensor*);
62template Status RingAllreduce<GPUDevice, float>(OpKernelContext*, const Tensor*,
63                                                Tensor*, Tensor*);
64
65// Generate all necessary specializations for RingAllgather.
66template Status RingAllgather<GPUDevice, int>(OpKernelContext*, const Tensor*,
67                                              const std::vector<size_t>&,
68                                              Tensor*);
69template Status RingAllgather<GPUDevice, long long>(OpKernelContext*,
70                                                    const Tensor*,
71                                                    const std::vector<size_t>&,
72                                                    Tensor*);
73template Status RingAllgather<GPUDevice, float>(OpKernelContext*, const Tensor*,
74                                                const std::vector<size_t>&,
75                                                Tensor*);
76
77// Synchronously copy data on the GPU, using a different stream than the default
78// and than TensorFlow to avoid synchronizing on operations unrelated to the
79// allreduce.
80template <>
81void CopyTensorData<GPUDevice>(void* dst, void* src, size_t size) {
82  auto stream = CudaStreamForMPI();
83  cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream);
84  cudaStreamSynchronize(stream);
85};
86
87// Elementwise accumulation kernel for GPU.
88template <typename T>
89__global__ void elemwise_accum(T* out, const T* in, const size_t N) {
90  for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
91       i += blockDim.x * gridDim.x) {
92    out[i] += in[i];
93  }
94}
95
96// Synchronously accumulate tensors on the GPU, using a different stream than
97// the default and than TensorFlow to avoid synchronizing on operations
98// unrelated to the allreduce.
99#define GENERATE_ACCUMULATE(type)                                    \
100  template <>                                                        \
101  void AccumulateTensorData<GPUDevice, type>(type * dst, type * src, \
102                                             size_t size) {          \
103    auto stream = CudaStreamForMPI();                                \
104    elemwise_accum<type><<<32, 256, 0, stream>>>(dst, src, size);    \
105    cudaStreamSynchronize(stream);                                   \
106  };
107GENERATE_ACCUMULATE(int);
108GENERATE_ACCUMULATE(long long);
109GENERATE_ACCUMULATE(float);
110#undef GENERATE_ACCUMULATE
111
112}  // namespace mpi_collectives
113}  // namespace contrib
114}  // namespace tensorflow
115#endif  // GOOGLE_CUDA
116
117#endif  // TENSORFLOW_USE_MPI
118