maxpooling_op.cc revision 28ce1d163eeffe618a6972c5245be0e660d94e85
1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/nn_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/kernels/maxpooling_op.h"
21
22#include <vector>
23#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24#include "tensorflow/core/common_runtime/device.h"
25#include "tensorflow/core/framework/numeric_op.h"
26#include "tensorflow/core/framework/op_kernel.h"
27#include "tensorflow/core/framework/register_types.h"
28#include "tensorflow/core/framework/tensor.h"
29#include "tensorflow/core/framework/tensor_shape.h"
30#include "tensorflow/core/framework/tensor_slice.h"
31#include "tensorflow/core/kernels/conv_2d.h"
32#include "tensorflow/core/kernels/eigen_pooling.h"
33#include "tensorflow/core/kernels/ops_util.h"
34#include "tensorflow/core/kernels/pooling_ops_common.h"
35#include "tensorflow/core/lib/core/errors.h"
36#include "tensorflow/core/lib/gtl/array_slice.h"
37#include "tensorflow/core/util/padding.h"
38#include "tensorflow/core/util/tensor_format.h"
39#include "tensorflow/core/util/use_cudnn.h"
40
41#if GOOGLE_CUDA
42#include "tensorflow/core/kernels/maxpooling_op_gpu.h"
43#include "tensorflow/core/kernels/pooling_ops_common_gpu.h"
44#include "tensorflow/core/platform/stream_executor.h"
45#endif  // GOOGLE_CUDA
46
47namespace tensorflow {
48
49typedef Eigen::ThreadPoolDevice CPUDevice;
50typedef Eigen::GpuDevice GPUDevice;
51
52const int kInvalidMaxPoolingIndex = -1;
53
54template <typename Device, typename T>
55static void SpatialMaxPoolWithArgMaxHelper(
56    OpKernelContext* context, Tensor* output, Tensor* output_arg_max,
57    Tensor* input_backprop, const Tensor& tensor_in, const Tensor& out_backprop,
58    const PoolParameters& params, const Padding& padding) {
59  typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
60      ConstEigenMatrixMap;
61  typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
62      EigenMatrixMap;
63  typedef Eigen::Map<Eigen::Matrix<int64, Eigen::Dynamic, Eigen::Dynamic>>
64      EigenIndexMatrixMap;
65
66  ConstEigenMatrixMap in_mat(
67      tensor_in.flat<T>().data(), params.depth,
68      params.tensor_in_cols * params.tensor_in_rows * params.tensor_in_batch);
69  EigenMatrixMap out_mat(
70      output->flat<T>().data(), params.depth,
71      params.out_width * params.out_height * params.tensor_in_batch);
72  EigenIndexMatrixMap out_arg_max_mat(
73      output_arg_max->flat<int64>().data(), params.depth,
74      params.out_width * params.out_height * params.tensor_in_batch);
75
76  const DeviceBase::CpuWorkerThreads& worker_threads =
77      *(context->device()->tensorflow_cpu_worker_threads());
78
79  // The following code basically does the following:
80  // 1. Flattens the input and output tensors into two dimensional arrays.
81  //    tensor_in_as_matrix:
82  //      depth by (tensor_in_cols * tensor_in_rows * tensor_in_batch)
83  //    output_as_matrix:
84  //      depth by (out_width * out_height * tensor_in_batch)
85  //
86  // 2. Walks through the set of columns in the flattened tensor_in_as_matrix,
87  //    and updates the corresponding column(s) in output_as_matrix with the
88  //    max value.
89  auto shard = [&params, &in_mat, &out_mat, &out_arg_max_mat, &input_backprop,
90                &output_arg_max, &out_backprop](int64 start, int64 limit) {
91
92    const int32 depth = params.depth;
93    const int32 in_rows = params.tensor_in_rows;
94    const int32 in_cols = params.tensor_in_cols;
95    const int32 pad_rows = params.pad_rows;
96    const int32 pad_cols = params.pad_cols;
97    const int32 window_rows = params.window_rows;
98    const int32 window_cols = params.window_cols;
99    const int32 row_stride = params.row_stride;
100    const int32 col_stride = params.col_stride;
101    const int32 out_height = params.out_height;
102    const int32 out_width = params.out_width;
103
104    {
105      // Initializes the output tensor with MIN<T>.
106      const int32 output_image_size = out_height * out_width * depth;
107      EigenMatrixMap out_shard(out_mat.data() + start * output_image_size, 1,
108                               (limit - start) * output_image_size);
109      out_shard.setConstant(Eigen::NumTraits<T>::lowest());
110      EigenIndexMatrixMap out_arg_max_shard(
111          out_arg_max_mat.data() + start * output_image_size, 1,
112          (limit - start) * output_image_size);
113      out_arg_max_shard.setConstant(kInvalidMaxPoolingIndex);
114    }
115
116    for (int64 b = start; b < limit; ++b) {
117      for (int h = 0; h < in_rows; ++h) {
118        for (int w = 0; w < in_cols; ++w) {
119          // (h_start, h_end) * (w_start, w_end) is the range that the input
120          // vector projects to.
121          const int hpad = h + pad_rows;
122          const int wpad = w + pad_cols;
123          const int h_start =
124              (hpad < window_rows) ? 0 : (hpad - window_rows) / row_stride + 1;
125          const int h_end = std::min(hpad / row_stride + 1, out_height);
126          const int w_start =
127              (wpad < window_cols) ? 0 : (wpad - window_cols) / col_stride + 1;
128          const int w_end = std::min(wpad / col_stride + 1, out_width);
129          // compute elementwise max
130          const int64 in_index = (b * in_rows + h) * in_cols + w;
131          for (int ph = h_start; ph < h_end; ++ph) {
132            const int64 out_index_base = (b * out_height + ph) * out_width;
133            for (int pw = w_start; pw < w_end; ++pw) {
134              const int64 out_index = out_index_base + pw;
135              /// NOTES(zhengxq): not using the eigen matrix operation for
136              /// now.
137              for (int d = 0; d < depth; ++d) {
138                const T& input_ref = in_mat.coeffRef(d, in_index);
139                T& output_ref = out_mat.coeffRef(d, out_index);
140                int64& out_arg_max_ref = out_arg_max_mat.coeffRef(d, out_index);
141                if (output_ref < input_ref ||
142                    out_arg_max_ref == kInvalidMaxPoolingIndex) {
143                  output_ref = input_ref;
144                  int64 input_offset = in_index * depth + d;
145                  out_arg_max_ref = input_offset;
146                }
147              }
148            }
149          }
150        }
151      }
152    }
153
154    {
155      auto input_backprop_flat = input_backprop->flat<T>();
156      auto out_arg_max_flat = output_arg_max->flat<int64>();
157      auto out_backprop_flat = out_backprop.flat<T>();
158
159      // Initialize output to 0.
160      const int64 in_size = in_rows * in_cols * depth;
161      const int64 in_start = start * in_size;
162      const int64 in_end = limit * in_size;
163      EigenMatrixMap in_shard(input_backprop_flat.data() + in_start, 1,
164                              in_end - in_start);
165      in_shard.setConstant(T(0));
166
167      // Backpropagate.
168      const int out_size = out_height * out_width * depth;
169      const int out_start = start * out_size;
170      const int out_end = limit * out_size;
171      for (int index = out_start; index < out_end; ++index) {
172        int input_backprop_index = out_arg_max_flat(index);
173        // Although this check is in the inner loop, it is worth its value
174        // so we don't end up with memory corruptions. Our benchmark shows that
175        // the performance impact is quite small
176        CHECK(input_backprop_index >= in_start && input_backprop_index < in_end)
177            << "Invalid input backprop index: " << input_backprop_index << ", "
178            << in_start << ", " << in_end;
179        input_backprop_flat(input_backprop_index) += out_backprop_flat(index);
180      }
181    }
182
183  };
184
185  const int64 shard_cost = params.tensor_in_rows * params.tensor_in_cols *
186                           params.depth * params.window_rows *
187                           params.window_cols;
188  Shard(worker_threads.num_threads, worker_threads.workers,
189        params.tensor_in_batch, shard_cost, shard);
190}
191
192// The operation to compute MaxPool gradients.
193// It takes three inputs:
194//   - The original input tensor
195//   - The original output tensor
196//   - Backprop tensor for output
197// It produces one output: backprop tensor for input.
198template <class Device, class T>
199class MaxPoolingGradOp : public OpKernel {
200 public:
201  explicit MaxPoolingGradOp(OpKernelConstruction* context) : OpKernel(context) {
202    string data_format;
203    OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
204    OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
205                errors::InvalidArgument("Invalid data format"));
206    OP_REQUIRES(
207        context, data_format_ == FORMAT_NHWC,
208        errors::InvalidArgument("Default MaxPoolingGradOp only supports NHWC ",
209                                "on device type ",
210                                DeviceTypeString(context->device_type())));
211
212    if (context->num_inputs() == 3) {
213      OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
214      OP_REQUIRES(context, ksize_.size() == 4,
215                  errors::InvalidArgument("Sliding window ksize field must "
216                                          "specify 4 dimensions"));
217      OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
218      OP_REQUIRES(context, stride_.size() == 4,
219                  errors::InvalidArgument("Sliding window strides field must "
220                                          "specify 4 dimensions"));
221      OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
222                  errors::Unimplemented(
223                      "Pooling is not yet supported on the batch dimension."));
224      OP_REQUIRES(
225          context, ksize_[3] == 1 && stride_[3] == 1,
226          errors::Unimplemented(
227              "MaxPoolingGrad is not yet supported on the depth dimension."));
228    }
229
230    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
231  }
232
233  void Compute(OpKernelContext* context) override {
234    const Tensor& tensor_in = context->input(0);
235    const Tensor& tensor_out = context->input(1);
236    const Tensor& out_backprop = context->input(2);
237
238    // For maxpooling, tensor_in should have 4 dimensions.
239    OP_REQUIRES(context, tensor_in.dims() == 4,
240                errors::InvalidArgument("tensor_in must be 4-dimensional"));
241    OP_REQUIRES(context, tensor_out.dims() == 4,
242                errors::InvalidArgument("tensor_out must be 4-dimensional"));
243    // For maxpooling, out_backprop should have 4 dimensions.
244    OP_REQUIRES(context, out_backprop.dims() == 4,
245                errors::InvalidArgument("out_backprop must be 4-dimensional"));
246
247    const TensorShape& output_shape = tensor_in.shape();
248
249    Tensor tensor_out_dup;
250    OP_REQUIRES_OK(context, context->forward_input_or_allocate_temp(
251                                {1}, DataTypeToEnum<T>::v(), tensor_out.shape(),
252                                &tensor_out_dup));
253    Tensor tensor_out_arg_max;
254    OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<int64>::v(),
255                                                   tensor_out.shape(),
256                                                   &tensor_out_arg_max));
257    std::vector<int32> ksize = ksize_;
258    std::vector<int32> stride = stride_;
259    if (context->num_inputs() == 5) {
260      const Tensor& tensor_ksize = context->input(3);
261      auto value_ksize = tensor_ksize.flat<int32>();
262      ksize.resize(tensor_ksize.shape().num_elements());
263      std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
264
265      const Tensor& tensor_stride = context->input(4);
266      auto value_stride = tensor_stride.flat<int32>();
267      stride.resize(tensor_stride.shape().num_elements());
268      std::copy_n(&value_stride(0), stride.size(), stride.begin());
269    }
270
271    OP_REQUIRES(context, ksize.size() == 4,
272                errors::InvalidArgument("Sliding window ksize field must "
273                                        "specify 4 dimensions"));
274    OP_REQUIRES(context, stride.size() == 4,
275                errors::InvalidArgument("Sliding window strides field must "
276                                        "specify 4 dimensions"));
277    OP_REQUIRES(context, ksize[0] == 1 && stride[0] == 1,
278                errors::Unimplemented(
279                    "Pooling is not yet supported on the batch dimension."));
280    OP_REQUIRES(
281        context, ksize[3] == 1 && stride[3] == 1,
282        errors::Unimplemented(
283            "MaxPoolingGrad is not yet supported on the depth dimension."));
284
285    PoolParameters params{context,  ksize,       stride,
286                          padding_, FORMAT_NHWC, tensor_in.shape()};
287    if (!context->status().ok()) {
288      return;
289    }
290
291    Tensor* output = nullptr;
292    OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
293                                {0}, 0, output_shape, &output));
294
295    SpatialMaxPoolWithArgMaxHelper<CPUDevice, T>(
296        context, &tensor_out_dup, &tensor_out_arg_max, output, tensor_in,
297        out_backprop, params, padding_);
298  }
299
300 private:
301  std::vector<int32> ksize_;
302  std::vector<int32> stride_;
303  Padding padding_;
304  TensorFormat data_format_;
305};
306
307#ifdef GOOGLE_CUDA
308
309template <typename T>
310static void MaxPoolingBackwardCustomKernel(
311    OpKernelContext* context, const std::vector<int32>& size,
312    const std::vector<int32>& stride, Padding padding, const Tensor* tensor_in,
313    const Tensor& out_backprop, const TensorShape& tensor_in_shape) {
314  Tensor* output = nullptr;
315  OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
316                              {0}, 0, tensor_in_shape, &output));
317
318  PoolParameters params{context, size,        stride,
319                        padding, FORMAT_NHWC, tensor_in_shape};
320  if (!context->status().ok()) {
321    return;
322  }
323
324  functor::MaxPoolBackwardNoMask<T>()(
325      tensor_in->flat<T>().data(), params.tensor_in_batch,
326      params.tensor_in_rows, params.tensor_in_cols, params.depth,
327      params.out_height, params.out_width, params.window_rows,
328      params.window_cols, params.row_stride, params.col_stride, params.pad_rows,
329      params.pad_cols, out_backprop.flat<T>().data(), output->flat<T>().data(),
330      context->eigen_device<Eigen::GpuDevice>());
331}
332
333template <class T>
334class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
335 public:
336  typedef Eigen::GpuDevice Device;
337
338  explicit MaxPoolingGradOp(OpKernelConstruction* context) : OpKernel(context) {
339    string data_format;
340    OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
341    OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
342                errors::InvalidArgument("Invalid data format"));
343    if (context->num_inputs() == 3) {
344      OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
345      OP_REQUIRES(context, ksize_.size() == 4,
346                  errors::InvalidArgument("Sliding window ksize field must "
347                                          "specify 4 dimensions"));
348      OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
349      OP_REQUIRES(context, stride_.size() == 4,
350                  errors::InvalidArgument("Sliding window strides field must "
351                                          "specify 4 dimensions"));
352      const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N');
353      const int32 stride_n = GetTensorDim(stride_, data_format_, 'N');
354      OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
355                  errors::Unimplemented(
356                      "Pooling is not yet supported on the batch dimension."));
357    }
358    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
359
360    use_dnn_ = CanUseCudnn();
361  }
362
363  void Compute(OpKernelContext* context) override {
364    const Tensor& tensor_in = context->input(0);
365    const Tensor& tensor_out = context->input(1);
366    const Tensor& out_backprop = context->input(2);
367
368    // For maxpooling, tensor_in should have 4 dimensions.
369    OP_REQUIRES(context, tensor_in.dims() == 4,
370                errors::InvalidArgument("tensor_in must be 4-dimensional 4"));
371    OP_REQUIRES(context, tensor_out.dims() == 4,
372                errors::InvalidArgument("tensor_out must be 4-dimensional"));
373    // For maxpooling, out_backprop should have 4 dimensions.
374    OP_REQUIRES(context, out_backprop.dims() == 4,
375                errors::InvalidArgument("out_backprop must be 4-dimensional"));
376
377    TensorShape output_shape = tensor_in.shape();
378
379    std::vector<int32> ksize = ksize_;
380    std::vector<int32> stride = stride_;
381    if (context->num_inputs() == 5) {
382      const Tensor& tensor_ksize = context->input(3);
383      auto value_ksize = tensor_ksize.flat<int32>();
384      ksize.resize(tensor_ksize.shape().num_elements());
385      std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
386
387      const Tensor& tensor_stride = context->input(4);
388      auto value_stride = tensor_stride.flat<int32>();
389      stride.resize(tensor_stride.shape().num_elements());
390      std::copy_n(&value_stride(0), stride.size(), stride.begin());
391    }
392    OP_REQUIRES(context, ksize.size() == 4,
393                errors::InvalidArgument("Sliding window ksize field must "
394                                        "specify 4 dimensions"));
395    OP_REQUIRES(context, stride.size() == 4,
396                errors::InvalidArgument("Sliding window strides field must "
397                                        "specify 4 dimensions"));
398    const int32 ksize_n = GetTensorDim(ksize, data_format_, 'N');
399    const int32 stride_n = GetTensorDim(stride, data_format_, 'N');
400    OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
401                errors::Unimplemented(
402                    "Pooling is not yet supported on the batch dimension."));
403
404    if (use_dnn_) {
405      DnnPoolingGradOp<T>::Compute(
406          context, perftools::gputools::dnn::PoolingMode::kMaximum, ksize,
407          stride, padding_, data_format_, &tensor_in, &tensor_out, out_backprop,
408          output_shape);
409    } else {
410      CHECK(data_format_ == FORMAT_NHWC)
411          << "Non-Cudnn MaxPoolGrad only supports NHWC format";
412      MaxPoolingBackwardCustomKernel<T>(context, ksize, stride, padding_,
413                                        &tensor_in, out_backprop, output_shape);
414    }
415  }
416
417 private:
418  std::vector<int32> ksize_;
419  std::vector<int32> stride_;
420  Padding padding_;
421  TensorFormat data_format_;
422  bool use_dnn_;
423};
424
425#endif  // GOOGLE_CUDA
426
427// The operation to compute gradient of MaxPool gradients.
428// It takes three inputs:
429//   - The original input tensor
430//   - The original output tensor
431//   - Backprop tensor for output gradients
432// It produces one output: backprop tensor for output gradient.
433template <class Device, class T>
434class MaxPoolingGradGradOp : public OpKernel {
435 public:
436  explicit MaxPoolingGradGradOp(OpKernelConstruction* context)
437      : OpKernel(context) {
438    string data_format;
439    OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
440    OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
441                errors::InvalidArgument("Invalid data format"));
442    OP_REQUIRES(
443        context, data_format_ == FORMAT_NHWC,
444        errors::InvalidArgument(
445            "Default MaxPoolingGradGradOp only supports NHWC ",
446            "on device type ", DeviceTypeString(context->device_type())));
447
448    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
449
450    if (context->num_inputs() == 3) {
451      OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
452      OP_REQUIRES(context, ksize_.size() == 4,
453                  errors::InvalidArgument("Sliding window ksize field must "
454                                          "specify 4 dimensions"));
455      OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
456      OP_REQUIRES(context, stride_.size() == 4,
457                  errors::InvalidArgument("Sliding window strides field must "
458                                          "specify 4 dimensions"));
459      OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
460                  errors::Unimplemented(
461                      "Pooling is not yet supported on the batch dimension."));
462      OP_REQUIRES(context, ksize_[3] == 1 && stride_[3] == 1,
463                  errors::Unimplemented("MaxPoolingGradGrad is not yet "
464                                        "supported on the depth dimension."));
465    }
466  }
467
468  void Compute(OpKernelContext* context) override {
469    const Tensor& tensor_in = context->input(0);
470    const Tensor& tensor_out = context->input(1);
471    const Tensor& out_grad_backprop = context->input(2);
472
473    // For maxpooling, tensor_in should have 4 dimensions.
474    OP_REQUIRES(context, tensor_in.dims() == 4,
475                errors::InvalidArgument("tensor_in must be 4-dimensional"));
476    OP_REQUIRES(context, tensor_out.dims() == 4,
477                errors::InvalidArgument("tensor_out must be 4-dimensional"));
478    // For maxpooling, out_grad_backprop should have 4 dimensions.
479    OP_REQUIRES(
480        context, out_grad_backprop.dims() == 4,
481        errors::InvalidArgument("out_grad_backprop must be 4-dimensional"));
482
483    std::vector<int32> ksize = ksize_;
484    std::vector<int32> stride = stride_;
485    if (context->num_inputs() == 5) {
486      const Tensor& tensor_ksize = context->input(3);
487      auto value_ksize = tensor_ksize.flat<int32>();
488      ksize.resize(tensor_ksize.shape().num_elements());
489      std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
490
491      const Tensor& tensor_stride = context->input(4);
492      auto value_stride = tensor_stride.flat<int32>();
493      stride.resize(tensor_stride.shape().num_elements());
494      std::copy_n(&value_stride(0), stride.size(), stride.begin());
495    }
496
497    OP_REQUIRES(context, ksize.size() == 4,
498                errors::InvalidArgument("Sliding window ksize field must "
499                                        "specify 4 dimensions"));
500    OP_REQUIRES(context, stride.size() == 4,
501                errors::InvalidArgument("Sliding window strides field must "
502                                        "specify 4 dimensions"));
503    OP_REQUIRES(context, ksize[0] == 1 && stride[0] == 1,
504                errors::Unimplemented(
505                    "Pooling is not yet supported on the batch dimension."));
506    OP_REQUIRES(
507        context, ksize[3] == 1 && stride[3] == 1,
508        errors::Unimplemented(
509            "MaxPoolingGrad is not yet supported on the depth dimension."));
510
511    PoolParameters params{context,  ksize,       stride,
512                          padding_, FORMAT_NHWC, tensor_in.shape()};
513    Tensor* output = nullptr;
514    OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
515                                {2}, 0, tensor_out.shape(), &output));
516
517    SpatialMaxPoolGradGrad(context, output, tensor_in, tensor_out,
518                           out_grad_backprop, params, padding_);
519  }
520
521 private:
522  void SpatialMaxPoolGradGrad(OpKernelContext* context, Tensor* bottom_diff,
523                              const Tensor& tensor_in, const Tensor& tensor_out,
524                              const Tensor& top_diff,
525                              const PoolParameters& params,
526                              const Padding& padding) {
527    typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
528        ConstEigenMatrixMap;
529    typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
530        EigenMatrixMap;
531
532    ConstEigenMatrixMap in_mat(
533        tensor_in.flat<T>().data(), params.depth,
534        params.tensor_in_cols * params.tensor_in_rows * params.tensor_in_batch);
535    ConstEigenMatrixMap out_mat(
536        tensor_out.flat<T>().data(), params.depth,
537        params.out_width * params.out_height * params.tensor_in_batch);
538    ConstEigenMatrixMap top_diff_mat(
539        top_diff.flat<T>().data(), params.depth,
540        params.tensor_in_cols * params.tensor_in_rows * params.tensor_in_batch);
541    EigenMatrixMap bottom_diff_mat(
542        bottom_diff->flat<T>().data(), params.depth,
543        params.out_width * params.out_height * params.tensor_in_batch);
544
545    const DeviceBase::CpuWorkerThreads& worker_threads =
546        *(context->device()->tensorflow_cpu_worker_threads());
547
548    // The following code basically does the following:
549    // 1. Flattens the input, output, top_diff and bottom_diff tensors into
550    //    two dimensional arrays.
551    //    tensor_in_as_matrix:
552    //      depth by (tensor_in_cols * tensor_in_rows * tensor_in_batch)
553    //    tensor_out_as_matrix:
554    //      depth by (out_width * out_height * tensor_in_batch)
555    //    top_diff_as_matrix:
556    //      depth by (tensor_in_cols * tensor_in_rows * tensor_in_batch)
557    //    bottom_diff_as_matrix:
558    //      depth by (out_width * out_height * tensor_in_batch)
559    //
560    // 2. Walks through the set of columns in the flattened
561    //    tensor_in_as_matrix, tensor_out_as_matrix, top_diff_as_matrix
562    //    and updates the column(s) corresponding to the maximum values in
563    //    tensor_out_as_matrix with the corresponding values in
564    //    top_diff_as_matrix.
565    auto shard = [&params, &in_mat, &out_mat, &top_diff_mat, &bottom_diff_mat](
566        int64 start, int64 limit) {
567      const int32 depth = params.depth;
568      const int32 in_rows = params.tensor_in_rows;
569      const int32 in_cols = params.tensor_in_cols;
570      const int32 pad_rows = params.pad_rows;
571      const int32 pad_cols = params.pad_cols;
572      const int32 window_rows = params.window_rows;
573      const int32 window_cols = params.window_cols;
574      const int32 row_stride = params.row_stride;
575      const int32 col_stride = params.col_stride;
576      const int32 out_height = params.out_height;
577      const int32 out_width = params.out_width;
578
579      {
580        // Initializes the output grad backprop tensor with 0.
581        const int32 output_image_size = out_height * out_width * params.depth;
582        EigenMatrixMap bottom_diff_shard(
583            bottom_diff_mat.data() + start * output_image_size, 1,
584            (limit - start) * output_image_size);
585        bottom_diff_shard.setZero();
586      }
587
588      for (int b = start; b < limit; ++b) {
589        for (int ph = 0; ph < out_height; ++ph) {
590          for (int pw = 0; pw < out_width; ++pw) {
591            // (h_start, h_end) * (w_start, w_end) is the range that the input
592            // vector projects to.
593            int h_start = ph * row_stride - pad_rows;
594            const int h_end = std::min(h_start + window_rows, in_rows);
595            int w_start = pw * col_stride - pad_cols;
596            const int w_end = std::min(w_start + window_cols, in_cols);
597            h_start = std::max(h_start, 0);
598            w_start = std::max(w_start, 0);
599            const int out_index = (b * out_height + ph) * out_width + pw;
600            // Find value corresponding to the input maximum in top_diff.
601            for (int d = 0; d < depth; ++d) {
602              const T& output_ref = out_mat.coeffRef(d, out_index);
603              bool should_stop = false;
604              for (int h = h_start; h < h_end && !should_stop; ++h) {
605                for (int w = w_start; w < w_end && !should_stop; ++w) {
606                  const int in_index = (b * in_rows + h) * in_cols + w;
607                  const T& input_ref = in_mat.coeffRef(d, in_index);
608                  if (output_ref == input_ref) {
609                    T& bottom_diff_ref = bottom_diff_mat.coeffRef(d, out_index);
610                    bottom_diff_ref = top_diff_mat.coeffRef(d, in_index);
611                    should_stop = true;
612                  }
613                }
614              }
615            }
616          }
617        }
618      }
619    };
620
621    const int64 shard_cost = params.out_width * params.out_height *
622                             params.depth * params.window_rows *
623                             params.window_cols;
624    Shard(worker_threads.num_threads, worker_threads.workers,
625          params.tensor_in_batch, shard_cost, shard);
626  }
627
628  std::vector<int32> ksize_;
629  std::vector<int32> stride_;
630  Padding padding_;
631  TensorFormat data_format_;
632};
633
634#ifdef GOOGLE_CUDA
635
636template <class T>
637class MaxPoolingGradGradOp<Eigen::GpuDevice, T> : public OpKernel {
638 public:
639  typedef Eigen::GpuDevice Device;
640
641  explicit MaxPoolingGradGradOp(OpKernelConstruction* context)
642      : OpKernel(context) {
643    string data_format;
644    OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
645    OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
646                errors::InvalidArgument("Invalid data format"));
647    if (context->num_inputs() == 3) {
648      OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
649      OP_REQUIRES(context, ksize_.size() == 4,
650                  errors::InvalidArgument("Sliding window ksize field must "
651                                          "specify 4 dimensions"));
652      OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
653      OP_REQUIRES(context, stride_.size() == 4,
654                  errors::InvalidArgument("Sliding window strides field must "
655                                          "specify 4 dimensions"));
656      const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N');
657      const int32 stride_n = GetTensorDim(stride_, data_format_, 'N');
658      OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
659                  errors::Unimplemented(
660                      "Pooling is not yet supported on the batch dimension."));
661    }
662    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
663  }
664
665  void Compute(OpKernelContext* context) override {
666    const Tensor& tensor_in = context->input(0);
667    const Tensor& tensor_out = context->input(1);
668    const Tensor& out_grad_backprop = context->input(2);
669
670    // For maxpooling, tensor_in should have 4 dimensions.
671    OP_REQUIRES(context, tensor_in.dims() == 4,
672                errors::InvalidArgument("tensor_in must be 4-dimensional 4"));
673    OP_REQUIRES(context, tensor_out.dims() == 4,
674                errors::InvalidArgument("tensor_out must be 4-dimensional"));
675    // For maxpooling, out_grad_backprop should have 4 dimensions.
676    OP_REQUIRES(
677        context, out_grad_backprop.dims() == 4,
678        errors::InvalidArgument("out_grad_backprop must be 4-dimensional"));
679
680    Tensor* output = nullptr;
681    OP_REQUIRES_OK(context,
682                   context->allocate_output(0, tensor_out.shape(), &output));
683
684    std::vector<int32> ksize = ksize_;
685    std::vector<int32> stride = stride_;
686    if (context->num_inputs() == 5) {
687      const Tensor& tensor_ksize = context->input(3);
688      auto value_ksize = tensor_ksize.flat<int32>();
689      ksize.resize(tensor_ksize.shape().num_elements());
690      std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
691
692      const Tensor& tensor_stride = context->input(4);
693      auto value_stride = tensor_stride.flat<int32>();
694      stride.resize(tensor_stride.shape().num_elements());
695      std::copy_n(&value_stride(0), stride.size(), stride.begin());
696    }
697
698    OP_REQUIRES(context, ksize.size() == 4,
699                errors::InvalidArgument("Sliding window ksize field must "
700                                        "specify 4 dimensions"));
701    OP_REQUIRES(context, stride.size() == 4,
702                errors::InvalidArgument("Sliding window strides field must "
703                                        "specify 4 dimensions"));
704    const int32 ksize_n = GetTensorDim(ksize, data_format_, 'N');
705    const int32 stride_n = GetTensorDim(stride, data_format_, 'N');
706    OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
707                errors::Unimplemented(
708                    "Pooling is not yet supported on the batch dimension."));
709
710    PoolParameters params{context,  ksize,        stride,
711                          padding_, data_format_, tensor_in.shape()};
712
713    functor::MaxPoolGradBackwardNoMask<T>()(
714        data_format_, tensor_in.flat<T>().data(), tensor_out.flat<T>().data(),
715        params.tensor_in_batch, params.out_height, params.out_width,
716        params.depth, params.tensor_in_rows, params.tensor_in_cols,
717        params.window_rows, params.window_cols, params.row_stride,
718        params.col_stride, params.pad_rows, params.pad_cols,
719        out_grad_backprop.flat<T>().data(), output->flat<T>().data(),
720        context->eigen_device<Eigen::GpuDevice>());
721  }
722
723 private:
724  std::vector<int32> ksize_;
725  std::vector<int32> stride_;
726  Padding padding_;
727  TensorFormat data_format_;
728  bool use_dnn_;
729};
730
731#endif  // GOOGLE_CUDA
732
733template <typename Device, typename T>
734struct LaunchMaxPoolingNoMask;
735
736template <typename Device, typename T>
737class MaxPoolingNoMaskOp : public OpKernel {
738 public:
739  explicit MaxPoolingNoMaskOp(OpKernelConstruction* context)
740      : OpKernel(context) {
741    string data_format;
742    OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
743    OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
744                errors::InvalidArgument("Invalid data format"));
745    OP_REQUIRES(
746        context, data_format_ == FORMAT_NHWC,
747        errors::InvalidArgument(
748            "Default MaxPoolingNoMaskOp only supports NHWC on device type ",
749            DeviceTypeString(context->device_type())));
750    OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
751    OP_REQUIRES(context, ksize_.size() == 4,
752                errors::InvalidArgument("Sliding window ksize field must "
753                                        "specify 4 dimensions"));
754    OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
755    OP_REQUIRES(context, stride_.size() == 4,
756                errors::InvalidArgument("Sliding window stride field must "
757                                        "specify 4 dimensions"));
758    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
759    OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
760                errors::Unimplemented(
761                    "Pooling is not yet supported on the batch dimension."));
762  }
763
764  void Compute(OpKernelContext* context) override {
765    const Tensor& tensor_in = context->input(0);
766
767    PoolParameters params{context,  ksize_,       stride_,
768                          padding_, data_format_, tensor_in.shape()};
769    if (!context->status().ok()) {
770      return;
771    }
772
773    TensorShape out_shape({params.tensor_in_batch, params.out_height,
774                           params.out_width, params.depth});
775    Tensor* output = nullptr;
776    OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
777
778    LaunchMaxPoolingNoMask<Device, T>::launch(context, params, tensor_in,
779                                              output);
780  }
781
782 private:
783  std::vector<int32> ksize_;
784  std::vector<int32> stride_;
785  Padding padding_;
786  TensorFormat data_format_;
787};
788
789template <typename Device, typename T>
790class MaxPoolingNoMaskV2Op : public OpKernel {
791 public:
792  explicit MaxPoolingNoMaskV2Op(OpKernelConstruction* context)
793      : OpKernel(context) {
794    string data_format;
795    OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
796    OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
797                errors::InvalidArgument("Invalid data format"));
798    OP_REQUIRES(
799        context, data_format_ == FORMAT_NHWC,
800        errors::InvalidArgument(
801            "Default MaxPoolingNoMaskOp only supports NHWC on device type ",
802            DeviceTypeString(context->device_type())));
803    if (context->num_inputs() == 1) {
804      OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
805      OP_REQUIRES(context, ksize_.size() == 4,
806                  errors::InvalidArgument("Sliding window ksize field must "
807                                          "specify 4 dimensions"));
808      OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
809      OP_REQUIRES(context, stride_.size() == 4,
810                  errors::InvalidArgument("Sliding window stride field must "
811                                          "specify 4 dimensions"));
812      OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
813                  errors::Unimplemented(
814                      "Pooling is not yet supported on the batch dimension."));
815    }
816    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
817  }
818
819  void Compute(OpKernelContext* context) override {
820    const Tensor& tensor_in = context->input(0);
821
822    std::vector<int32> ksize = ksize_;
823    std::vector<int32> stride = stride_;
824
825    if (context->num_inputs() != 1) {
826      const Tensor& tensor_ksize = context->input(1);
827      auto value_ksize = tensor_ksize.flat<int32>();
828      ksize.resize(tensor_ksize.shape().num_elements());
829      std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
830
831      const Tensor& tensor_stride = context->input(2);
832      auto value_stride = tensor_stride.flat<int32>();
833      stride.resize(tensor_stride.shape().num_elements());
834      std::copy_n(&value_stride(0), stride.size(), stride.begin());
835    }
836    OP_REQUIRES(context, ksize.size() == 4,
837                errors::InvalidArgument("Sliding window ksize field must "
838                                        "specify 4 dimensions"));
839    OP_REQUIRES(context, stride.size() == 4,
840                errors::InvalidArgument("Sliding window stride field must "
841                                        "specify 4 dimensions"));
842    OP_REQUIRES(context, ksize[0] == 1 && stride[0] == 1,
843                errors::Unimplemented(
844                    "Pooling is not yet supported on the batch dimension."));
845    PoolParameters params{context,  ksize,        stride,
846                          padding_, data_format_, tensor_in.shape()};
847    if (!context->status().ok()) {
848      return;
849    }
850
851    TensorShape out_shape({params.tensor_in_batch, params.out_height,
852                           params.out_width, params.depth});
853    Tensor* output = nullptr;
854    OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
855
856    LaunchMaxPoolingNoMask<Device, T>::launch(context, params, tensor_in,
857                                              output);
858  }
859
860 private:
861  std::vector<int32> ksize_;
862  std::vector<int32> stride_;
863  Padding padding_;
864  TensorFormat data_format_;
865};
866
867template <typename Device, typename T>
868struct LaunchMaxPoolingWithArgmax;
869
870template <typename Device, typename T>
871class MaxPoolingWithArgmaxOp : public OpKernel {
872 public:
873  explicit MaxPoolingWithArgmaxOp(OpKernelConstruction* context)
874      : OpKernel(context) {
875    OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
876    OP_REQUIRES(context, ksize_.size() == 4,
877                errors::InvalidArgument("Sliding window ksize field must "
878                                        "specify 4 dimensions"));
879    OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
880    OP_REQUIRES(context, stride_.size() == 4,
881                errors::InvalidArgument("Sliding window stride field must "
882                                        "specify 4 dimensions"));
883    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
884    OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
885                errors::Unimplemented(
886                    "Pooling is not yet supported on the batch dimension."));
887  }
888
889  void Compute(OpKernelContext* context) override {
890    const Tensor& tensor_in = context->input(0);
891
892    PoolParameters params{context,  ksize_,      stride_,
893                          padding_, FORMAT_NHWC, tensor_in.shape()};
894    if (!context->status().ok()) {
895      return;
896    }
897
898    TensorShape out_shape({params.tensor_in_batch, params.out_height,
899                           params.out_width, params.depth});
900    Tensor* output = nullptr;
901    OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
902    Tensor* argmax = nullptr;
903    OP_REQUIRES_OK(context, context->allocate_output(1, out_shape, &argmax));
904
905    LaunchMaxPoolingWithArgmax<Device, T>::launch(context, params, tensor_in,
906                                                  output, argmax);
907  }
908
909 private:
910  std::vector<int32> ksize_;
911  std::vector<int32> stride_;
912  Padding padding_;
913};
914
915template <typename Device, typename T>
916struct LaunchMaxPoolingGradWithArgmax;
917
918template <typename Device, typename T>
919class MaxPoolingGradWithArgmaxOp : public OpKernel {
920 public:
921  explicit MaxPoolingGradWithArgmaxOp(OpKernelConstruction* context)
922      : OpKernel(context) {
923    OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
924    OP_REQUIRES(context, ksize_.size() == 4,
925                errors::InvalidArgument("Sliding window ksize field must "
926                                        "specify 4 dimensions"));
927    OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
928    OP_REQUIRES(context, stride_.size() == 4,
929                errors::InvalidArgument("Sliding window stride field must "
930                                        "specify 4 dimensions"));
931    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
932    OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
933                errors::Unimplemented(
934                    "Pooling is not yet supported on the batch dimension."));
935  }
936
937  void Compute(OpKernelContext* context) override {
938    const Tensor& tensor_in = context->input(0);
939    const Tensor& grad_in = context->input(1);
940    const Tensor& argmax = context->input(2);
941
942    PoolParameters params{context,  ksize_,      stride_,
943                          padding_, FORMAT_NHWC, tensor_in.shape()};
944    if (!context->status().ok()) {
945      return;
946    }
947
948    TensorShape out_shape({params.tensor_in_batch, params.tensor_in_rows,
949                           params.tensor_in_cols, params.depth});
950    Tensor* grad_out = nullptr;
951    OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
952                                {1}, 0, out_shape, &grad_out));
953
954    LaunchMaxPoolingGradWithArgmax<Device, T>::launch(context, params, grad_in,
955                                                      argmax, grad_out);
956  }
957
958 private:
959  std::vector<int32> ksize_;
960  std::vector<int32> stride_;
961  Padding padding_;
962};
963
964template <typename Device, typename T>
965struct LaunchMaxPoolingGradGradWithArgmax;
966
967template <typename Device, typename T>
968class MaxPoolingGradGradWithArgmaxOp : public OpKernel {
969 public:
970  explicit MaxPoolingGradGradWithArgmaxOp(OpKernelConstruction* context)
971      : OpKernel(context) {
972    OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
973    OP_REQUIRES(context, ksize_.size() == 4,
974                errors::InvalidArgument("Sliding window ksize field must "
975                                        "specify 4 dimensions"));
976    OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
977    OP_REQUIRES(context, stride_.size() == 4,
978                errors::InvalidArgument("Sliding window stride field must "
979                                        "specify 4 dimensions"));
980    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
981    OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
982                errors::Unimplemented(
983                    "Pooling is not yet supported on the batch dimension."));
984  }
985
986  void Compute(OpKernelContext* context) override {
987    const Tensor& tensor_in = context->input(0);
988    const Tensor& grad_in = context->input(1);
989    const Tensor& argmax = context->input(2);
990
991    PoolParameters params{context,  ksize_,      stride_,
992                          padding_, FORMAT_NHWC, tensor_in.shape()};
993    if (!context->status().ok()) {
994      return;
995    }
996
997    TensorShape out_shape({params.tensor_in_batch, params.out_height,
998                           params.out_width, params.depth});
999
1000    Tensor* grad_out = nullptr;
1001    OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
1002                                {1}, 0, out_shape, &grad_out));
1003
1004    LaunchMaxPoolingGradGradWithArgmax<Device, T>::launch(
1005        context, params, grad_in, argmax, grad_out);
1006  }
1007
1008 private:
1009  std::vector<int32> ksize_;
1010  std::vector<int32> stride_;
1011  Padding padding_;
1012};
1013
1014#if GOOGLE_CUDA
1015template <typename T>
1016class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
1017 public:
1018  typedef GPUDevice Device;
1019  explicit MaxPoolingNoMaskOp(OpKernelConstruction* context)
1020      : OpKernel(context) {
1021    string data_format;
1022    OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
1023    OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
1024                errors::InvalidArgument("Invalid data format"));
1025    OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
1026    OP_REQUIRES(context, ksize_.size() == 4,
1027                errors::InvalidArgument("Sliding window ksize field must "
1028                                        "specify 4 dimensions"));
1029    OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
1030    OP_REQUIRES(context, stride_.size() == 4,
1031                errors::InvalidArgument("Sliding window stride field must "
1032                                        "specify 4 dimensions"));
1033    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
1034    const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N');
1035    const int32 stride_n = GetTensorDim(stride_, data_format_, 'N');
1036    OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
1037                errors::Unimplemented(
1038                    "Pooling is not yet supported on the batch dimension."));
1039    use_dnn_ = CanUseCudnn();
1040  }
1041
1042  void Compute(OpKernelContext* context) override {
1043    const Tensor& tensor_in = context->input(0);
1044
1045    PoolParameters params{context,  ksize_,       stride_,
1046                          padding_, data_format_, tensor_in.shape()};
1047    if (!context->status().ok()) {
1048      return;
1049    }
1050
1051    TensorShape out_shape =
1052        ShapeFromFormat(data_format_, params.tensor_in_batch, params.out_height,
1053                        params.out_width, params.depth);
1054    if (use_dnn_ && data_format_ == FORMAT_NCHW) {
1055      DnnPoolingOp<T>::Compute(
1056          context, perftools::gputools::dnn::PoolingMode::kMaximum, ksize_,
1057          stride_, padding_, data_format_, tensor_in, out_shape);
1058    } else {
1059      CHECK(data_format_ == FORMAT_NHWC)
1060          << "Non-Cudnn MaxPool only supports NHWC format";
1061      Tensor* output = nullptr;
1062      OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
1063      LaunchMaxPoolingNoMask<Device, T>::launch(context, params, tensor_in,
1064                                                output);
1065    }
1066  }
1067
1068 private:
1069  std::vector<int32> ksize_;
1070  std::vector<int32> stride_;
1071  Padding padding_;
1072  TensorFormat data_format_;
1073  bool use_dnn_;
1074};
1075
1076template <typename T>
1077class MaxPoolingNoMaskV2Op<GPUDevice, T> : public OpKernel {
1078 public:
1079  typedef GPUDevice Device;
1080  explicit MaxPoolingNoMaskV2Op(OpKernelConstruction* context)
1081      : OpKernel(context) {
1082    string data_format;
1083    OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
1084    OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
1085                errors::InvalidArgument("Invalid data format"));
1086    if (context->num_inputs() == 1) {
1087      OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
1088      OP_REQUIRES(context, ksize_.size() == 4,
1089                  errors::InvalidArgument("Sliding window ksize field must "
1090                                          "specify 4 dimensions"));
1091      OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
1092      OP_REQUIRES(context, stride_.size() == 4,
1093                  errors::InvalidArgument("Sliding window stride field must "
1094                                          "specify 4 dimensions"));
1095      const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N');
1096      const int32 stride_n = GetTensorDim(stride_, data_format_, 'N');
1097      OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
1098                  errors::Unimplemented(
1099                      "Pooling is not yet supported on the batch dimension."));
1100    }
1101    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
1102    use_dnn_ = CanUseCudnn();
1103  }
1104
1105  void Compute(OpKernelContext* context) override {
1106    const Tensor& tensor_in = context->input(0);
1107
1108    std::vector<int32> ksize = ksize_;
1109    std::vector<int32> stride = stride_;
1110
1111    if (context->num_inputs() != 1) {
1112      const Tensor& tensor_ksize = context->input(1);
1113      auto value_ksize = tensor_ksize.flat<int32>();
1114      ksize.resize(tensor_ksize.shape().num_elements());
1115      std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
1116
1117      const Tensor& tensor_stride = context->input(2);
1118      auto value_stride = tensor_stride.flat<int32>();
1119      stride.resize(tensor_stride.shape().num_elements());
1120      std::copy_n(&value_stride(0), stride.size(), stride.begin());
1121    }
1122    OP_REQUIRES(context, ksize.size() == 4,
1123                errors::InvalidArgument("Sliding window ksize field must "
1124                                        "specify 4 dimensions"));
1125    OP_REQUIRES(context, stride.size() == 4,
1126                errors::InvalidArgument("Sliding window stride field must "
1127                                        "specify 4 dimensions"));
1128    const int32 ksize_n = GetTensorDim(ksize, data_format_, 'N');
1129    const int32 stride_n = GetTensorDim(stride, data_format_, 'N');
1130    OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
1131                errors::Unimplemented(
1132                    "Pooling is not yet supported on the batch dimension."));
1133
1134    PoolParameters params{context,  ksize,        stride,
1135                          padding_, data_format_, tensor_in.shape()};
1136    if (!context->status().ok()) {
1137      return;
1138    }
1139
1140    TensorShape out_shape =
1141        ShapeFromFormat(data_format_, params.tensor_in_batch, params.out_height,
1142                        params.out_width, params.depth);
1143    if (use_dnn_ && data_format_ == FORMAT_NCHW) {
1144      DnnPoolingOp<T>::Compute(
1145          context, perftools::gputools::dnn::PoolingMode::kMaximum, ksize,
1146          stride, padding_, data_format_, tensor_in, out_shape);
1147    } else {
1148      CHECK(data_format_ == FORMAT_NHWC)
1149          << "Non-Cudnn MaxPool only supports NHWC format";
1150      Tensor* output = nullptr;
1151      OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
1152      LaunchMaxPoolingNoMask<Device, T>::launch(context, params, tensor_in,
1153                                                output);
1154    }
1155  }
1156
1157 private:
1158  std::vector<int32> ksize_;
1159  std::vector<int32> stride_;
1160  Padding padding_;
1161  TensorFormat data_format_;
1162  bool use_dnn_;
1163};
1164
1165template <typename T>
1166struct LaunchMaxPoolingNoMask<Eigen::GpuDevice, T> {
1167  static void launch(OpKernelContext* context, const PoolParameters& params,
1168                     const Tensor& input, Tensor* output) {
1169    bool status = functor::MaxPoolForwardWithOptionalArgmax<T>()(
1170        input.flat<T>().data(), params.tensor_in_batch, params.tensor_in_rows,
1171        params.tensor_in_cols, params.depth, params.out_height,
1172        params.out_width, params.window_rows, params.window_cols,
1173        params.row_stride, params.col_stride, params.pad_rows, params.pad_cols,
1174        output->flat<T>().data(), nullptr, context->eigen_gpu_device());
1175    if (!status) {
1176      context->SetStatus(
1177          errors::Internal("Failed launching MaxPoolForwardNoMask"));
1178    }
1179  }
1180};
1181
1182template <typename T>
1183struct LaunchMaxPoolingWithArgmax<Eigen::GpuDevice, T> {
1184  static void launch(OpKernelContext* context, const PoolParameters& params,
1185                     const Tensor& input, Tensor* output, Tensor* argmax) {
1186    bool status = functor::MaxPoolForwardWithOptionalArgmax<T>()(
1187        input.flat<T>().data(), params.tensor_in_batch, params.tensor_in_rows,
1188        params.tensor_in_cols, params.depth, params.out_height,
1189        params.out_width, params.window_rows, params.window_cols,
1190        params.row_stride, params.col_stride, params.pad_rows, params.pad_cols,
1191        output->flat<T>().data(),
1192        reinterpret_cast<int64*>(argmax->flat<int64>().data()),
1193        context->eigen_gpu_device());
1194    if (!status) {
1195      context->SetStatus(
1196          errors::Internal("Failed launching MaxPoolForwardWithArgmax"));
1197    }
1198  }
1199};
1200
1201template <typename T>
1202struct LaunchMaxPoolingGradWithArgmax<Eigen::GpuDevice, T> {
1203  static void launch(OpKernelContext* context, const PoolParameters& params,
1204                     const Tensor& grad_in, const Tensor& argmax,
1205                     Tensor* grad_out) {
1206    const int input_size = params.tensor_in_batch * params.tensor_in_rows *
1207                           params.tensor_in_cols * params.depth;
1208    const int output_size = params.tensor_in_batch * params.out_height *
1209                            params.out_width * params.depth;
1210    const int top_offset = params.out_height * params.out_width * params.depth;
1211    const int bottom_offset =
1212        params.tensor_in_rows * params.tensor_in_cols * params.depth;
1213    bool status = functor::MaxPoolBackwardWithArgmax<T>()(
1214        output_size, input_size, grad_in.flat<T>().data(),
1215        reinterpret_cast<const int64*>(argmax.flat<int64>().data()), top_offset,
1216        bottom_offset, grad_out->flat<T>().data(), context->eigen_gpu_device());
1217    if (!status) {
1218      context->SetStatus(
1219          errors::Internal("Failed launching MaxPoolBackwardWithArgmax"));
1220    }
1221  }
1222};
1223
1224template <typename T>
1225struct LaunchMaxPoolingGradGradWithArgmax<Eigen::GpuDevice, T> {
1226  static void launch(OpKernelContext* context, const PoolParameters& params,
1227                     const Tensor& grad_in, const Tensor& argmax,
1228                     Tensor* grad_out) {
1229    const int input_size = params.tensor_in_batch * params.tensor_in_rows *
1230                           params.tensor_in_cols * params.depth;
1231    const int output_size = params.tensor_in_batch * params.out_height *
1232                            params.out_width * params.depth;
1233    const int top_offset =
1234        params.tensor_in_rows * params.tensor_in_cols * params.depth;
1235    const int bottom_offset =
1236        params.out_width * params.out_height * params.depth;
1237    bool status = functor::MaxPoolGradBackwardWithArgmax<T>()(
1238        output_size, input_size, grad_in.flat<T>().data(),
1239        reinterpret_cast<const int64*>(argmax.flat<int64>().data()), top_offset,
1240        bottom_offset, grad_out->flat<T>().data(), context->eigen_gpu_device());
1241    if (!status) {
1242      context->SetStatus(
1243          errors::Internal("Failed launching MaxPoolGradBackwardWithArgmax"));
1244    }
1245  }
1246};
1247
1248#endif  // GOOGLE_CUDA
1249
1250#define REGISTER_MAX_POOL_KERNELS(D, T)                                  \
1251  REGISTER_KERNEL_BUILDER(                                               \
1252      Name("MaxPoolGrad").Device(DEVICE_##D).TypeConstraint<T>("T"),     \
1253      MaxPoolingGradOp<D##Device, T>);                                   \
1254  REGISTER_KERNEL_BUILDER(                                               \
1255      Name("MaxPoolGradGrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \
1256      MaxPoolingGradGradOp<D##Device, T>);                               \
1257  REGISTER_KERNEL_BUILDER(Name("MaxPoolGradV2")                          \
1258                              .Device(DEVICE_##D)                        \
1259                              .HostMemory("ksize")                       \
1260                              .HostMemory("strides")                     \
1261                              .TypeConstraint<T>("T"),                   \
1262                          MaxPoolingGradOp<D##Device, T>);               \
1263  REGISTER_KERNEL_BUILDER(Name("MaxPoolGradGradV2")                      \
1264                              .Device(DEVICE_##D)                        \
1265                              .HostMemory("ksize")                       \
1266                              .HostMemory("strides")                     \
1267                              .TypeConstraint<T>("T"),                   \
1268                          MaxPoolingGradGradOp<D##Device, T>);
1269
1270// Below kernels implemented only for CPU device.
1271#define REGISTER_CPU_ONLY_POOL_KERNELS(T)                          \
1272  REGISTER_KERNEL_BUILDER(                                         \
1273      Name("MaxPool").Device(DEVICE_CPU).TypeConstraint<T>("T"),   \
1274      MaxPoolingOp<CPUDevice, T>);                                 \
1275  REGISTER_KERNEL_BUILDER(                                         \
1276      Name("MaxPoolV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
1277      MaxPoolingV2Op<CPUDevice, T>);
1278TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_ONLY_POOL_KERNELS);
1279#undef REGISTER_CPU_ONLY_POOL_KERNELS
1280
1281#define REGISTER_CPU_MAX_POOL_KERNELS(T) REGISTER_MAX_POOL_KERNELS(CPU, T);
1282TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_MAX_POOL_KERNELS);
1283#undef REGISTER_CPU_KERNELS
1284
1285#if GOOGLE_CUDA
1286
1287// Forward declarations for the functor specializations for GPU.
1288namespace functor {
1289#define DECLARE_GPU_SPEC(T)                                            \
1290  template <>                                                          \
1291  void SpatialMaxPooling<Eigen::GpuDevice, T>::operator()(             \
1292      const Eigen::GpuDevice& d, typename TTypes<T, 4>::Tensor output, \
1293      typename TTypes<T, 4>::ConstTensor input, int window_rows,       \
1294      int window_cols, int row_stride, int col_stride,                 \
1295      const Eigen::PaddingType& padding);                              \
1296  extern template struct SpatialMaxPooling<Eigen::GpuDevice, T>;
1297
1298TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
1299#undef DECLARE_GPU_SPEC
1300}  // namespace functor
1301
1302#define REGISTER_GPU_MAX_POOL_KERNELS(T) REGISTER_MAX_POOL_KERNELS(GPU, T)
1303TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_MAX_POOL_KERNELS);
1304#undef REGISTER_GPU_MAX_POOL_KERNELS
1305
1306// Below kernels currently implemented only for GPU device.
1307// Note(jiayq): Currently, the Caffe custom implementation is faster than the
1308// default Eigen implementation so we are using the custom kernel as the
1309// default. However, you can explicitly invoke the eigen version using
1310// kernel_label_map.
1311#define REGISTER_GPU_ONLY_POOL_KERNELS(T)                            \
1312  REGISTER_KERNEL_BUILDER(Name("MaxPool")                            \
1313                              .Device(DEVICE_GPU)                    \
1314                              .TypeConstraint<T>("T")                \
1315                              .Label("eigen_tensor"),                \
1316                          MaxPoolingOp<GPUDevice, T>);               \
1317  REGISTER_KERNEL_BUILDER(Name("MaxPoolV2")                          \
1318                              .Device(DEVICE_GPU)                    \
1319                              .HostMemory("ksize")                   \
1320                              .HostMemory("strides")                 \
1321                              .TypeConstraint<T>("T")                \
1322                              .Label("eigen_tensor"),                \
1323                          MaxPoolingV2Op<GPUDevice, T>);             \
1324  REGISTER_KERNEL_BUILDER(                                           \
1325      Name("MaxPool").Device(DEVICE_GPU).TypeConstraint<T>("T"),     \
1326      MaxPoolingNoMaskOp<GPUDevice, T>);                             \
1327  REGISTER_KERNEL_BUILDER(Name("MaxPoolV2")                          \
1328                              .Device(DEVICE_GPU)                    \
1329                              .HostMemory("ksize")                   \
1330                              .HostMemory("strides")                 \
1331                              .TypeConstraint<T>("T"),               \
1332                          MaxPoolingNoMaskV2Op<GPUDevice, T>);       \
1333  REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax")                  \
1334                              .Device(DEVICE_GPU)                    \
1335                              .TypeConstraint<int64>("Targmax")      \
1336                              .TypeConstraint<T>("T"),               \
1337                          MaxPoolingWithArgmaxOp<GPUDevice, T>);     \
1338  REGISTER_KERNEL_BUILDER(Name("MaxPoolGradWithArgmax")              \
1339                              .Device(DEVICE_GPU)                    \
1340                              .TypeConstraint<T>("T")                \
1341                              .TypeConstraint<int64>("Targmax"),     \
1342                          MaxPoolingGradWithArgmaxOp<GPUDevice, T>); \
1343  REGISTER_KERNEL_BUILDER(Name("MaxPoolGradGradWithArgmax")          \
1344                              .Device(DEVICE_GPU)                    \
1345                              .TypeConstraint<T>("T")                \
1346                              .TypeConstraint<int64>("Targmax"),     \
1347                          MaxPoolingGradGradWithArgmaxOp<GPUDevice, T>);
1348TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_ONLY_POOL_KERNELS);
1349#undef REGISTER_GPU_ONLY_POOL_KERNELS
1350
1351#endif  // GOOGLE_CUDA
1352
1353#undef REGISTER_MAX_POOL_KERNELS
1354
1355}  // namespace tensorflow
1356