11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// XLA specific pooling ops.
171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/type_util.h"
191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/xla_helpers.h"
201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
21a8c325e57c1077f1e8df540a20bd8b36d3d1f968Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/literal_util.h"
241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/util.h"
251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/framework/op_kernel.h"
261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/framework/register_types.h"
271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/framework/tensor.h"
281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/kernels/bounds_check.h"
291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/kernels/conv_grad_ops.h"
301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/kernels/pooling_ops_common.h"
311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace tensorflow {
331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace {
341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// Superclass of pooling ops.
361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass PoolingOp : public XlaOpKernel {
371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public:
38c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims)
39c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
40995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang    if (ctx->num_inputs() == 1) {
41995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang      std::vector<int32> ksize_int;
42995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang      std::vector<int32> stride_int;
43995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang      OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int));
44995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang      OP_REQUIRES(ctx, ksize_int.size() == num_dims(),
45995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang                  errors::InvalidArgument("Sliding window ksize field must "
46995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang                                          "specify ",
47995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang                                          num_dims(), " dimensions"));
48995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang      OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int));
49995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang      OP_REQUIRES(ctx, stride_int.size() == num_dims(),
50995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang                  errors::InvalidArgument("Sliding window stride field must "
51995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang                                          "specify ",
52995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang                                          num_dims(), " dimensions"));
53995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang      for (int i = 0; i < num_dims(); ++i) {
54995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang        ksize_.push_back(ksize_int[i]);
55995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang        stride_.push_back(stride_int[i]);
56995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang      }
571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    Padding padding;
591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding));
601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
63c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  int num_dims() const { return num_spatial_dims_ + 2; }
64c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins
651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Method that builds an initial value to use in reductions.
661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  virtual xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b,
671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                               DataType data_type) = 0;
681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // The reduction operation to apply to each window.
701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  virtual const xla::Computation* Reduction(XlaOpKernelContext* ctx,
711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                            DataType dtype) = 0;
721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // A post-processing operation to apply on the outputs of the ReduceWindow.
741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  virtual xla::ComputationDataHandle PostProcessOutput(
751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      DataType dtype, const TensorShape& input_shape) = 0;
771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  void Compile(XlaOpKernelContext* ctx) override {
791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    xla::ComputationDataHandle input = ctx->Input(0);
801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const TensorShape input_shape = ctx->InputShape(0);
811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
82995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang    std::vector<int64> ksize = ksize_;
83995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang    std::vector<int64> stride = stride_;
84995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang    if (ctx->num_inputs() != 1) {
85995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang      const TensorShape ksize_shape = ctx->InputShape(1);
86995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang      // Validate input sizes.
87995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang      OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
88995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang                  errors::InvalidArgument("ksize must be a vector, not shape ",
89995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang                                          ksize_shape.DebugString()));
90995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang      OP_REQUIRES(ctx, ksize_shape.num_elements() == num_dims(),
91995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang                  errors::InvalidArgument("Sliding window ksize field must "
92995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang                                          "specify ",
93995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang                                          num_dims(), " dimensions"));
94995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang      ksize.clear();
95995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang      OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &ksize));
96995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang
97995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang      const TensorShape stride_shape = ctx->InputShape(2);
98995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang      // Validate input sizes.
99995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang      OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
100995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang                  errors::InvalidArgument("stride must be a vector, not shape ",
101995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang                                          stride_shape.DebugString()));
102995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang      OP_REQUIRES(ctx, stride_shape.num_elements() == num_dims(),
103995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang                  errors::InvalidArgument("Sliding window stride field must "
104995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang                                          "specify ",
105995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang                                          num_dims(), " dimensions"));
106995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang      stride.clear();
107995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang      OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride));
108995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang    }
109c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
110c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                errors::InvalidArgument("Input to ", type_string(),
111c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                                        " operator must have ", num_dims(),
112c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                                        " dimensions"));
113c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins
1141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const DataType type = input_type(0);
1151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    xla::ComputationDataHandle pooled = ctx->builder()->ReduceWindow(
116995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang        input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize,
117995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang        stride, padding_);
1181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    ctx->SetOutput(0, PostProcessOutput(ctx, pooled, type, input_shape));
1191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins protected:
122c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  const int num_spatial_dims_;
1231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> ksize_;
1241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> stride_;
1251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  xla::Padding padding_;
126c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  TensorFormat data_format_ = FORMAT_NHWC;
1271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
1281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass MaxPoolOp : public PoolingOp {
1301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public:
131c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
132c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims) {}
1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b,
1351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                       DataType data_type) override {
1361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return XlaHelpers::MinValue(b, data_type);
1371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const xla::Computation* Reduction(XlaOpKernelContext* ctx,
1401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                    DataType dtype) override {
1411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return ctx->GetOrCreateMax(dtype);
1421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  xla::ComputationDataHandle PostProcessOutput(
1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      DataType dtype, const TensorShape& input_shape) override {
1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return output;
1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
1501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
151c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkinsclass MaxPool2DOp : public MaxPoolOp {
152c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins public:
153c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  explicit MaxPool2DOp(OpKernelConstruction* ctx)
154c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      : MaxPoolOp(ctx, /*num_spatial_dims=*/2) {
155c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    string data_format_str;
156c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
157c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
158c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                errors::InvalidArgument("Invalid data format"));
159c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  }
160c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins};
16193f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter HawkinsREGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp);
162995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong TangREGISTER_XLA_OP(Name("MaxPoolV2")
163995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang                    .CompileTimeConstInput("ksize")
164995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang                    .CompileTimeConstInput("strides"),
165995378c4c9ff156cae7a365cfdc1480a3ee6d0bfYong Tang                MaxPool2DOp);
166c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins
167c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkinsclass MaxPool3DOp : public MaxPoolOp {
168c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins public:
169c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  explicit MaxPool3DOp(OpKernelConstruction* ctx)
170c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      : MaxPoolOp(ctx, /*num_spatial_dims=*/3) {}
171c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins};
17293f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter HawkinsREGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp);
1731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// Common computation shared between AvgPool and AvgPoolGrad. Divide each
1751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// element of an image by the count of elements that contributed to that
1761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// element during pooling.
1771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsstatic xla::ComputationDataHandle AvgPoolDivideByCount(
1781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
1791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    DataType dtype, const TensorShape& input_shape, xla::Padding padding,
1801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const std::vector<int64>& ksize, const std::vector<int64>& stride,
181c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    int num_spatial_dims, TensorFormat data_format) {
1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  if (padding == xla::Padding::kValid) {
1831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // In VALID padding, all windows have the same number of elements
1841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // contributing to each average. Divide by the window size everywhere to
1851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // get the average.
1861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    int64 window_size = std::accumulate(ksize.begin(), ksize.end(), 1,
1871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                        [](int64 a, int64 b) { return a * b; });
1881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto divisor =
1901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        XlaHelpers::IntegerLiteral(ctx->builder(), dtype, window_size);
1911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return ctx->builder()->Div(output, divisor);
1921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  } else {
1931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // For SAME padding, the padding shouldn't be included in the
1941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // counts. We use another ReduceWindow to find the right counts.
1951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // TODO(phawkins): use a less brute-force way to compute this. Only
1971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // the boundary regions will have interesting values here.
1981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
199c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    std::vector<int64> input_dim_sizes(num_spatial_dims);
200c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    std::vector<int64> window_dims(num_spatial_dims);
201c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    std::vector<int64> window_ksize(num_spatial_dims);
202c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    std::vector<int64> window_stride(num_spatial_dims);
203c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    for (int i = 0; i < num_spatial_dims; ++i) {
204c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      int dim = GetTensorSpatialDimIndex(num_spatial_dims + 2, data_format, i);
205c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      input_dim_sizes[i] = input_shape.dim_size(dim);
206c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      window_dims[i] = dim;
207c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      window_ksize[i] = ksize[dim];
208c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      window_stride[i] = stride[dim];
209c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    }
2101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Build a matrix of all 1s, with the same width/height as the input.
2121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto ones = ctx->builder()->Broadcast(
213c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins        XlaHelpers::One(ctx->builder(), dtype), input_dim_sizes);
2141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Perform a ReduceWindow with the same window size, strides, and padding
2161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // to count the number of contributions to each result element.
2171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto counts = ctx->builder()->ReduceWindow(
2181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        ones, XlaHelpers::Zero(ctx->builder(), dtype),
219c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins        *ctx->GetOrCreateAdd(dtype), window_ksize, window_stride,
220c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins        xla::Padding::kSame);
2211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
222c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    return ctx->builder()->Div(output, counts, window_dims);
2231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
2241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
2251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass AvgPoolOp : public PoolingOp {
2271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public:
228c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
229c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      : PoolingOp(ctx, num_spatial_dims) {}
2301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b,
2321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                       DataType data_type) override {
2331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return XlaHelpers::Zero(b, data_type);
2341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
2351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const xla::Computation* Reduction(XlaOpKernelContext* ctx,
2371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                    DataType dtype) override {
2381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return ctx->GetOrCreateAdd(dtype);
2391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
2401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  xla::ComputationDataHandle PostProcessOutput(
2421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
2431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      DataType dtype, const TensorShape& input_shape) override {
2441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_,
245c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                                ksize_, stride_, num_spatial_dims_,
246c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                                data_format_);
2471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
248c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins};
2491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
250c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkinsclass AvgPool2DOp : public AvgPoolOp {
251c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins public:
252c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  explicit AvgPool2DOp(OpKernelConstruction* ctx)
253c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      : AvgPoolOp(ctx, /*num_spatial_dims=*/2) {
254c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    string data_format_str;
255c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
256c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
257c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                errors::InvalidArgument("Invalid data format"));
258c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  }
2591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
26093f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter HawkinsREGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp);
2611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
262c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkinsclass AvgPool3DOp : public AvgPoolOp {
263c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins public:
264c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  explicit AvgPool3DOp(OpKernelConstruction* ctx)
265c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      : AvgPoolOp(ctx, /*num_spatial_dims=*/3) {}
266c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins};
26793f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter HawkinsREGISTER_XLA_OP(Name("AvgPool3D"), AvgPool3DOp);
2681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// The operation to compute MaxPool gradients.
2701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// It takes three inputs:
2711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins//   - The original input tensor
2721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins//   - The original output tensor
2731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins//   - Backprop tensor for output
2741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// It produces one output: backprop tensor for input.
2751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass MaxPoolGradOp : public XlaOpKernel {
2761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public:
277c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  MaxPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
278c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
2798fdef24247814d322991127091cc6d0c1eb60380Russell Power    if (ctx->num_inputs() == 3) {
2808fdef24247814d322991127091cc6d0c1eb60380Russell Power      OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
2818fdef24247814d322991127091cc6d0c1eb60380Russell Power      OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
2828fdef24247814d322991127091cc6d0c1eb60380Russell Power    }
2838fdef24247814d322991127091cc6d0c1eb60380Russell Power    OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
2848fdef24247814d322991127091cc6d0c1eb60380Russell Power  }
2858fdef24247814d322991127091cc6d0c1eb60380Russell Power
2868fdef24247814d322991127091cc6d0c1eb60380Russell Power  int num_dims() const { return num_spatial_dims_ + 2; }
2878fdef24247814d322991127091cc6d0c1eb60380Russell Power
2888fdef24247814d322991127091cc6d0c1eb60380Russell Power  void Compile(XlaOpKernelContext* ctx) override {
2898fdef24247814d322991127091cc6d0c1eb60380Russell Power    if (ctx->num_inputs() != 3) {
2908fdef24247814d322991127091cc6d0c1eb60380Russell Power      OP_REQUIRES(
2918fdef24247814d322991127091cc6d0c1eb60380Russell Power          ctx, ctx->num_inputs() == 5,
2928fdef24247814d322991127091cc6d0c1eb60380Russell Power          errors::InvalidArgument("Must supply ksize and stride arguments."));
2938fdef24247814d322991127091cc6d0c1eb60380Russell Power      const TensorShape ksize_shape = ctx->InputShape(3);
2948fdef24247814d322991127091cc6d0c1eb60380Russell Power      // Validate input sizes.
2958fdef24247814d322991127091cc6d0c1eb60380Russell Power      OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
2968fdef24247814d322991127091cc6d0c1eb60380Russell Power                  errors::InvalidArgument("ksize must be a vector, not shape ",
2978fdef24247814d322991127091cc6d0c1eb60380Russell Power                                          ksize_shape.DebugString()));
2988fdef24247814d322991127091cc6d0c1eb60380Russell Power      OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_));
2998fdef24247814d322991127091cc6d0c1eb60380Russell Power
3008fdef24247814d322991127091cc6d0c1eb60380Russell Power      const TensorShape stride_shape = ctx->InputShape(4);
3018fdef24247814d322991127091cc6d0c1eb60380Russell Power      // Validate input sizes.
3028fdef24247814d322991127091cc6d0c1eb60380Russell Power      OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
3038fdef24247814d322991127091cc6d0c1eb60380Russell Power                  errors::InvalidArgument("stride must be a vector, not shape ",
3048fdef24247814d322991127091cc6d0c1eb60380Russell Power                                          stride_shape.DebugString()));
3058fdef24247814d322991127091cc6d0c1eb60380Russell Power      OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_));
3068fdef24247814d322991127091cc6d0c1eb60380Russell Power    }
3078fdef24247814d322991127091cc6d0c1eb60380Russell Power
308c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    OP_REQUIRES(ctx, ksize_.size() == num_dims(),
3091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                errors::InvalidArgument("Sliding window ksize field must "
310c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                                        "specify ",
311c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                                        num_dims(), " dimensions"));
312c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    OP_REQUIRES(ctx, stride_.size() == num_dims(),
3131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                errors::InvalidArgument("Sliding window strides field must "
314c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                                        "specify ",
315c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                                        num_dims(), " dimensions"));
316c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins
3171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const TensorShape tensor_in_shape = ctx->InputShape(0);
3181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const TensorShape tensor_out_shape = ctx->InputShape(1);
3191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const TensorShape out_backprop_shape = ctx->InputShape(2);
3201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
321c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    // For maxpooling, tensor_in should have num_dims() dimensions.
322c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(),
323c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                errors::InvalidArgument("tensor_in must be ", num_dims(),
324c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                                        "-dimensional"));
325c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(),
326c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                errors::InvalidArgument("tensor_out must be ", num_dims(),
327c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                                        "-dimensional"));
328c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    // For maxpooling, out_backprop should have num_dims() dimensions.
329c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
330c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                errors::InvalidArgument("out_backprop must be ", num_dims(),
331c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                                        "-dimensional"));
3321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // TODO(phawkins): The XLA version doesn't need tensor_out. Investigate
3341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // whether this is a good time/space tradeoff.
3351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto input = ctx->Input(0);
3361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto out_backprop = ctx->Input(2);
3371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    xla::Padding xla_padding =
3391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
3401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    xla::PrimitiveType element_type;
3421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
3431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    xla::ComputationDataHandle init_value =
3441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        XlaHelpers::Zero(ctx->builder(), input_type(2));
3451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto select = CreateScalarGeComputation(element_type, ctx->builder());
3461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto scatter = CreateScalarAddComputation(element_type, ctx->builder());
3471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    xla::ComputationDataHandle gradients = ctx->builder()->SelectAndScatter(
3481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        input, select, ksize_, stride_, xla_padding, out_backprop, init_value,
3491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        scatter);
3501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    ctx->SetOutput(0, gradients);
3521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
3531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
354c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins protected:
355c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  const int num_spatial_dims_;
3561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> ksize_;
3571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> stride_;
3581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Padding padding_;
359c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  TensorFormat data_format_ = FORMAT_NHWC;
3601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
3611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
362c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkinsclass MaxPool2DGradOp : public MaxPoolGradOp {
3631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public:
364c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  explicit MaxPool2DGradOp(OpKernelConstruction* ctx)
365c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      : MaxPoolGradOp(ctx, /*num_spatial_dims=*/2) {
3661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    string data_format;
3671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
3681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
3691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                errors::InvalidArgument("Invalid data format"));
370c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  }
371c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins};
37293f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter HawkinsREGISTER_XLA_OP(Name("MaxPoolGrad"), MaxPool2DGradOp);
3738fdef24247814d322991127091cc6d0c1eb60380Russell PowerREGISTER_XLA_OP(Name("MaxPoolGradV2")
3748fdef24247814d322991127091cc6d0c1eb60380Russell Power                    .CompileTimeConstInput("ksize")
3758fdef24247814d322991127091cc6d0c1eb60380Russell Power                    .CompileTimeConstInput("strides"),
3768fdef24247814d322991127091cc6d0c1eb60380Russell Power                MaxPool2DGradOp);
377c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins
378c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkinsclass MaxPool3DGradOp : public MaxPoolGradOp {
379c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins public:
380c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  explicit MaxPool3DGradOp(OpKernelConstruction* ctx)
381c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      : MaxPoolGradOp(ctx, /*num_spatial_dims=*/3) {}
382c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins};
38393f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter HawkinsREGISTER_XLA_OP(Name("MaxPool3DGrad"), MaxPool3DGradOp);
384c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins
385c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins// Average-pooling gradient
386c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkinsclass AvgPoolGradOp : public XlaOpKernel {
387c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins public:
388c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  AvgPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
389c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
3901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
391c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    OP_REQUIRES(ctx, ksize_.size() == num_dims(),
3921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                errors::InvalidArgument("Sliding window ksize field must "
393c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                                        "specify ",
394c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                                        num_dims(), " dimensions"));
3951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
396c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    OP_REQUIRES(ctx, stride_.size() == num_dims(),
3971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                errors::InvalidArgument("Sliding window strides field must "
398c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                                        "specify ",
399c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                                        num_dims(), " dimensions"));
4001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
4011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES(ctx, ksize_[0] == 1 && stride_[0] == 1,
4021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                errors::Unimplemented(
4031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                    "Pooling is not yet supported on the batch dimension."));
4041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
4051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
406c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  int num_dims() const { return num_spatial_dims_ + 2; }
407c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins
4081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  void Compile(XlaOpKernelContext* ctx) override {
4091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    TensorShape gradients_shape;
4101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &gradients_shape));
4111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const TensorShape out_backprop_shape = ctx->InputShape(1);
4131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
414c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    // For avgpooling, tensor_in_shape should have num_dims() dimensions.
415c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    OP_REQUIRES(ctx, gradients_shape.dims() == num_dims(),
416c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                errors::InvalidArgument("orig_input_shape must be ", num_dims(),
417c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                                        "-dimensional"));
4181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
419c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    // For avgpooling, out_backprop should have num_dims() dimensions.
420c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
421c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                errors::InvalidArgument("out_backprop must be ", num_dims(),
422c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                                        "-dimensional"));
4231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
424c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    int depth_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
425c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    int64 depth = out_backprop_shape.dim_size(depth_dim);
4261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // We can think of average-pooling as:
4281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // * a convolution with a kernel consisting entirely of 1s, where the
4291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    //   input feature and output feature are equal, and 0s everywhere else.
4301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // * followed by dividing by the counts.
4311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    //
4321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // This then gives us an algorithm to build the gradient:
4331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // * divide out_backprop by the counts, followed by
4341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // * Conv2DBackpropInput specialized for that kernel, which simplifies to
4351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    //   a Pad and a ReduceWindow.
4361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    //
4371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // For an explanation of backpropagation for convolution, see the comments
4381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // in third_party/tensorflow/core/kernels/conv_grad_ops.h
4391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
440c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    // TF filter shape is [ H, W, ..., inC, outC ]
441c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    std::vector<int64> filter_dims(num_dims());
442c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    for (int i = 0; i < num_spatial_dims_; ++i) {
443c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
444c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      filter_dims[i] = ksize_[dim];
445c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    }
446c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    filter_dims[num_dims() - 2] = depth;
447c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    filter_dims[num_dims() - 1] = depth;
448c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    TensorShape filter_shape(filter_dims);
4491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Reuse the logic from Conv2DBackpropInput to compute padding.
45119dd9342e7bc55c877367b7474caf41e819e38c3Peter Hawkins    ConvBackpropDimensions dims;
452c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    OP_REQUIRES_OK(
453c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins        ctx, ConvBackpropComputeDimensions(
454c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                 type_string(), /*num_spatial_dims=*/num_spatial_dims_,
455c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                 gradients_shape, filter_shape, out_backprop_shape, stride_,
456c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                 padding_, data_format_, &dims));
4571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto out_backprop = ctx->Input(1);
4591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // The input gradients are computed by a convolution of the output
4611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // gradients
4621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // and the filter, with some appropriate padding. See the comment at
4631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // the top of conv_grad_ops.h for details.
4641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    DataType dtype = input_type(1);
4651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    xla::Padding xla_padding =
4671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
4681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Divide the out_backprop values by the counts for each spatial position.
4701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::vector<int64> stride_int64s(stride_.begin(), stride_.end());
471c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    auto out_backprop_div = AvgPoolDivideByCount(
472c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins        ctx, out_backprop, dtype, gradients_shape, xla_padding, ksize_,
473c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins        stride_int64s, num_spatial_dims_, data_format_);
4741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Pad the gradients in the spatial dimensions. We use the same padding
4761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // as Conv2DBackpropInput.
477c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    xla::PaddingConfig padding_config = xla::MakeNoPaddingConfig(num_dims());
478c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    for (int i = 0; i < num_spatial_dims_; ++i) {
479c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
480c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      auto* padding = padding_config.mutable_dimensions(dim);
481c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      padding->set_edge_padding_low(dims.spatial_dims[i].pad_before);
482c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      padding->set_edge_padding_high(dims.spatial_dims[i].pad_after);
483c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      padding->set_interior_padding(dims.spatial_dims[i].stride - 1);
484c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    }
4851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto zero = XlaHelpers::Zero(ctx->builder(), dtype);
4871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto padded_gradients =
4881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        ctx->builder()->Pad(out_backprop_div, zero, padding_config);
4891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // in_backprop = padded_gradients <conv> ones
491c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    std::vector<int64> ones(num_dims(), 1LL);
4921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    xla::ComputationDataHandle in_backprop = ctx->builder()->ReduceWindow(
4931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        padded_gradients, zero, *ctx->GetOrCreateAdd(dtype), ksize_,
494c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins        /* window_strides=*/ones, xla::Padding::kValid);
4951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    ctx->SetOutput(0, in_backprop);
4971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
4981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
499c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins protected:
500c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  const int num_spatial_dims_;
5011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int64> ksize_;
5021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<int32> stride_;
5031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Padding padding_;
504c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  TensorFormat data_format_ = FORMAT_NHWC;
5051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
5061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
507c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkinsclass AvgPool2DGradOp : public AvgPoolGradOp {
508c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins public:
509c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  explicit AvgPool2DGradOp(OpKernelConstruction* ctx)
510c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) {
511c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    string data_format;
512c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
513c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins    OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
514c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins                errors::InvalidArgument("Invalid data format"));
515c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  }
516c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins};
517c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter HawkinsREGISTER_XLA_OP(Name("AvgPoolGrad").CompileTimeConstInput("orig_input_shape"),
518c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins                AvgPool2DGradOp);
519c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins
520c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkinsclass AvgPool3DGradOp : public AvgPoolGradOp {
521c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins public:
522c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins  explicit AvgPool3DGradOp(OpKernelConstruction* ctx)
523c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins      : AvgPoolGradOp(ctx, /*num_spatial_dims=*/3) {}
524c5238fb028f728def4d6ece44e06e8003a8defbbPeter Hawkins};
525c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter HawkinsREGISTER_XLA_OP(Name("AvgPool3DGrad").CompileTimeConstInput("orig_input_shape"),
526c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins                AvgPool3DGradOp);
5271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // anonymous namespace
5291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace tensorflow
530