1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/array_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#if GOOGLE_CUDA
21#define EIGEN_USE_GPU
22#endif  // GOOGLE_CUDA
23
24#include <numeric>
25
26#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27#include "tensorflow/core/framework/op_kernel.h"
28#include "tensorflow/core/framework/register_types.h"
29#include "tensorflow/core/framework/tensor.h"
30#include "tensorflow/core/kernels/bounds_check.h"
31#include "tensorflow/core/kernels/ops_util.h"
32#include "tensorflow/core/kernels/split_lib.h"
33#include "tensorflow/core/lib/core/status.h"
34#include "tensorflow/core/lib/gtl/array_slice.h"
35#include "tensorflow/core/util/work_sharder.h"
36#if GOOGLE_CUDA
37#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
38#include "tensorflow/core/kernels/cuda_device_array.h"
39#include "tensorflow/core/platform/stream_executor.h"
40#endif  // GOOGLE_CUDA
41
42namespace tensorflow {
43
44typedef Eigen::ThreadPoolDevice CPUDevice;
45typedef Eigen::GpuDevice GPUDevice;
46
47template <typename Device, typename T, typename Tlen>
48class SplitVOpBase : public OpKernel {
49 public:
50  explicit SplitVOpBase(OpKernelConstruction* c) : OpKernel(c) {}
51
52  void ComputeEasyCases(OpKernelContext* context, bool* done,
53                        std::vector<Tlen>* split_sizes_vec) {
54    const int32 num_split = context->num_outputs();
55    const Tensor& input = context->input(0);
56    const TensorShape& input_shape = input.shape();
57    const Tensor& split_tensor = context->input(1);
58
59    const int32 split_dim_orig = context->input(2).flat<int32>()(0);
60    const int32 split_dim =
61        split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
62
63    OP_REQUIRES(
64        context,
65        split_tensor.dims() == 1 && split_tensor.NumElements() == num_split,
66        errors::InvalidArgument("size of the split_tensor must be 1-D and have "
67                                "the same elements as outputs got ",
68                                split_tensor.dims(), " -D and ",
69                                split_tensor.NumElements(), " elements"));
70
71    auto split_sizes_d = split_tensor.vec<Tlen>();
72
73    split_sizes_vec->resize(split_sizes_d.size());
74
75    std::copy(split_sizes_d.data(), split_sizes_d.data() + split_sizes_d.size(),
76              split_sizes_vec->begin());
77
78    OP_REQUIRES(
79        context, num_split > 0,
80        errors::InvalidArgument(
81            "Number of ways to split should be > 0, but got ", num_split));
82
83    OP_REQUIRES(
84        context, 0 <= split_dim && split_dim < input.dims(),
85        errors::InvalidArgument("-input rank(-", input.dims(),
86                                ") <= split_dim < input rank (", input.dims(),
87                                "), but got ", split_dim_orig));
88
89    Tlen input_size_split_dim = input_shape.dim_size(split_dim);
90
91    // Special case 1: num_split == 1. Nothing to do.
92    if (num_split == 1) {
93      context->set_output(0, context->input(0));
94      OP_REQUIRES(
95          context, (*split_sizes_vec)[0] == input_size_split_dim,
96          errors::InvalidArgument("If there is only one output, it must have "
97                                  "the same size as the input. Input size: ",
98                                  input_size_split_dim,
99                                  " output size: ", (*split_sizes_vec)[0]));
100      *done = true;
101      return;
102    }
103
104    // Determine sizes of output, in case of a -1 input value
105    int neg_one_dim = -1;
106    Tlen determined_size = 0;
107    for (int d = 0; d < split_sizes_vec->size(); ++d) {
108      Tlen size = (*split_sizes_vec)[d];
109
110      if (size == -1) {
111        OP_REQUIRES(context, neg_one_dim == -1,
112                    errors::InvalidArgument("There can only be one -1 in the "
113                                            "input."));
114        neg_one_dim = d;
115      } else {
116        determined_size += size;
117      }
118    }
119
120    OP_REQUIRES(
121        context,
122        (neg_one_dim == -1 && determined_size == input_size_split_dim) ||
123            (neg_one_dim >= 0 && determined_size <= input_size_split_dim),
124        errors::InvalidArgument("Determined shape must either match "
125                                "input shape along split_dim exactly if "
126                                "fully specified, or be less than the size of "
127                                "the input along split_dim if not fully "
128                                "specified.  Got: ",
129                                determined_size));
130
131    if (neg_one_dim >= 0) {
132      (*split_sizes_vec)[neg_one_dim] = input_size_split_dim - determined_size;
133    }
134
135    // Special case 2: split along the 1st dimension. We can share the
136    // underlying buffer.
137    //
138    // Apply this optimization conservatively: if input is aligned,
139    // the resulting tensors must be aligned. It's conservative
140    // because if the immediate consumer of the resulting tensors are
141    // not using eigen for computation, its perfectly fine to avoid
142    // the copying.
143    if ((split_dim == 0) && IsInnerDimsSizeAligned<T>(input_shape)) {
144      Tlen start = 0;
145      for (int i = 0; i < num_split; ++i) {
146        context->set_output(i,
147                            input.Slice(start, start + (*split_sizes_vec)[i]));
148        start += (*split_sizes_vec)[i];
149      }
150      *done = true;
151      return;
152    }
153  }
154
155  template <typename IndexType>
156  std::tuple<IndexType, IndexType, IndexType> SetDims(
157      const TensorShape& input_shape, const int32 split_dim) const {
158    static_assert(std::is_integral<IndexType>::value,
159                  "IndexType must be an integer type");
160    int32 prefix_dim_size = 1;
161    for (int i = 0; i < split_dim; ++i) {
162      prefix_dim_size *= input_shape.dim_size(i);
163    }
164
165    // Caller must ensure that dim_size and suffix_dim_size are <
166    // std::numeric_limits<IndexType>::max()
167    IndexType split_dim_size =
168        static_cast<IndexType>(input_shape.dim_size(split_dim));
169
170    IndexType suffix_dim_size = 1;
171    for (int i = split_dim + 1; i < input_shape.dims(); ++i) {
172      suffix_dim_size *= static_cast<IndexType>(input_shape.dim_size(i));
173    }
174    return std::make_tuple(prefix_dim_size, split_dim_size, suffix_dim_size);
175  }
176};
177
178template <typename T, typename Tlen>
179class SplitVOpCPU : public SplitVOpBase<CPUDevice, T, Tlen> {
180 public:
181  typedef SplitVOpBase<CPUDevice, T, Tlen> Base;
182  explicit SplitVOpCPU(OpKernelConstruction* c) : Base(c) {}
183
184  void Compute(OpKernelContext* context) override {
185    bool done = false;
186    std::vector<Tlen> split_sizes_vec;
187    Base::ComputeEasyCases(context, &done, &split_sizes_vec);
188    if (!context->status().ok() || done) {
189      return;
190    }
191    const int32 num_split = Base::num_outputs();
192    const Tensor& input = context->input(0);
193    const TensorShape& input_shape = input.shape();
194    const int32 split_dim_orig = context->input(2).flat<int32>()(0);
195    const int32 split_dim =
196        split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
197
198    // Android also uses int32 indexing, so check here also.
199    OP_REQUIRES(
200        context,
201        FastBoundsCheck(input.NumElements(),
202                        std::numeric_limits<Eigen::DenseIndex>::max()),
203        errors::InvalidArgument("Split requires input size < ",
204                                std::numeric_limits<Eigen::DenseIndex>::max()));
205
206    Eigen::DenseIndex prefix_dim_size;
207    Eigen::DenseIndex split_dim_size;
208    Eigen::DenseIndex suffix_dim_size;
209
210    std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) =
211        Base::template SetDims<Eigen::DenseIndex>(input_shape, split_dim);
212    auto input_reshaped =
213        input.shaped<T, 3>({prefix_dim_size, split_dim_size, suffix_dim_size});
214
215    Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, 0, 0};
216    std::vector<int64> split_start_points(num_split);
217    for (int i = 0; i < num_split; ++i) {
218      if (i == 0) {
219        split_start_points[i] = 0;
220      } else {
221        split_start_points[i] =
222            split_start_points[i - 1] + split_sizes_vec[i - 1];
223      }
224    }
225
226    const auto num_threads =
227        context->device()->tensorflow_cpu_worker_threads()->num_threads;
228    // TODO(jewillco): Tune heuristic further.
229    const auto input_element_count = input_shape.num_elements();
230    const bool use_parallelism_between_outputs =
231        (num_split >= 4 &&
232         input_element_count >= std::max(num_threads, num_split) * 4096 &&
233         input_element_count < num_split * 180 * 1024);
234
235    auto range_output_func = [&indices, context, &input_shape, prefix_dim_size,
236                              split_dim, &split_sizes_vec, &split_start_points,
237                              suffix_dim_size, use_parallelism_between_outputs,
238                              &input_reshaped](int64 start, int64 limit) {
239      for (int64 i = start; i < limit; ++i) {
240        TensorShape output_shape(input_shape);
241        output_shape.set_dim(split_dim, split_sizes_vec[i]);
242        Tensor* result = nullptr;
243        OP_REQUIRES_OK(context,
244                       context->allocate_output(i, output_shape, &result));
245
246        Eigen::DSizes<Eigen::DenseIndex, 3> sizes{
247            prefix_dim_size, split_sizes_vec[i], suffix_dim_size};
248
249        if (sizes.TotalSize() > 0) {
250          auto result_shaped = result->shaped<T, 3>(
251              {prefix_dim_size, split_sizes_vec[i], suffix_dim_size});
252
253          auto current_indices = indices;
254          current_indices[1] = split_start_points[i];
255          if (use_parallelism_between_outputs) {
256            // Use sequential implementation for single output.
257            result_shaped = input_reshaped.slice(current_indices, sizes);
258          } else {
259            // This implementation may be parallel internally.
260            functor::Split<CPUDevice, T>()(context->eigen_device<CPUDevice>(),
261                                           result_shaped, input_reshaped,
262                                           current_indices, sizes);
263          }
264        }
265      }
266    };
267    if (use_parallelism_between_outputs) {
268      // Run in parallel, disabling parallelism in functor.
269      Shard(num_split,
270            context->device()->tensorflow_cpu_worker_threads()->workers,
271            num_split, input_element_count / num_split, range_output_func);
272    } else {
273      // Run sequentially, but allow internal parallelism in functor.
274      range_output_func(0, num_split);
275    }
276  }
277};
278
279#if GOOGLE_CUDA
280
281template <typename T, typename IntType>
282struct SplitVOpGPULaunch {
283  void Run(const Eigen::GpuDevice& d, bool fixed, const T* input,
284           int total_cols, int total_rows,
285           const CudaDeviceArrayStruct<IntType>& output_scan,
286           const CudaDeviceArrayStruct<T*>& output_ptr_data);
287};
288
289// Partial specialization for GPU
290template <typename T, typename Tlen>
291class SplitVOpGPU : public SplitVOpBase<GPUDevice, T, Tlen> {
292 public:
293  typedef SplitVOpBase<GPUDevice, T, Tlen> Base;
294  explicit SplitVOpGPU(OpKernelConstruction* c) : Base(c) {}
295
296  void Compute(OpKernelContext* context) override {
297    bool done = false;
298    std::vector<Tlen> split_sizes_vec;
299    Base::ComputeEasyCases(context, &done, &split_sizes_vec);
300    if (!context->status().ok() || done) {
301      return;
302    }
303    const int32 num_split = Base::num_outputs();
304    const Tensor& input = context->input(0);
305    const TensorShape& input_shape = input.shape();
306    const int32 split_dim_orig = context->input(2).flat<int32>()(0);
307    const int32 split_dim =
308        split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
309    OP_REQUIRES(
310        context,
311        FastBoundsCheck(input.NumElements(), std::numeric_limits<int32>::max()),
312        errors::InvalidArgument("Split on GPU requires input size "
313                                "< max int32"));
314
315    int32 prefix_dim_size;
316    int32 split_dim_size;
317    int32 suffix_dim_size;
318    std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) =
319        Base::template SetDims<int32>(input_shape, split_dim);
320
321    // use the same approach as concat (see documentation there)
322    // reshape to 2D
323
324    if (num_split > 16) {
325      CudaDeviceArrayOnHost<T*> ptrs(context, num_split);
326      OP_REQUIRES_OK(context, ptrs.Init());
327
328      CudaDeviceArrayOnHost<Tlen> offsets(context, num_split + 1);
329      OP_REQUIRES_OK(context, offsets.Init());
330
331      Tlen offset = 0;
332      int entry = split_sizes_vec[0];
333      bool fixed_size =
334          std::all_of(split_sizes_vec.begin(), split_sizes_vec.end(),
335                      [&entry](int n) { return n == entry; });
336
337      for (int i = 0; i < num_split; ++i) {
338        TensorShape output_shape(input_shape);
339        output_shape.set_dim(split_dim, split_sizes_vec[i]);
340        Tensor* result = nullptr;
341        OP_REQUIRES_OK(context,
342                       context->allocate_output(i, output_shape, &result));
343        ptrs.Set(i, result->flat<T>().data());
344        offsets.Set(i, offset);
345        offset += split_sizes_vec[i] * suffix_dim_size;
346      }
347      offsets.Set(num_split, offset);
348      OP_REQUIRES_OK(context, ptrs.Finalize());
349      OP_REQUIRES_OK(context, offsets.Finalize());
350
351      if (input.NumElements() > 0) {
352        SplitVOpGPULaunch<T, Tlen>().Run(
353            context->eigen_device<GPUDevice>(), fixed_size,
354            input.flat<T>().data(), prefix_dim_size,
355            input.NumElements() / prefix_dim_size, offsets.data(), ptrs.data());
356        OP_REQUIRES(
357            context, context->op_device_context()->stream()->ok(),
358            errors::Internal("Launch of gpu kernel for SplitVOp failed"));
359      }
360    } else {
361      Eigen::DenseIndex prefix_dim_size;
362      Eigen::DenseIndex split_dim_size;
363      Eigen::DenseIndex suffix_dim_size;
364
365      std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) =
366          Base::template SetDims<Eigen::DenseIndex>(input_shape, split_dim);
367      auto input_reshaped = input.shaped<T, 2>(
368          {prefix_dim_size, split_dim_size * suffix_dim_size});
369
370      Eigen::DSizes<Eigen::DenseIndex, 2> indices{0, 0};
371
372      for (int i = 0; i < num_split; ++i) {
373        TensorShape output_shape(input_shape);
374        output_shape.set_dim(split_dim, split_sizes_vec[i]);
375        Tensor* result = nullptr;
376        OP_REQUIRES_OK(context,
377                       context->allocate_output(i, output_shape, &result));
378
379        Eigen::DSizes<Eigen::DenseIndex, 2> sizes{
380            prefix_dim_size, split_sizes_vec[i] * suffix_dim_size};
381
382        if (sizes.TotalSize() > 0) {
383          auto result_shaped = result->shaped<T, 2>(
384              {prefix_dim_size, split_sizes_vec[i] * suffix_dim_size});
385
386          functor::SplitCustom<GPUDevice, T>()(
387              context->eigen_device<GPUDevice>(), result_shaped, input_reshaped,
388              indices, sizes);
389        }
390        indices[1] += split_sizes_vec[i] * suffix_dim_size;
391      }
392    }
393  }
394};
395#endif  // GOOGLE_CUDA
396
397#define REGISTER_SPLIT(type, len_type)                          \
398  REGISTER_KERNEL_BUILDER(Name("SplitV")                        \
399                              .Device(DEVICE_CPU)               \
400                              .TypeConstraint<len_type>("Tlen") \
401                              .TypeConstraint<type>("T")        \
402                              .HostMemory("size_splits")        \
403                              .HostMemory("split_dim"),         \
404                          SplitVOpCPU<type, len_type>);
405
406#define REGISTER_SPLIT_LEN(type) \
407  REGISTER_SPLIT(type, int32);   \
408  REGISTER_SPLIT(type, int64);
409
410TF_CALL_ALL_TYPES(REGISTER_SPLIT_LEN);
411
412#undef REGISTER_SPLIT_LEN
413#undef REGISTER_SPLIT
414
415#if GOOGLE_CUDA
416
417#define REGISTER_GPU(type, len_type)                            \
418  REGISTER_KERNEL_BUILDER(Name("SplitV")                        \
419                              .Device(DEVICE_GPU)               \
420                              .TypeConstraint<len_type>("Tlen") \
421                              .TypeConstraint<type>("T")        \
422                              .HostMemory("size_splits")        \
423                              .HostMemory("split_dim"),         \
424                          SplitVOpGPU<type, len_type>);
425
426#define REGISTER_GPU_LEN(type) \
427  REGISTER_GPU(type, int32);   \
428  REGISTER_GPU(type, int64);
429
430TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_LEN);
431TF_CALL_complex64(REGISTER_GPU_LEN);
432TF_CALL_complex128(REGISTER_GPU_LEN);
433REGISTER_GPU_LEN(bfloat16);
434#undef REGISTER_GPU_LEN
435#undef REGISTER_GPU
436
437// special GPU kernel for int32
438
439#define REGISTER_GPU_int32(len_type)                            \
440  REGISTER_KERNEL_BUILDER(Name("SplitV")                        \
441                              .Device(DEVICE_GPU)               \
442                              .TypeConstraint<int32>("T")       \
443                              .TypeConstraint<len_type>("Tlen") \
444                              .HostMemory("size_splits")        \
445                              .HostMemory("split_dim")          \
446                              .HostMemory("value")              \
447                              .HostMemory("output"),            \
448                          SplitVOpCPU<int32, len_type>);
449
450REGISTER_GPU_int32(int32);
451REGISTER_GPU_int32(int64);
452
453#undef REGISTER_GPU_int32
454
455#endif  // GOOGLE_CUDA
456
457}  // end namespace tensorflow
458