11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License"); 41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License. 51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at 61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins http://www.apache.org/licenses/LICENSE-2.0 81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software 101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS, 111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and 131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License. 141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/ 151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// XLA specific pooling ops. 171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/type_util.h" 191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/xla_helpers.h" 201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 21a8c325e57c1077f1e8df540a20bd8b36d3d1f968Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_registry.h" 221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/lib/arithmetic.h" 231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/literal_util.h" 241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/util.h" 251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/framework/op_kernel.h" 261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/framework/register_types.h" 271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/framework/tensor.h" 281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/kernels/bounds_check.h" 291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/kernels/conv_grad_ops.h" 301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/kernels/pooling_ops_common.h" 311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace tensorflow { 331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace { 341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// Superclass of pooling ops. 361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass PoolingOp : public XlaOpKernel { 371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public: 38c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims) 39c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { 40995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang if (ctx->num_inputs() == 1) { 41995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang std::vector<int32> ksize_int; 42995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang std::vector<int32> stride_int; 43995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int)); 44995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang OP_REQUIRES(ctx, ksize_int.size() == num_dims(), 45995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang errors::InvalidArgument("Sliding window ksize field must " 46995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang "specify ", 47995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang num_dims(), " dimensions")); 48995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int)); 49995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang OP_REQUIRES(ctx, stride_int.size() == num_dims(), 50995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang errors::InvalidArgument("Sliding window stride field must " 51995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang "specify ", 52995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang num_dims(), " dimensions")); 53995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang for (int i = 0; i < num_dims(); ++i) { 54995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang ksize_.push_back(ksize_int[i]); 55995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang stride_.push_back(stride_int[i]); 56995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang } 571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Padding padding; 591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding)); 601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame; 611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 63c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins int num_dims() const { return num_spatial_dims_ + 2; } 64c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins 651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Method that builds an initial value to use in reductions. 661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins virtual xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, 671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins DataType data_type) = 0; 681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // The reduction operation to apply to each window. 701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins virtual const xla::Computation* Reduction(XlaOpKernelContext* ctx, 711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins DataType dtype) = 0; 721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // A post-processing operation to apply on the outputs of the ReduceWindow. 741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins virtual xla::ComputationDataHandle PostProcessOutput( 751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, 761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins DataType dtype, const TensorShape& input_shape) = 0; 771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins void Compile(XlaOpKernelContext* ctx) override { 791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins xla::ComputationDataHandle input = ctx->Input(0); 801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const TensorShape input_shape = ctx->InputShape(0); 811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 82995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang std::vector<int64> ksize = ksize_; 83995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang std::vector<int64> stride = stride_; 84995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang if (ctx->num_inputs() != 1) { 85995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang const TensorShape ksize_shape = ctx->InputShape(1); 86995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang // Validate input sizes. 87995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape), 88995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang errors::InvalidArgument("ksize must be a vector, not shape ", 89995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang ksize_shape.DebugString())); 90995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang OP_REQUIRES(ctx, ksize_shape.num_elements() == num_dims(), 91995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang errors::InvalidArgument("Sliding window ksize field must " 92995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang "specify ", 93995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang num_dims(), " dimensions")); 94995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang ksize.clear(); 95995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &ksize)); 96995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang 97995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang const TensorShape stride_shape = ctx->InputShape(2); 98995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang // Validate input sizes. 99995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape), 100995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang errors::InvalidArgument("stride must be a vector, not shape ", 101995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang stride_shape.DebugString())); 102995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang OP_REQUIRES(ctx, stride_shape.num_elements() == num_dims(), 103995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang errors::InvalidArgument("Sliding window stride field must " 104995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang "specify ", 105995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang num_dims(), " dimensions")); 106995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang stride.clear(); 107995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride)); 108995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang } 109c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins OP_REQUIRES(ctx, input_shape.dims() == num_dims(), 110c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins errors::InvalidArgument("Input to ", type_string(), 111c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins " operator must have ", num_dims(), 112c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins " dimensions")); 113c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins 1141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const DataType type = input_type(0); 1151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins xla::ComputationDataHandle pooled = ctx->builder()->ReduceWindow( 116995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize, 117995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang stride, padding_); 1181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ctx->SetOutput(0, PostProcessOutput(ctx, pooled, type, input_shape)); 1191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins protected: 122c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins const int num_spatial_dims_; 1231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> ksize_; 1241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> stride_; 1251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins xla::Padding padding_; 126c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins TensorFormat data_format_ = FORMAT_NHWC; 1271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}; 1281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass MaxPoolOp : public PoolingOp { 1301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public: 131c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) 132c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims) {} 1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, 1351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins DataType data_type) override { 1361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return XlaHelpers::MinValue(b, data_type); 1371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const xla::Computation* Reduction(XlaOpKernelContext* ctx, 1401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins DataType dtype) override { 1411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return ctx->GetOrCreateMax(dtype); 1421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins xla::ComputationDataHandle PostProcessOutput( 1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, 1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins DataType dtype, const TensorShape& input_shape) override { 1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return output; 1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}; 1501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 151c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkinsclass MaxPool2DOp : public MaxPoolOp { 152c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins public: 153c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins explicit MaxPool2DOp(OpKernelConstruction* ctx) 154c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins : MaxPoolOp(ctx, /*num_spatial_dims=*/2) { 155c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins string data_format_str; 156c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); 157c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), 158c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins errors::InvalidArgument("Invalid data format")); 159c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins } 160c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins}; 16193f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter HawkinsREGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp); 162995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong TangREGISTER_XLA_OP(Name("MaxPoolV2") 163995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang .CompileTimeConstInput("ksize") 164995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang .CompileTimeConstInput("strides"), 165995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang MaxPool2DOp); 166c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins 167c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkinsclass MaxPool3DOp : public MaxPoolOp { 168c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins public: 169c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins explicit MaxPool3DOp(OpKernelConstruction* ctx) 170c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins : MaxPoolOp(ctx, /*num_spatial_dims=*/3) {} 171c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins}; 17293f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter HawkinsREGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp); 1731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// Common computation shared between AvgPool and AvgPoolGrad. Divide each 1751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// element of an image by the count of elements that contributed to that 1761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// element during pooling. 1771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsstatic xla::ComputationDataHandle AvgPoolDivideByCount( 1781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, 1791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins DataType dtype, const TensorShape& input_shape, xla::Padding padding, 1801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const std::vector<int64>& ksize, const std::vector<int64>& stride, 181c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins int num_spatial_dims, TensorFormat data_format) { 1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (padding == xla::Padding::kValid) { 1831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // In VALID padding, all windows have the same number of elements 1841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // contributing to each average. Divide by the window size everywhere to 1851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // get the average. 1861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 window_size = std::accumulate(ksize.begin(), ksize.end(), 1, 1871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins [](int64 a, int64 b) { return a * b; }); 1881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto divisor = 1901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins XlaHelpers::IntegerLiteral(ctx->builder(), dtype, window_size); 1911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return ctx->builder()->Div(output, divisor); 1921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } else { 1931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // For SAME padding, the padding shouldn't be included in the 1941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // counts. We use another ReduceWindow to find the right counts. 1951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // TODO(phawkins): use a less brute-force way to compute this. Only 1971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // the boundary regions will have interesting values here. 1981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 199c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins std::vector<int64> input_dim_sizes(num_spatial_dims); 200c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins std::vector<int64> window_dims(num_spatial_dims); 201c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins std::vector<int64> window_ksize(num_spatial_dims); 202c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins std::vector<int64> window_stride(num_spatial_dims); 203c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins for (int i = 0; i < num_spatial_dims; ++i) { 204c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins int dim = GetTensorSpatialDimIndex(num_spatial_dims + 2, data_format, i); 205c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins input_dim_sizes[i] = input_shape.dim_size(dim); 206c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins window_dims[i] = dim; 207c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins window_ksize[i] = ksize[dim]; 208c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins window_stride[i] = stride[dim]; 209c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins } 2101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Build a matrix of all 1s, with the same width/height as the input. 2121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto ones = ctx->builder()->Broadcast( 213c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins XlaHelpers::One(ctx->builder(), dtype), input_dim_sizes); 2141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Perform a ReduceWindow with the same window size, strides, and padding 2161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // to count the number of contributions to each result element. 2171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto counts = ctx->builder()->ReduceWindow( 2181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ones, XlaHelpers::Zero(ctx->builder(), dtype), 219c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins *ctx->GetOrCreateAdd(dtype), window_ksize, window_stride, 220c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins xla::Padding::kSame); 2211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 222c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins return ctx->builder()->Div(output, counts, window_dims); 2231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 2251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass AvgPoolOp : public PoolingOp { 2271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public: 228c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) 229c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins : PoolingOp(ctx, num_spatial_dims) {} 2301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, 2321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins DataType data_type) override { 2331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return XlaHelpers::Zero(b, data_type); 2341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const xla::Computation* Reduction(XlaOpKernelContext* ctx, 2371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins DataType dtype) override { 2381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return ctx->GetOrCreateAdd(dtype); 2391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins xla::ComputationDataHandle PostProcessOutput( 2421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, 2431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins DataType dtype, const TensorShape& input_shape) override { 2441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_, 245c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins ksize_, stride_, num_spatial_dims_, 246c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins data_format_); 2471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 248c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins}; 2491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 250c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkinsclass AvgPool2DOp : public AvgPoolOp { 251c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins public: 252c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins explicit AvgPool2DOp(OpKernelConstruction* ctx) 253c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins : AvgPoolOp(ctx, /*num_spatial_dims=*/2) { 254c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins string data_format_str; 255c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); 256c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), 257c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins errors::InvalidArgument("Invalid data format")); 258c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins } 2591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}; 26093f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter HawkinsREGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp); 2611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 262c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkinsclass AvgPool3DOp : public AvgPoolOp { 263c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins public: 264c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins explicit AvgPool3DOp(OpKernelConstruction* ctx) 265c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins : AvgPoolOp(ctx, /*num_spatial_dims=*/3) {} 266c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins}; 26793f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter HawkinsREGISTER_XLA_OP(Name("AvgPool3D"), AvgPool3DOp); 2681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// The operation to compute MaxPool gradients. 2701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// It takes three inputs: 2711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// - The original input tensor 2721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// - The original output tensor 2731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// - Backprop tensor for output 2741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// It produces one output: backprop tensor for input. 2751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass MaxPoolGradOp : public XlaOpKernel { 2761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public: 277c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins MaxPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims) 278c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { 2798fdef24247814d322991127091cc6d0c1eb60380Russell Power if (ctx->num_inputs() == 3) { 2808fdef24247814d322991127091cc6d0c1eb60380Russell Power OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_)); 2818fdef24247814d322991127091cc6d0c1eb60380Russell Power OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_)); 2828fdef24247814d322991127091cc6d0c1eb60380Russell Power } 2838fdef24247814d322991127091cc6d0c1eb60380Russell Power OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); 2848fdef24247814d322991127091cc6d0c1eb60380Russell Power } 2858fdef24247814d322991127091cc6d0c1eb60380Russell Power 2868fdef24247814d322991127091cc6d0c1eb60380Russell Power int num_dims() const { return num_spatial_dims_ + 2; } 2878fdef24247814d322991127091cc6d0c1eb60380Russell Power 2888fdef24247814d322991127091cc6d0c1eb60380Russell Power void Compile(XlaOpKernelContext* ctx) override { 2898fdef24247814d322991127091cc6d0c1eb60380Russell Power if (ctx->num_inputs() != 3) { 2908fdef24247814d322991127091cc6d0c1eb60380Russell Power OP_REQUIRES( 2918fdef24247814d322991127091cc6d0c1eb60380Russell Power ctx, ctx->num_inputs() == 5, 2928fdef24247814d322991127091cc6d0c1eb60380Russell Power errors::InvalidArgument("Must supply ksize and stride arguments.")); 2938fdef24247814d322991127091cc6d0c1eb60380Russell Power const TensorShape ksize_shape = ctx->InputShape(3); 2948fdef24247814d322991127091cc6d0c1eb60380Russell Power // Validate input sizes. 2958fdef24247814d322991127091cc6d0c1eb60380Russell Power OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape), 2968fdef24247814d322991127091cc6d0c1eb60380Russell Power errors::InvalidArgument("ksize must be a vector, not shape ", 2978fdef24247814d322991127091cc6d0c1eb60380Russell Power ksize_shape.DebugString())); 2988fdef24247814d322991127091cc6d0c1eb60380Russell Power OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_)); 2998fdef24247814d322991127091cc6d0c1eb60380Russell Power 3008fdef24247814d322991127091cc6d0c1eb60380Russell Power const TensorShape stride_shape = ctx->InputShape(4); 3018fdef24247814d322991127091cc6d0c1eb60380Russell Power // Validate input sizes. 3028fdef24247814d322991127091cc6d0c1eb60380Russell Power OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape), 3038fdef24247814d322991127091cc6d0c1eb60380Russell Power errors::InvalidArgument("stride must be a vector, not shape ", 3048fdef24247814d322991127091cc6d0c1eb60380Russell Power stride_shape.DebugString())); 3058fdef24247814d322991127091cc6d0c1eb60380Russell Power OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_)); 3068fdef24247814d322991127091cc6d0c1eb60380Russell Power } 3078fdef24247814d322991127091cc6d0c1eb60380Russell Power 308c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins OP_REQUIRES(ctx, ksize_.size() == num_dims(), 3091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins errors::InvalidArgument("Sliding window ksize field must " 310c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins "specify ", 311c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins num_dims(), " dimensions")); 312c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins OP_REQUIRES(ctx, stride_.size() == num_dims(), 3131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins errors::InvalidArgument("Sliding window strides field must " 314c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins "specify ", 315c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins num_dims(), " dimensions")); 316c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins 3171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const TensorShape tensor_in_shape = ctx->InputShape(0); 3181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const TensorShape tensor_out_shape = ctx->InputShape(1); 3191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const TensorShape out_backprop_shape = ctx->InputShape(2); 3201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 321c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins // For maxpooling, tensor_in should have num_dims() dimensions. 322c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(), 323c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins errors::InvalidArgument("tensor_in must be ", num_dims(), 324c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins "-dimensional")); 325c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(), 326c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins errors::InvalidArgument("tensor_out must be ", num_dims(), 327c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins "-dimensional")); 328c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins // For maxpooling, out_backprop should have num_dims() dimensions. 329c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(), 330c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins errors::InvalidArgument("out_backprop must be ", num_dims(), 331c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins "-dimensional")); 3321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // TODO(phawkins): The XLA version doesn't need tensor_out. Investigate 3341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // whether this is a good time/space tradeoff. 3351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto input = ctx->Input(0); 3361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto out_backprop = ctx->Input(2); 3371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins xla::Padding xla_padding = 3391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; 3401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins xla::PrimitiveType element_type; 3421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type)); 3431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins xla::ComputationDataHandle init_value = 3441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins XlaHelpers::Zero(ctx->builder(), input_type(2)); 3451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto select = CreateScalarGeComputation(element_type, ctx->builder()); 3461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto scatter = CreateScalarAddComputation(element_type, ctx->builder()); 3471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins xla::ComputationDataHandle gradients = ctx->builder()->SelectAndScatter( 3481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins input, select, ksize_, stride_, xla_padding, out_backprop, init_value, 3491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins scatter); 3501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ctx->SetOutput(0, gradients); 3521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 354c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins protected: 355c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins const int num_spatial_dims_; 3561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> ksize_; 3571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> stride_; 3581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Padding padding_; 359c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins TensorFormat data_format_ = FORMAT_NHWC; 3601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}; 3611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 362c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkinsclass MaxPool2DGradOp : public MaxPoolGradOp { 3631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public: 364c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins explicit MaxPool2DGradOp(OpKernelConstruction* ctx) 365c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins : MaxPoolGradOp(ctx, /*num_spatial_dims=*/2) { 3661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins string data_format; 3671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); 3681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), 3691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins errors::InvalidArgument("Invalid data format")); 370c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins } 371c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins}; 37293f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter HawkinsREGISTER_XLA_OP(Name("MaxPoolGrad"), MaxPool2DGradOp); 3738fdef24247814d322991127091cc6d0c1eb60380Russell PowerREGISTER_XLA_OP(Name("MaxPoolGradV2") 3748fdef24247814d322991127091cc6d0c1eb60380Russell Power .CompileTimeConstInput("ksize") 3758fdef24247814d322991127091cc6d0c1eb60380Russell Power .CompileTimeConstInput("strides"), 3768fdef24247814d322991127091cc6d0c1eb60380Russell Power MaxPool2DGradOp); 377c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins 378c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkinsclass MaxPool3DGradOp : public MaxPoolGradOp { 379c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins public: 380c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins explicit MaxPool3DGradOp(OpKernelConstruction* ctx) 381c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins : MaxPoolGradOp(ctx, /*num_spatial_dims=*/3) {} 382c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins}; 38393f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter HawkinsREGISTER_XLA_OP(Name("MaxPool3DGrad"), MaxPool3DGradOp); 384c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins 385c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins// Average-pooling gradient 386c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkinsclass AvgPoolGradOp : public XlaOpKernel { 387c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins public: 388c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins AvgPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims) 389c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { 3901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_)); 391c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins OP_REQUIRES(ctx, ksize_.size() == num_dims(), 3921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins errors::InvalidArgument("Sliding window ksize field must " 393c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins "specify ", 394c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins num_dims(), " dimensions")); 3951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_)); 396c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins OP_REQUIRES(ctx, stride_.size() == num_dims(), 3971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins errors::InvalidArgument("Sliding window strides field must " 398c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins "specify ", 399c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins num_dims(), " dimensions")); 4001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); 4011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins OP_REQUIRES(ctx, ksize_[0] == 1 && stride_[0] == 1, 4021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins errors::Unimplemented( 4031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "Pooling is not yet supported on the batch dimension.")); 4041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 406c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins int num_dims() const { return num_spatial_dims_ + 2; } 407c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins 4081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins void Compile(XlaOpKernelContext* ctx) override { 4091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TensorShape gradients_shape; 4101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &gradients_shape)); 4111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const TensorShape out_backprop_shape = ctx->InputShape(1); 4131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 414c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins // For avgpooling, tensor_in_shape should have num_dims() dimensions. 415c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins OP_REQUIRES(ctx, gradients_shape.dims() == num_dims(), 416c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins errors::InvalidArgument("orig_input_shape must be ", num_dims(), 417c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins "-dimensional")); 4181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 419c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins // For avgpooling, out_backprop should have num_dims() dimensions. 420c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(), 421c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins errors::InvalidArgument("out_backprop must be ", num_dims(), 422c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins "-dimensional")); 4231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 424c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins int depth_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); 425c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins int64 depth = out_backprop_shape.dim_size(depth_dim); 4261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // We can think of average-pooling as: 4281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // * a convolution with a kernel consisting entirely of 1s, where the 4291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // input feature and output feature are equal, and 0s everywhere else. 4301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // * followed by dividing by the counts. 4311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // 4321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // This then gives us an algorithm to build the gradient: 4331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // * divide out_backprop by the counts, followed by 4341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // * Conv2DBackpropInput specialized for that kernel, which simplifies to 4351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // a Pad and a ReduceWindow. 4361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // 4371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // For an explanation of backpropagation for convolution, see the comments 4381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // in third_party/tensorflow/core/kernels/conv_grad_ops.h 4391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 440c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins // TF filter shape is [ H, W, ..., inC, outC ] 441c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins std::vector<int64> filter_dims(num_dims()); 442c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins for (int i = 0; i < num_spatial_dims_; ++i) { 443c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); 444c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins filter_dims[i] = ksize_[dim]; 445c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins } 446c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins filter_dims[num_dims() - 2] = depth; 447c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins filter_dims[num_dims() - 1] = depth; 448c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins TensorShape filter_shape(filter_dims); 4491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Reuse the logic from Conv2DBackpropInput to compute padding. 45119dd9342e7bc55c877367b7474caf41e819e38c3Peter Hawkins ConvBackpropDimensions dims; 452c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins OP_REQUIRES_OK( 453c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins ctx, ConvBackpropComputeDimensions( 454c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins type_string(), /*num_spatial_dims=*/num_spatial_dims_, 455c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins gradients_shape, filter_shape, out_backprop_shape, stride_, 456c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins padding_, data_format_, &dims)); 4571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto out_backprop = ctx->Input(1); 4591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // The input gradients are computed by a convolution of the output 4611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // gradients 4621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // and the filter, with some appropriate padding. See the comment at 4631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // the top of conv_grad_ops.h for details. 4641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins DataType dtype = input_type(1); 4651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins xla::Padding xla_padding = 4671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; 4681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Divide the out_backprop values by the counts for each spatial position. 4701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> stride_int64s(stride_.begin(), stride_.end()); 471c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins auto out_backprop_div = AvgPoolDivideByCount( 472c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins ctx, out_backprop, dtype, gradients_shape, xla_padding, ksize_, 473c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins stride_int64s, num_spatial_dims_, data_format_); 4741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Pad the gradients in the spatial dimensions. We use the same padding 4761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // as Conv2DBackpropInput. 477c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins xla::PaddingConfig padding_config = xla::MakeNoPaddingConfig(num_dims()); 478c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins for (int i = 0; i < num_spatial_dims_; ++i) { 479c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); 480c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins auto* padding = padding_config.mutable_dimensions(dim); 481c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins padding->set_edge_padding_low(dims.spatial_dims[i].pad_before); 482c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins padding->set_edge_padding_high(dims.spatial_dims[i].pad_after); 483c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins padding->set_interior_padding(dims.spatial_dims[i].stride - 1); 484c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins } 4851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto zero = XlaHelpers::Zero(ctx->builder(), dtype); 4871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto padded_gradients = 4881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ctx->builder()->Pad(out_backprop_div, zero, padding_config); 4891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // in_backprop = padded_gradients <conv> ones 491c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins std::vector<int64> ones(num_dims(), 1LL); 4921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins xla::ComputationDataHandle in_backprop = ctx->builder()->ReduceWindow( 4931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins padded_gradients, zero, *ctx->GetOrCreateAdd(dtype), ksize_, 494c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins /* window_strides=*/ones, xla::Padding::kValid); 4951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ctx->SetOutput(0, in_backprop); 4971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 499c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins protected: 500c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins const int num_spatial_dims_; 5011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int64> ksize_; 5021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<int32> stride_; 5031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Padding padding_; 504c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins TensorFormat data_format_ = FORMAT_NHWC; 5051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}; 5061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 507c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkinsclass AvgPool2DGradOp : public AvgPoolGradOp { 508c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins public: 509c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins explicit AvgPool2DGradOp(OpKernelConstruction* ctx) 510c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) { 511c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins string data_format; 512c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); 513c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), 514c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins errors::InvalidArgument("Invalid data format")); 515c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins } 516c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins}; 517c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter HawkinsREGISTER_XLA_OP(Name("AvgPoolGrad").CompileTimeConstInput("orig_input_shape"), 518c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins AvgPool2DGradOp); 519c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins 520c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkinsclass AvgPool3DGradOp : public AvgPoolGradOp { 521c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins public: 522c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins explicit AvgPool3DGradOp(OpKernelConstruction* ctx) 523c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins : AvgPoolGradOp(ctx, /*num_spatial_dims=*/3) {} 524c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins}; 525c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter HawkinsREGISTER_XLA_OP(Name("AvgPool3DGrad").CompileTimeConstInput("orig_input_shape"), 526c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins AvgPool3DGradOp); 5271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // anonymous namespace 5291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace tensorflow 530