16e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 26e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 36e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew HarpLicensed under the Apache License, Version 2.0 (the "License"); 46e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harpyou may not use this file except in compliance with the License. 56e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew HarpYou may obtain a copy of the License at 66e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 76e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp http://www.apache.org/licenses/LICENSE-2.0 86e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 96e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew HarpUnless required by applicable law or agreed to in writing, software 106e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harpdistributed under the License is distributed on an "AS IS" BASIS, 116e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew HarpWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 126e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew HarpSee the License for the specific language governing permissions and 136e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harplimitations under the License. 146e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp==============================================================================*/ 156e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 166e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define EIGEN_USE_THREADS 176e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 18008910f1122d115a6d7430bfcc63cf4296c7467dJonathan Hseu#include "tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h" 194463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower#include <algorithm> 206e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#include "tensorflow/core/framework/op.h" 216e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#include "tensorflow/core/framework/op_kernel.h" 226e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#include "tensorflow/core/framework/register_types.h" 236e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#include "tensorflow/core/lib/core/threadpool.h" 246e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 256e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harpnamespace tensorflow { 266e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 276e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harpusing GPUDevice = Eigen::GpuDevice; 286e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harpusing CPUDevice = Eigen::ThreadPoolDevice; 296e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harpusing thread::ThreadPool; 306e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 316e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harpnamespace functor { 326e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 336e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define CPUReduceSliceFunctorReduceop(reduceop, beginning) \ 346e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp template <typename T, typename Index> \ 356e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp struct ReduceSliceFunctor##reduceop<CPUDevice, T, Index> { \ 366e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp private: \ 376e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp struct XYZ { \ 386e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp Index x, y, z; \ 396e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp XYZ() = default; \ 406e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp XYZ(Index x, Index y, Index z) : x(x), y(y), z(z) {} \ 416e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp }; \ 426e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp inline static XYZ global_index_to_xyz(Index global, XYZ size) { \ 436e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp XYZ ret; \ 446e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp ret.x = global / (size.y * size.z); \ 456e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp ret.y = global % (size.y * size.z) / size.z; \ 466e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp ret.z = global % size.z; \ 476e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp return ret; \ 486e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp } \ 496e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp \ 506e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp public: \ 516e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp virtual ~ReduceSliceFunctor##reduceop() {} \ 526e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp virtual void operator()(OpKernelContext* ctx, const CPUDevice& d, \ 536e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp Index indices_width, \ 546e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp typename TTypes<Index, 1>::ConstTensor indices, \ 556e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp typename TTypes<T, 3>::ConstTensor data, \ 566e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp typename TTypes<T, 3>::Tensor output) { \ 576e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp Index bound = data.dimension(1); \ 586e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp Index dim1 = output.dimension(0); \ 596e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp Index dim2 = output.dimension(1); \ 606e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp Index dim3 = output.dimension(2); \ 616e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp Index size = dim1 * dim2 * dim3; \ 626e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp if (size == 0) { \ 636e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp return; \ 646e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp } \ 656e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp T zero = beginning<T>(); \ 666e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp ThreadPool* thread_pool = \ 676e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp ctx->device()->tensorflow_cpu_worker_threads()->workers; \ 686e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp /* shard the work */ \ 696e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp auto work = [&](Index start, Index end) { \ 706e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp for (Index global = start; global < end; ++global) { \ 716e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp XYZ xyz = global_index_to_xyz(global, XYZ(dim1, dim2, dim3)); \ 726e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp Index x = xyz.x; \ 736e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp Index y = xyz.y; \ 746e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp Index z = xyz.z; \ 756e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp output(x, y, z) = zero; \ 766e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp Index slice_head = indices(y * indices_width); \ 776e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp Index slice_end = std::min(indices(y * indices_width + 1), bound); \ 786e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp for (Index i = slice_head; i < slice_end; ++i) { \ 796e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp output(x, y, z) = reduceop(output(x, y, z), data(x, i, z)); \ 806e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp } \ 816e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp } \ 826e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp }; \ 836e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp /* Here assumes the number of average CPU cycles for each slice equals \ 846e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp * the average length of each slice */ \ 856e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp thread_pool->ParallelFor(size, std::max(bound / dim2, (Index)1), work); \ 866e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp } \ 876e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp }; 886e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 896e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew HarpCALL_ALL_REDUCEOPS(CPUReduceSliceFunctorReduceop) 906e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#undef CPUReduceSliceFunctorReduceop 916e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 926e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define DEFINE_CPU_SUMPROD_SPECS_INDEX(T, Index) \ 936e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp template struct ReduceSliceFunctorSum<CPUDevice, T, Index>; \ 946e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp template struct ReduceSliceFunctorProd<CPUDevice, T, Index>; 956e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 966e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define DEFINE_CPU_MINMAX_SPECS_INDEX(T, Index) \ 976e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp template struct ReduceSliceFunctorMax<CPUDevice, T, Index>; \ 986e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp template struct ReduceSliceFunctorMin<CPUDevice, T, Index>; 996e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 1006e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define DEFINE_CPU_SUMPROD_SPECS(T) \ 1016e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp DEFINE_CPU_SUMPROD_SPECS_INDEX(T, int32); \ 1026e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp DEFINE_CPU_SUMPROD_SPECS_INDEX(T, int64); 1036e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 1046e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define DEFINE_CPU_MINMAX_SPECS(T) \ 1056e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp DEFINE_CPU_MINMAX_SPECS_INDEX(T, int32); \ 1066e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp DEFINE_CPU_MINMAX_SPECS_INDEX(T, int64); 1076e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 1086e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew HarpTF_CALL_NUMBER_TYPES(DEFINE_CPU_SUMPROD_SPECS) 1096e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew HarpTF_CALL_REAL_NUMBER_TYPES(DEFINE_CPU_MINMAX_SPECS) 1106e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 1116e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#undef DEFINE_CPU_SUMPROD_SPECS_INDEX 1126e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#undef DEFINE_CPU_MINMAX_SPECS_INDEX 1136e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#undef DEFINE_CPU_SUMPROD_SPECS 1146e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#undef DEFINE_CPU_MINMAX_SPECS 1156e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 1166e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp} // namespace functor 1176e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 1186e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harptemplate <typename Device, typename T, typename Index, 1196e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp template <typename Device2, typename T2, typename Index2> 1206e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp class Functor> 1216e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harpclass ReduceSliceKernel : public OpKernel { 1226e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp public: 1236e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp explicit ReduceSliceKernel(OpKernelConstruction* context) 1246e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp : OpKernel(context) {} 1256e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 1266e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp void Compute(OpKernelContext* context) override { 1276e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp const Tensor& data = context->input(0); 1286e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp const Tensor& indices = context->input(1); 1296e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp const Tensor& _axis = context->input(2); 1306e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp int64 axis = _axis.scalar<int64>()(); 1316e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 1326e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp int indices_width = 2; 1336e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp int out_axis_dim_size = indices.shape().dim_size(0); 1346e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp if (indices.dims() == 1 || indices.shape().dim_size(1) == 1) { 1356e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp indices_width = 1; 1366e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp if (out_axis_dim_size > 0) { 1376e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp out_axis_dim_size--; 1386e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp } 1396e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp } 1406e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 1416e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp TensorShape output_shape = data.shape(); 1426e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp output_shape.set_dim(axis, out_axis_dim_size); 1436e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp Tensor* output = nullptr; 1446e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 1456e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp auto functor = Functor<Device, T, Index>(); 1466e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp functor(context, context->eigen_device<Device>(), indices_width, 1476e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp indices.flat<Index>(), data.flat_inner_outer_dims<T, 3>(axis - 1), 1486e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp output->flat_inner_outer_dims<T, 3>(axis - 1)); 1496e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp } 1506e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp}; 1516e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 1526e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS(type, index_type) \ 1536e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp REGISTER_KERNEL_BUILDER(Name("ReduceSliceSum") \ 1546e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .Device(DEVICE_CPU) \ 1556e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .TypeConstraint<type>("T") \ 1566e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .TypeConstraint<index_type>("Tindices"), \ 1576e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp ReduceSliceKernel<CPUDevice, type, index_type, \ 1586e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp functor::ReduceSliceFunctorSum>); \ 1596e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp REGISTER_KERNEL_BUILDER(Name("ReduceSliceProd") \ 1606e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .Device(DEVICE_CPU) \ 1616e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .TypeConstraint<type>("T") \ 1626e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .TypeConstraint<index_type>("Tindices"), \ 1636e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp ReduceSliceKernel<CPUDevice, type, index_type, \ 1646e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp functor::ReduceSliceFunctorProd>); 1656e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 1666e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS(type, index_type) \ 1676e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp REGISTER_KERNEL_BUILDER(Name("ReduceSliceMax") \ 1686e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .Device(DEVICE_CPU) \ 1696e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .TypeConstraint<type>("T") \ 1706e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .TypeConstraint<index_type>("Tindices"), \ 1716e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp ReduceSliceKernel<CPUDevice, type, index_type, \ 1726e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp functor::ReduceSliceFunctorMax>); \ 1736e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp REGISTER_KERNEL_BUILDER(Name("ReduceSliceMin") \ 1746e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .Device(DEVICE_CPU) \ 1756e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .TypeConstraint<type>("T") \ 1766e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .TypeConstraint<index_type>("Tindices"), \ 1776e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp ReduceSliceKernel<CPUDevice, type, index_type, \ 1786e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp functor::ReduceSliceFunctorMin>); 1796e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 1806e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS_ALL(type) \ 1816e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS(type, int32); \ 1826e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS(type, int64); 1836e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 1846e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS_ALL(type) \ 1856e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS(type, int32); \ 1866e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS(type, int64); 1876e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 1886e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew HarpTF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS_ALL) 1896e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew HarpTF_CALL_NUMBER_TYPES(REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS_ALL) 1906e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 1916e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#undef REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS 1926e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#undef REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS 1936e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#undef REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS_ALL 1946e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#undef REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS_ALL 1956e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 1966e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#if GOOGLE_CUDA 1976e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 1986e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define REGISTER_GPU_REDUCE_SLICE_KERNELS(type, index_type) \ 1996e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp REGISTER_KERNEL_BUILDER(Name("ReduceSliceSum") \ 2006e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .Device(DEVICE_GPU) \ 2016e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .HostMemory("axis") \ 2026e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .TypeConstraint<type>("T") \ 2036e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .TypeConstraint<index_type>("Tindices"), \ 2046e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp ReduceSliceKernel<GPUDevice, type, index_type, \ 2056e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp functor::ReduceSliceFunctorSum>); \ 2066e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp REGISTER_KERNEL_BUILDER(Name("ReduceSliceProd") \ 2076e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .Device(DEVICE_GPU) \ 2086e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .HostMemory("axis") \ 2096e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .TypeConstraint<type>("T") \ 2106e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .TypeConstraint<index_type>("Tindices"), \ 2116e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp ReduceSliceKernel<GPUDevice, type, index_type, \ 2126e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp functor::ReduceSliceFunctorProd>); \ 2136e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp REGISTER_KERNEL_BUILDER(Name("ReduceSliceMax") \ 2146e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .Device(DEVICE_GPU) \ 2156e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .HostMemory("axis") \ 2166e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .TypeConstraint<type>("T") \ 2176e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .TypeConstraint<index_type>("Tindices"), \ 2186e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp ReduceSliceKernel<GPUDevice, type, index_type, \ 2196e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp functor::ReduceSliceFunctorMax>); \ 2206e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp REGISTER_KERNEL_BUILDER(Name("ReduceSliceMin") \ 2216e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .Device(DEVICE_GPU) \ 2226e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .HostMemory("axis") \ 2236e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .TypeConstraint<type>("T") \ 2246e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp .TypeConstraint<index_type>("Tindices"), \ 2256e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp ReduceSliceKernel<GPUDevice, type, index_type, \ 2266e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp functor::ReduceSliceFunctorMin>); 2276e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 2286e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#define REGISTER_GPU_REDUCE_SLICE_KERNELS_ALL(type) \ 2296e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp REGISTER_GPU_REDUCE_SLICE_KERNELS(type, int32); \ 2306e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp REGISTER_GPU_REDUCE_SLICE_KERNELS(type, int64); 2316e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 2326e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew HarpTF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_REDUCE_SLICE_KERNELS_ALL); 2336e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 2346e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#undef REGISTER_GPU_REDUCE_SLICE_KERNELS 2356e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#undef REGISTER_GPU_REDUCE_SLICE_KERNELS_ALL 2366e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 2376e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp#endif // GOOGLE_CUDA 2386e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp 2396e3e7d18f42cb4237ce6dbe2ffd0f9f158c36dafAndrew Harp} // namespace tensorflow 240