1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/array_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21#include "tensorflow/core/framework/op_kernel.h"
22#include "tensorflow/core/framework/register_types.h"
23#include "tensorflow/core/framework/tensor.h"
24#include "tensorflow/core/kernels/bounds_check.h"
25#include "tensorflow/core/kernels/ops_util.h"
26#include "tensorflow/core/kernels/split_lib.h"
27#include "tensorflow/core/lib/core/status.h"
28#include "tensorflow/core/lib/gtl/array_slice.h"
29#include "tensorflow/core/util/work_sharder.h"
30#if GOOGLE_CUDA
31#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
32#include "tensorflow/core/kernels/cuda_device_array.h"
33#include "tensorflow/core/platform/stream_executor.h"
34#endif  // GOOGLE_CUDA
35
36namespace tensorflow {
37
38typedef Eigen::ThreadPoolDevice CPUDevice;
39typedef Eigen::GpuDevice GPUDevice;
40#ifdef TENSORFLOW_USE_SYCL
41typedef Eigen::SyclDevice SYCLDevice;
42#endif  // TENSORFLOW_USE_SYCL
43
44template <typename Device, typename T>
45class SplitOpBase : public OpKernel {
46 public:
47  explicit SplitOpBase(OpKernelConstruction* c) : OpKernel(c) {}
48
49  void ComputeEasyCases(OpKernelContext* context, bool* done) {
50    const Tensor& input = context->input(1);
51    const TensorShape& input_shape = input.shape();
52    const int32 split_dim_orig = context->input(0).flat<int32>()(0);
53    const int32 split_dim =
54        split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
55    const int32 num_split = num_outputs();
56
57    OP_REQUIRES(
58        context, 0 <= split_dim && split_dim < input_shape.dims(),
59        errors::InvalidArgument("-input rank(-", input.dims(),
60                                ") <= split_dim < input rank (", input.dims(),
61                                "), but got ", split_dim_orig));
62
63    OP_REQUIRES(
64        context, num_split > 0,
65        errors::InvalidArgument(
66            "Number of ways to split should be > 0, but got ", num_split));
67
68    OP_REQUIRES(context, input_shape.dim_size(split_dim) % num_split == 0,
69                errors::InvalidArgument(
70                    "Number of ways to split should evenly divide the split "
71                    "dimension, but got split_dim ",
72                    split_dim, " (size = ", input_shape.dim_size(split_dim),
73                    ") ", "and num_split ", num_split));
74    // Special case 1: num_split == 1. Nothing to do.
75    if (num_split == 1) {
76      VLOG(1) << "Split identity";
77      context->set_output(0, context->input(1));
78      *done = true;
79      return;
80    }
81
82    // Special case 2: split along the 1st dimension. We can share the
83    // underlying buffer.
84    //
85    // Apply this optimization conservatively: if input is aligned,
86    // the resulting tensors must be aligned. It's conservative
87    // because if the immediate consumer of the resulting tensors are
88    // not using eigen for computation, its perfectly fine to avoid
89    // the copying.
90    if ((split_dim == 0) && IsInnerDimsSizeAligned<T>(input_shape)) {
91      VLOG(1) << "Slice dim 0: " << input_shape.DebugString();
92      const int64 delta = input_shape.dim_size(0) / num_split;
93      for (int i = 0; i < num_split; ++i) {
94        context->set_output(i, input.Slice(i * delta, (i + 1) * delta));
95      }
96      *done = true;
97      return;
98    }
99  }
100
101  template <typename IndexType>
102  std::tuple<IndexType, IndexType, IndexType> SetDims(
103      const TensorShape& input_shape, int32 split_dim) const {
104    static_assert(std::is_integral<IndexType>::value,
105                  "IndexType must be an integer type");
106    int32 prefix_dim_size = 1;
107    for (int i = 0; i < split_dim; ++i) {
108      prefix_dim_size *= input_shape.dim_size(i);
109    }
110
111    // Caller must ensure that dim_size and suffix_dim_size are <
112    // std::numeric_limits<IndexType>::max()
113    IndexType split_dim_size =
114        static_cast<IndexType>(input_shape.dim_size(split_dim));
115
116    IndexType suffix_dim_size = 1;
117    for (int i = split_dim + 1; i < input_shape.dims(); ++i) {
118      suffix_dim_size *= static_cast<IndexType>(input_shape.dim_size(i));
119    }
120    return std::make_tuple(prefix_dim_size, split_dim_size, suffix_dim_size);
121  }
122};
123
124template <typename T>
125class SplitOpCPU : public SplitOpBase<CPUDevice, T> {
126 public:
127  typedef SplitOpBase<CPUDevice, T> Base;
128  explicit SplitOpCPU(OpKernelConstruction* c) : Base(c) {}
129
130  void Compute(OpKernelContext* context) override {
131    bool done = false;
132    Base::ComputeEasyCases(context, &done);
133    if (!context->status().ok() || done) {
134      return;
135    }
136    const int32 num_split = Base::num_outputs();
137    const Tensor& input = context->input(1);
138    const TensorShape& input_shape = input.shape();
139    const int32 split_dim_orig = context->input(0).flat<int32>()(0);
140    const int32 split_dim =
141        split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
142
143    // Android also uses int32 indexing, so check here also.
144    OP_REQUIRES(
145        context,
146        FastBoundsCheck(input.NumElements(),
147                        std::numeric_limits<Eigen::DenseIndex>::max()),
148        errors::InvalidArgument("Split requires input size < ",
149                                std::numeric_limits<Eigen::DenseIndex>::max()));
150
151    Eigen::DenseIndex prefix_dim_size;
152    Eigen::DenseIndex split_dim_size;
153    Eigen::DenseIndex suffix_dim_size;
154
155    std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) =
156        Base::template SetDims<Eigen::DenseIndex>(input_shape, split_dim);
157    auto input_reshaped =
158        input.shaped<T, 3>({prefix_dim_size, split_dim_size, suffix_dim_size});
159
160    const int64 split_dim_output_size = split_dim_size / num_split;
161    TensorShape output_shape(input_shape);
162    output_shape.set_dim(split_dim, split_dim_output_size);
163
164    Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, 0, 0};
165    const Eigen::DSizes<Eigen::DenseIndex, 3> sizes{
166        prefix_dim_size, split_dim_output_size, suffix_dim_size};
167
168    const auto num_threads =
169        context->device()->tensorflow_cpu_worker_threads()->num_threads;
170    // TODO(jewillco): Tune heuristic further.
171    const auto input_element_count = input_shape.num_elements();
172    const bool use_parallelism_between_outputs =
173        (num_split >= 4 &&
174         input_element_count >= std::max(num_threads, num_split) * 4096 &&
175         input_element_count < num_split * 180 * 1024);
176
177    auto range_output_func = [&indices, context, &output_shape, prefix_dim_size,
178                              split_dim_output_size, suffix_dim_size, &sizes,
179                              use_parallelism_between_outputs,
180                              &input_reshaped](int64 start, int64 limit) {
181      for (int64 i = start; i < limit; ++i) {
182        Tensor* result = nullptr;
183        OP_REQUIRES_OK(context,
184                       context->allocate_output(i, output_shape, &result));
185        if (prefix_dim_size * split_dim_output_size * suffix_dim_size > 0) {
186          Eigen::DSizes<Eigen::DenseIndex, 3> slice_indices;
187          Eigen::DSizes<Eigen::DenseIndex, 3> slice_sizes;
188          for (int j = 0; j < 3; ++j) {
189            slice_indices[j] =
190                (j == 1 ? i * split_dim_output_size : indices[j]);
191            slice_sizes[j] = sizes[j];
192          }
193
194          auto result_shaped = result->shaped<T, 3>(
195              {prefix_dim_size, split_dim_output_size, suffix_dim_size});
196
197          if (use_parallelism_between_outputs) {
198            // Use sequential implementation for single output.
199            result_shaped = input_reshaped.slice(slice_indices, slice_sizes);
200          } else {
201            // This implementation may be parallel internally.
202            functor::Split<CPUDevice, T>()(context->eigen_device<CPUDevice>(),
203                                           result_shaped, input_reshaped,
204                                           slice_indices, slice_sizes);
205          }
206        }
207      }
208    };
209    if (use_parallelism_between_outputs) {
210      // Run in parallel, disabling parallelism in functor.
211      Shard(num_split,
212            context->device()->tensorflow_cpu_worker_threads()->workers,
213            num_split, input_element_count / num_split, range_output_func);
214    } else {
215      // Run sequentially, but allow internal parallelism in functor.
216      range_output_func(0, num_split);
217    }
218  }
219};
220
221#if GOOGLE_CUDA
222
223template <typename T>
224struct SplitOpGPULaunch {
225  void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size,
226           int32 split_dim_size, int32 suffix_dim_size,
227           const CudaDeviceArrayStruct<T*>& output_ptr_data);
228};
229
230// Partial specialization for GPU
231template <typename T>
232class SplitOpGPU : public SplitOpBase<GPUDevice, T> {
233 public:
234  typedef SplitOpBase<GPUDevice, T> Base;
235  explicit SplitOpGPU(OpKernelConstruction* c) : Base(c) {}
236
237  void Compute(OpKernelContext* context) override {
238    bool done = false;
239    Base::ComputeEasyCases(context, &done);
240    if (!context->status().ok() || done) {
241      return;
242    }
243    const Tensor& input = context->input(1);
244    const TensorShape& input_shape = input.shape();
245    const int32 split_dim_orig = context->input(0).flat<int32>()(0);
246    const int32 split_dim =
247        split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
248    const int32 num_split = Base::num_outputs();
249    OP_REQUIRES(
250        context,
251        FastBoundsCheck(input.NumElements(), std::numeric_limits<int32>::max()),
252        errors::InvalidArgument("Split on GPU requires input size "
253                                "< max int32"));
254    int32 prefix_dim_size;
255    int32 split_dim_size;
256    int32 suffix_dim_size;
257    std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) =
258        Base::template SetDims<int32>(input_shape, split_dim);
259
260    const int32 split_dim_output_size = split_dim_size / num_split;
261    TensorShape output_shape(input_shape);
262    output_shape.set_dim(split_dim, split_dim_output_size);
263
264    CudaDeviceArrayOnHost<T*> ptrs(context, num_split);
265    OP_REQUIRES_OK(context, ptrs.Init());
266
267    for (int i = 0; i < num_split; ++i) {
268      Tensor* result = nullptr;
269      OP_REQUIRES_OK(context,
270                     context->allocate_output(i, output_shape, &result));
271      ptrs.Set(i, result->flat<T>().data());
272    }
273    if (prefix_dim_size * split_dim_output_size * suffix_dim_size == 0) {
274      return;
275    }
276    OP_REQUIRES_OK(context, ptrs.Finalize());
277
278    SplitOpGPULaunch<T>().Run(context->eigen_device<GPUDevice>(),
279                              input.flat<T>().data(), prefix_dim_size,
280                              split_dim_size, suffix_dim_size, ptrs.data());
281    OP_REQUIRES(context, context->op_device_context()->stream()->ok(),
282                errors::Internal("Launch of gpu kernel for SplitOp failed"));
283  }
284};
285#endif  // GOOGLE_CUDA
286
287#ifdef TENSORFLOW_USE_SYCL
288template <typename T>
289class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> {
290 public:
291  typedef SplitOpBase<SYCLDevice, T> Base;
292  explicit SplitOpSYCL(OpKernelConstruction* c) : Base(c) {}
293
294  void Compute(OpKernelContext* context) override {
295    bool done = false;
296    Base::ComputeEasyCases(context, &done);
297    if (!context->status().ok() || done) {
298      return;
299    }
300    const Tensor& input = context->input(1);
301    const TensorShape& input_shape = input.shape();
302    const int32 split_dim_orig = context->input(0).flat<int32>()(0);
303    const int32 split_dim =
304        split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
305    const int32 num_split = Base::num_outputs();
306
307    // Android also uses int32 indexing, so check here also.
308    OP_REQUIRES(
309        context,
310        FastBoundsCheck(input.NumElements(),
311                        std::numeric_limits<Eigen::DenseIndex>::max()),
312        errors::InvalidArgument("Split requires input size < ",
313                                std::numeric_limits<Eigen::DenseIndex>::max()));
314
315    Eigen::DenseIndex prefix_dim_size;
316    Eigen::DenseIndex split_dim_size;
317    Eigen::DenseIndex suffix_dim_size;
318
319    std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) =
320        Base::template SetDims<Eigen::DenseIndex>(input_shape, split_dim);
321    auto input_reshaped =
322        input.shaped<T, 3>({prefix_dim_size, split_dim_size, suffix_dim_size});
323
324    const int64 split_dim_output_size = split_dim_size / num_split;
325    TensorShape output_shape(input_shape);
326    output_shape.set_dim(split_dim, split_dim_output_size);
327
328    Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, 0, 0};
329    Eigen::DSizes<Eigen::DenseIndex, 3> sizes{
330        prefix_dim_size, split_dim_output_size, suffix_dim_size};
331
332    for (int i = 0; i < num_split; ++i) {
333      Tensor* result = nullptr;
334      OP_REQUIRES_OK(context,
335                     context->allocate_output(i, output_shape, &result));
336      if (prefix_dim_size * split_dim_output_size * suffix_dim_size > 0) {
337        Eigen::DSizes<Eigen::DenseIndex, 3> slice_indices;
338        Eigen::DSizes<Eigen::DenseIndex, 3> slice_sizes;
339        for (int j = 0; j < 3; ++j) {
340          slice_indices[j] = indices[j];
341          slice_sizes[j] = sizes[j];
342        }
343
344        auto result_shaped = result->shaped<T, 3>(
345            {prefix_dim_size, split_dim_output_size, suffix_dim_size});
346
347        functor::Split<SYCLDevice, T>()(context->eigen_device<SYCLDevice>(),
348                                        result_shaped, input_reshaped,
349                                        slice_indices, slice_sizes);
350      }
351      indices[1] += split_dim_output_size;
352    }
353  }
354};
355#endif  // TENSORFLOW_USE_SYCL
356
357#define REGISTER_SPLIT(type)                             \
358  REGISTER_KERNEL_BUILDER(Name("Split")                  \
359                              .Device(DEVICE_CPU)        \
360                              .TypeConstraint<type>("T") \
361                              .HostMemory("split_dim"),  \
362                          SplitOpCPU<type>)
363
364TF_CALL_ALL_TYPES(REGISTER_SPLIT);
365REGISTER_SPLIT(quint8);
366
367#undef REGISTER_SPLIT
368
369#if GOOGLE_CUDA
370
371#define REGISTER_GPU(type)                               \
372  REGISTER_KERNEL_BUILDER(Name("Split")                  \
373                              .Device(DEVICE_GPU)        \
374                              .TypeConstraint<type>("T") \
375                              .HostMemory("split_dim"),  \
376                          SplitOpGPU<type>)
377
378TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
379TF_CALL_complex64(REGISTER_GPU);
380TF_CALL_complex128(REGISTER_GPU);
381REGISTER_GPU(bfloat16);
382#undef REGISTER_GPU
383
384#endif  // GOOGLE_CUDA
385
386#ifdef TENSORFLOW_USE_SYCL
387#define REGISTER_SYCL(type)                              \
388  REGISTER_KERNEL_BUILDER(Name("Split")                  \
389                              .Device(DEVICE_SYCL)       \
390                              .TypeConstraint<type>("T") \
391                              .HostMemory("split_dim"),  \
392                          SplitOpSYCL<type>)
393
394TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL);
395#undef REGISTER_SYCL
396
397#endif  // TENSORFLOW_USE_SYCL
398
399}  // end namespace tensorflow
400