16e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
26e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
36e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew HarpLicensed under the Apache License, Version 2.0 (the "License");
46e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harpyou may not use this file except in compliance with the License.
56e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew HarpYou may obtain a copy of the License at
66e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
76e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp    http://www.apache.org/licenses/LICENSE-2.0
86e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
96e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew HarpUnless required by applicable law or agreed to in writing, software
106e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harpdistributed under the License is distributed on an "AS IS" BASIS,
116e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew HarpWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
126e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew HarpSee the License for the specific language governing permissions and
136e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harplimitations under the License.
146e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp==============================================================================*/
156e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
166e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define EIGEN_USE_THREADS
176e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
18008910f1122d115a6d7430bfcc63cf4296c7467dJonathan Hseu#include "tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h"
194463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower#include <algorithm>
206e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#include "tensorflow/core/framework/op.h"
216e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#include "tensorflow/core/framework/op_kernel.h"
226e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#include "tensorflow/core/framework/register_types.h"
236e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#include "tensorflow/core/lib/core/threadpool.h"
246e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
256e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harpnamespace tensorflow {
266e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
276e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harpusing GPUDevice = Eigen::GpuDevice;
286e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harpusing CPUDevice = Eigen::ThreadPoolDevice;
296e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harpusing thread::ThreadPool;
306e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
316e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harpnamespace functor {
326e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
336e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define CPUReduceSliceFunctorReduceop(reduceop, beginning)                    \
346e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  template <typename T, typename Index>                                       \
356e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  struct ReduceSliceFunctor##reduceop<CPUDevice, T, Index> {                  \
366e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp   private:                                                                   \
376e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp    struct XYZ {                                                              \
386e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      Index x, y, z;                                                          \
396e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      XYZ() = default;                                                        \
406e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      XYZ(Index x, Index y, Index z) : x(x), y(y), z(z) {}                    \
416e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp    };                                                                        \
426e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp    inline static XYZ global_index_to_xyz(Index global, XYZ size) {           \
436e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      XYZ ret;                                                                \
446e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      ret.x = global / (size.y * size.z);                                     \
456e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      ret.y = global % (size.y * size.z) / size.z;                            \
466e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      ret.z = global % size.z;                                                \
476e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      return ret;                                                             \
486e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp    }                                                                         \
496e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                                                                              \
506e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp   public:                                                                    \
516e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp    virtual ~ReduceSliceFunctor##reduceop() {}                                \
526e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp    virtual void operator()(OpKernelContext* ctx, const CPUDevice& d,         \
536e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                            Index indices_width,                              \
546e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                            typename TTypes<Index, 1>::ConstTensor indices,   \
556e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                            typename TTypes<T, 3>::ConstTensor data,          \
566e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                            typename TTypes<T, 3>::Tensor output) {           \
576e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      Index bound = data.dimension(1);                                        \
586e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      Index dim1 = output.dimension(0);                                       \
596e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      Index dim2 = output.dimension(1);                                       \
606e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      Index dim3 = output.dimension(2);                                       \
616e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      Index size = dim1 * dim2 * dim3;                                        \
626e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      if (size == 0) {                                                        \
636e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp        return;                                                               \
646e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      }                                                                       \
656e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      T zero = beginning<T>();                                                \
666e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      ThreadPool* thread_pool =                                               \
676e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp          ctx->device()->tensorflow_cpu_worker_threads()->workers;            \
686e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      /* shard the work */                                                    \
696e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      auto work = [&](Index start, Index end) {                               \
706e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp        for (Index global = start; global < end; ++global) {                  \
716e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp          XYZ xyz = global_index_to_xyz(global, XYZ(dim1, dim2, dim3));       \
726e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp          Index x = xyz.x;                                                    \
736e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp          Index y = xyz.y;                                                    \
746e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp          Index z = xyz.z;                                                    \
756e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp          output(x, y, z) = zero;                                             \
766e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp          Index slice_head = indices(y * indices_width);                      \
776e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp          Index slice_end = std::min(indices(y * indices_width + 1), bound);  \
786e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp          for (Index i = slice_head; i < slice_end; ++i) {                    \
796e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp            output(x, y, z) = reduceop(output(x, y, z), data(x, i, z));       \
806e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp          }                                                                   \
816e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp        }                                                                     \
826e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      };                                                                      \
836e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      /* Here assumes the number of average CPU cycles for each slice equals  \
846e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp       * the average length of each slice */                                  \
856e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      thread_pool->ParallelFor(size, std::max(bound / dim2, (Index)1), work); \
866e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp    }                                                                         \
876e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  };
886e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
896e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew HarpCALL_ALL_REDUCEOPS(CPUReduceSliceFunctorReduceop)
906e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#undef CPUReduceSliceFunctorReduceop
916e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
926e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define DEFINE_CPU_SUMPROD_SPECS_INDEX(T, Index)              \
936e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  template struct ReduceSliceFunctorSum<CPUDevice, T, Index>; \
946e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  template struct ReduceSliceFunctorProd<CPUDevice, T, Index>;
956e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
966e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define DEFINE_CPU_MINMAX_SPECS_INDEX(T, Index)               \
976e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  template struct ReduceSliceFunctorMax<CPUDevice, T, Index>; \
986e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  template struct ReduceSliceFunctorMin<CPUDevice, T, Index>;
996e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
1006e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define DEFINE_CPU_SUMPROD_SPECS(T)         \
1016e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  DEFINE_CPU_SUMPROD_SPECS_INDEX(T, int32); \
1026e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  DEFINE_CPU_SUMPROD_SPECS_INDEX(T, int64);
1036e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
1046e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define DEFINE_CPU_MINMAX_SPECS(T)         \
1056e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  DEFINE_CPU_MINMAX_SPECS_INDEX(T, int32); \
1066e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  DEFINE_CPU_MINMAX_SPECS_INDEX(T, int64);
1076e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
1086e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew HarpTF_CALL_NUMBER_TYPES(DEFINE_CPU_SUMPROD_SPECS)
1096e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew HarpTF_CALL_REAL_NUMBER_TYPES(DEFINE_CPU_MINMAX_SPECS)
1106e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
1116e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#undef DEFINE_CPU_SUMPROD_SPECS_INDEX
1126e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#undef DEFINE_CPU_MINMAX_SPECS_INDEX
1136e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#undef DEFINE_CPU_SUMPROD_SPECS
1146e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#undef DEFINE_CPU_MINMAX_SPECS
1156e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
1166e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp}  // namespace functor
1176e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
1186e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harptemplate <typename Device, typename T, typename Index,
1196e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp          template <typename Device2, typename T2, typename Index2>
1206e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp          class Functor>
1216e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harpclass ReduceSliceKernel : public OpKernel {
1226e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp public:
1236e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  explicit ReduceSliceKernel(OpKernelConstruction* context)
1246e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      : OpKernel(context) {}
1256e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
1266e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  void Compute(OpKernelContext* context) override {
1276e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp    const Tensor& data = context->input(0);
1286e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp    const Tensor& indices = context->input(1);
1296e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp    const Tensor& _axis = context->input(2);
1306e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp    int64 axis = _axis.scalar<int64>()();
1316e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
1326e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp    int indices_width = 2;
1336e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp    int out_axis_dim_size = indices.shape().dim_size(0);
1346e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp    if (indices.dims() == 1 || indices.shape().dim_size(1) == 1) {
1356e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      indices_width = 1;
1366e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      if (out_axis_dim_size > 0) {
1376e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp        out_axis_dim_size--;
1386e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp      }
1396e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp    }
1406e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
1416e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp    TensorShape output_shape = data.shape();
1426e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp    output_shape.set_dim(axis, out_axis_dim_size);
1436e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp    Tensor* output = nullptr;
1446e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp    OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
1456e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp    auto functor = Functor<Device, T, Index>();
1466e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp    functor(context, context->eigen_device<Device>(), indices_width,
1476e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp            indices.flat<Index>(), data.flat_inner_outer_dims<T, 3>(axis - 1),
1486e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp            output->flat_inner_outer_dims<T, 3>(axis - 1));
1496e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  }
1506e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp};
1516e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
1526e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS(type, index_type)           \
1536e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  REGISTER_KERNEL_BUILDER(Name("ReduceSliceSum")                              \
1546e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .Device(DEVICE_CPU)                             \
1556e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .TypeConstraint<type>("T")                      \
1566e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .TypeConstraint<index_type>("Tindices"),        \
1576e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                          ReduceSliceKernel<CPUDevice, type, index_type,      \
1586e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                                            functor::ReduceSliceFunctorSum>); \
1596e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  REGISTER_KERNEL_BUILDER(Name("ReduceSliceProd")                             \
1606e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .Device(DEVICE_CPU)                             \
1616e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .TypeConstraint<type>("T")                      \
1626e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .TypeConstraint<index_type>("Tindices"),        \
1636e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                          ReduceSliceKernel<CPUDevice, type, index_type,      \
1646e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                                            functor::ReduceSliceFunctorProd>);
1656e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
1666e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS(type, index_type)            \
1676e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  REGISTER_KERNEL_BUILDER(Name("ReduceSliceMax")                              \
1686e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .Device(DEVICE_CPU)                             \
1696e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .TypeConstraint<type>("T")                      \
1706e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .TypeConstraint<index_type>("Tindices"),        \
1716e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                          ReduceSliceKernel<CPUDevice, type, index_type,      \
1726e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                                            functor::ReduceSliceFunctorMax>); \
1736e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  REGISTER_KERNEL_BUILDER(Name("ReduceSliceMin")                              \
1746e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .Device(DEVICE_CPU)                             \
1756e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .TypeConstraint<type>("T")                      \
1766e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .TypeConstraint<index_type>("Tindices"),        \
1776e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                          ReduceSliceKernel<CPUDevice, type, index_type,      \
1786e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                                            functor::ReduceSliceFunctorMin>);
1796e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
1806e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS_ALL(type) \
1816e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS(type, int32);   \
1826e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS(type, int64);
1836e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
1846e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS_ALL(type) \
1856e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS(type, int32);   \
1866e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS(type, int64);
1876e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
1886e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew HarpTF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS_ALL)
1896e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew HarpTF_CALL_NUMBER_TYPES(REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS_ALL)
1906e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
1916e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#undef REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS
1926e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#undef REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS
1936e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#undef REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS_ALL
1946e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#undef REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS_ALL
1956e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
1966e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#if GOOGLE_CUDA
1976e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
1986e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define REGISTER_GPU_REDUCE_SLICE_KERNELS(type, index_type)                    \
1996e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  REGISTER_KERNEL_BUILDER(Name("ReduceSliceSum")                               \
2006e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .Device(DEVICE_GPU)                              \
2016e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .HostMemory("axis")                              \
2026e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .TypeConstraint<type>("T")                       \
2036e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .TypeConstraint<index_type>("Tindices"),         \
2046e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                          ReduceSliceKernel<GPUDevice, type, index_type,       \
2056e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                                            functor::ReduceSliceFunctorSum>);  \
2066e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  REGISTER_KERNEL_BUILDER(Name("ReduceSliceProd")                              \
2076e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .Device(DEVICE_GPU)                              \
2086e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .HostMemory("axis")                              \
2096e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .TypeConstraint<type>("T")                       \
2106e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .TypeConstraint<index_type>("Tindices"),         \
2116e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                          ReduceSliceKernel<GPUDevice, type, index_type,       \
2126e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                                            functor::ReduceSliceFunctorProd>); \
2136e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  REGISTER_KERNEL_BUILDER(Name("ReduceSliceMax")                               \
2146e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .Device(DEVICE_GPU)                              \
2156e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .HostMemory("axis")                              \
2166e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .TypeConstraint<type>("T")                       \
2176e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .TypeConstraint<index_type>("Tindices"),         \
2186e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                          ReduceSliceKernel<GPUDevice, type, index_type,       \
2196e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                                            functor::ReduceSliceFunctorMax>);  \
2206e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  REGISTER_KERNEL_BUILDER(Name("ReduceSliceMin")                               \
2216e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .Device(DEVICE_GPU)                              \
2226e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .HostMemory("axis")                              \
2236e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .TypeConstraint<type>("T")                       \
2246e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                              .TypeConstraint<index_type>("Tindices"),         \
2256e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                          ReduceSliceKernel<GPUDevice, type, index_type,       \
2266e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp                                            functor::ReduceSliceFunctorMin>);
2276e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
2286e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define REGISTER_GPU_REDUCE_SLICE_KERNELS_ALL(type) \
2296e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  REGISTER_GPU_REDUCE_SLICE_KERNELS(type, int32);   \
2306e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp  REGISTER_GPU_REDUCE_SLICE_KERNELS(type, int64);
2316e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
2326e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew HarpTF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_REDUCE_SLICE_KERNELS_ALL);
2336e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
2346e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#undef REGISTER_GPU_REDUCE_SLICE_KERNELS
2356e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#undef REGISTER_GPU_REDUCE_SLICE_KERNELS_ALL
2366e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
2376e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#endif  // GOOGLE_CUDA
2386e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp
2396e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp}  // namespace tensorflow
240