1fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
3fe97705b706c9dcd36586b6158e30758346c6afdVivek RaneLicensed under the Apache License, Version 2.0 (the "License");
4fe97705b706c9dcd36586b6158e30758346c6afdVivek Raneyou may not use this file except in compliance with the License.
5fe97705b706c9dcd36586b6158e30758346c6afdVivek RaneYou may obtain a copy of the License at
6fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
7fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    http://www.apache.org/licenses/LICENSE-2.0
8fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
9fe97705b706c9dcd36586b6158e30758346c6afdVivek RaneUnless required by applicable law or agreed to in writing, software
10fe97705b706c9dcd36586b6158e30758346c6afdVivek Ranedistributed under the License is distributed on an "AS IS" BASIS,
11fe97705b706c9dcd36586b6158e30758346c6afdVivek RaneWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12fe97705b706c9dcd36586b6158e30758346c6afdVivek RaneSee the License for the specific language governing permissions and
13fe97705b706c9dcd36586b6158e30758346c6afdVivek Ranelimitations under the License.
14fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane==============================================================================*/
15fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
16fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane// See docs in ../ops/nn_ops.cc.
17fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane#ifdef INTEL_MKL
18fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane#define EIGEN_USE_THREADS
19fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane#include "tensorflow/core/framework/op_kernel.h"
20fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane#include "tensorflow/core/kernels/mkl_pooling_ops_common.h"
21fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane#include "tensorflow/core/lib/core/errors.h"
22c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba#include "tensorflow/core/util/mkl_util.h"
23fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane#include "tensorflow/core/util/padding.h"
24fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
25e4a628adf84a2373d773103cdeabc96cbffd7b47AG Ramesh#ifndef INTEL_MKL_ML
2604807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina#include <algorithm>
2704807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina#include "mkldnn.hpp"
28982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsenusing mkldnn::algorithm;
29982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsenusing mkldnn::engine;
3004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzainausing mkldnn::error;
31982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsenusing mkldnn::memory;
3204807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzainausing mkldnn::padding_kind;
33982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsenusing mkldnn::pooling_backward;
34982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsenusing mkldnn::pooling_forward;
3504807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzainausing mkldnn::prop_kind;
3604807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina#endif
3704807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina
38fe97705b706c9dcd36586b6158e30758346c6afdVivek Ranenamespace tensorflow {
39fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
40fe97705b706c9dcd36586b6158e30758346c6afdVivek Ranetypedef Eigen::ThreadPoolDevice CPUDevice;
41fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
42e4a628adf84a2373d773103cdeabc96cbffd7b47AG Ramesh// MKL-DNN is now default. MKL-ML must be specified explicitly.
43e4a628adf84a2373d773103cdeabc96cbffd7b47AG Ramesh#ifdef INTEL_MKL_ML
4404807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina
45fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane// An implementation of MaxPooling (forward).
46fe97705b706c9dcd36586b6158e30758346c6afdVivek Ranetemplate <typename Device, typename T>
47fe97705b706c9dcd36586b6158e30758346c6afdVivek Raneclass MklMaxPoolingOp : public OpKernel {
48fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane public:
49fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane  explicit MklMaxPoolingOp(OpKernelConstruction* context) : OpKernel(context) {
50fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    string data_format;
51fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
52fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
53fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
54fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane                errors::InvalidArgument("Invalid data format"));
55fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
56fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    OP_REQUIRES(context, ksize_.size() == 4,
57fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane                errors::InvalidArgument("Sliding window ksize field must "
58fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane                                        "specify 4 dimensions"));
59fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
60fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    OP_REQUIRES(context, stride_.size() == 4,
61fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane                errors::InvalidArgument("Sliding window stride field must "
62fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane                                        "specify 4 dimensions"));
63fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
64fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
65fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane                errors::Unimplemented("Pooling is not yet supported on the "
66fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane                                      "batch dimension."));
67fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
68fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    workspace_enabled_ = false;
69fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    // We may not get this attribute for this node if it does not go through
70fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    // graph rewrite pass. So we do not check for error while retrieving this
71fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    // attribute value.
72fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    context->GetAttr("workspace_enabled", &workspace_enabled_);
73fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane  }
74fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
75fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane  void Compute(OpKernelContext* context) override {
76c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    MklMaxPoolingOpContext mkl_context;
77fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    // Get the input tensor
78fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    const Tensor& tensor_in = MklGetInput(context, 0);
79c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    GetMklShape(context, 0, &mkl_context.input_shape);
80c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor();
81c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba
82c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    mkl_context.params.in_dim = 4;
83c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    MklPoolParameters pool_params;
84c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    if (input_in_mkl_format == false) {
85c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      pool_params.Init(context, ksize_, stride_, padding_, data_format_,
86c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                       tensor_in.shape());
87c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      OP_REQUIRES(
88c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba          context, (pool_params.depth_window == 1),
89c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba          errors::Unimplemented("Depthwise max pooling not supported by MKL"));
90fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
91fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    } else {
92c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      pool_params.Init(context, ksize_, stride_, padding_, data_format_,
93c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                       &mkl_context.input_shape);
94fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    }
95fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
96fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    // Extract the parameters for the op from the pooling specs
97fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
98c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params);
99c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba
100c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    mkl_context.MklCreateLayoutsAndPrimitives(context);
1014eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane    OP_REQUIRES_OK(context, context->status());
102fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
103fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    // Declare output tensor
104fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    TensorShape tensor_out_shape;
1054eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane    MklShape mkl_out_shape, mkl_workspace_shape;
106fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    mkl_out_shape.SetMklTensor(true);
107c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    mkl_out_shape.SetMklLayout(mkl_context.prim_pooling_fwd, dnnResourceDst);
108c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    mkl_out_shape.SetTfLayout(mkl_context.params.in_dim,
109c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                              mkl_context.params.out_sizes,
110c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                              mkl_context.params.out_strides);
111c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    mkl_out_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_);
112fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
113fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    Tensor* output_tensor = nullptr;
114c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    tensor_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
115c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                                mkl_out_shape.GetMklLayout())) /
116c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                            sizeof(T));
117e0813b2c58890decedcddf84af839a1bf3c0bc76agramesh    AllocateOutputSetMklShape(context, 0, &output_tensor, tensor_out_shape,
118fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane                              mkl_out_shape);
119fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
120c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    Tensor* workspace_tensor;
121c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    void* workspace_buf = nullptr;
122fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
1234eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane    TensorShape workspace_shape;
1244eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane    mkl_workspace_shape.SetMklTensor(false);
1254eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane    workspace_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
1264eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane                               mkl_context.lt_workspace)) /
1274eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane                           sizeof(T));
1284eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane    AllocateOutputSetMklShape(context, 1, &workspace_tensor, workspace_shape,
1294eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane                              mkl_workspace_shape);
1304eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane
1314eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane    mkl_context.pooling_res[dnnResourceWorkspace] = const_cast<void*>(
1324eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane        static_cast<const void*>(workspace_tensor->flat<T>().data()));
133c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    mkl_context.pooling_res[dnnResourceSrc] =
134c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        const_cast<void*>(static_cast<const void*>(tensor_in.flat<T>().data()));
135c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    mkl_context.pooling_res[dnnResourceDst] = const_cast<void*>(
136fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane        static_cast<const void*>(output_tensor->flat<T>().data()));
137fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
138c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    CHECK_EQ(
139c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        dnnExecute_F32(mkl_context.prim_pooling_fwd, mkl_context.pooling_res),
140c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        E_SUCCESS);
141fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
142c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    mkl_context.MklCleanup();
143fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane  }
144fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
145fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane private:
146c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba  typedef struct {
147c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    MklPoolingOpParams params;
148c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    MklShape input_shape;
149c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    void* pooling_res[dnnResourceNumber];
1504eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane    dnnPrimitive_t prim_pooling_fwd = nullptr;
1514eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane    dnnLayout_t lt_user_input = nullptr, lt_workspace = nullptr;
152c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba
153c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    void MklCreateLayoutsAndPrimitives(OpKernelContext* context) {
154c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      bool input_in_mkl_format = input_shape.IsMklTensor();
155c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      // Create or use existing DNN user layout
156c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      if (input_in_mkl_format == false) {
157c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        CHECK_EQ(dnnLayoutCreate_F32(&lt_user_input, params.in_dim,
158c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                                     params.in_sizes, params.in_strides),
159c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                 E_SUCCESS);
160c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      } else {
161c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        lt_user_input = (dnnLayout_t)input_shape.GetCurLayout();
162c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      }
163fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
164c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      dnnAlgorithm_t algorithm = dnnAlgorithmPoolingMax;
165c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      dnnPrimitiveAttributes_t primAttr = nullptr;
166fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
167c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      // Create DNN primitives
168c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      CHECK_EQ(dnnPoolingCreateForward_F32(
169c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                   &prim_pooling_fwd, primAttr, algorithm, lt_user_input,
170c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                   params.kernel_size, params.kernel_stride, params.in_offset,
171c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                   dnnBorderZerosAsymm),
172c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba               E_SUCCESS);
173fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
174c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      // Creates layout for the workspace
175c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_workspace, prim_pooling_fwd,
176c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                                                dnnResourceWorkspace),
177fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane               E_SUCCESS);
178fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    }
179fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
180c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    void MklCleanup() {
181c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      bool input_in_mkl_format = input_shape.IsMklTensor();
182c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      CHECK_EQ(dnnDelete_F32(prim_pooling_fwd), E_SUCCESS);
183c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      if (!input_in_mkl_format) {
184c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        CHECK_EQ(dnnLayoutDelete_F32(lt_user_input), E_SUCCESS);
185c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      }
186c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      CHECK_EQ(dnnLayoutDelete_F32(lt_workspace), E_SUCCESS);
187fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    }
188c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba  } MklMaxPoolingOpContext;
189fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
190c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba  std::vector<int32> ksize_;
191c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba  std::vector<int32> stride_;
192c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba  Padding padding_;
193c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba  TensorFormat data_format_;
194c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba  bool workspace_enabled_;
195fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane};
196fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
197fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane// The operation to compute MaxPool gradients.
198fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane// It takes three inputs:
199fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane//   - The original input tensor
200fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane//   - The original output tensor
201fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane//   - Backprop tensor for output
202fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane// It produces one output: backprop tensor for input.
203fe97705b706c9dcd36586b6158e30758346c6afdVivek Ranetemplate <class Device, class T>
204fe97705b706c9dcd36586b6158e30758346c6afdVivek Raneclass MklMaxPoolingGradOp : public OpKernel {
205fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane public:
206fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane  explicit MklMaxPoolingGradOp(OpKernelConstruction* context)
207fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane      : OpKernel(context) {
208fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    string data_format;
209fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
210fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
211fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
212fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane                errors::InvalidArgument("Invalid data format"));
213fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
214fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    OP_REQUIRES(context, ksize_.size() == 4,
215fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane                errors::InvalidArgument("Sliding window ksize field must "
216fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane                                        "specify 4 dimensions"));
217fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
218fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    OP_REQUIRES(context, stride_.size() == 4,
219fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane                errors::InvalidArgument("Sliding window strides field must "
220fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane                                        "specify 4 dimensions"));
221fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
222fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
223fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane                errors::Unimplemented(
224fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane                    "Pooling is not yet supported on the batch dimension."));
225fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    workspace_enabled_ = false;
226fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    // We may not get this attribute for this node if it does not go through
227fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    // graph rewrite pass. So we do not check for error while retrieving this
228fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    // attribute value.
229fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    context->GetAttr("workspace_enabled", &workspace_enabled_);
230fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane  }
231fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
232fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane  void Compute(OpKernelContext* context) override {
233c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    MklMaxPoolingGradOpContext mkl_context;
234fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    // Input - The original input tensor
235fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    const Tensor& tensor_in = MklGetInput(context, 0);
236fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
237fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    // Output - Backprop tensor for input.
238fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    Tensor* output_tensor = nullptr;
239fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
240c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    GetMklShape(context, 0, &mkl_context.input_shape);
241c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    GetMklShape(context, 2, &mkl_context.output_backprop_shape);
242c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor();
243fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
244c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    if (input_in_mkl_format == false)
245c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      mkl_context.params.in_dim = tensor_in.dims();
246fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    else
247c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      mkl_context.params.in_dim = mkl_context.input_shape.GetDimension();
248fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
249c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    MklPoolParameters pool_params;
250c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    if (input_in_mkl_format == false) {
251c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      pool_params.Init(context, ksize_, stride_, padding_, data_format_,
252c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                       tensor_in.shape());
253c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      OP_REQUIRES(
254c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba          context, (pool_params.depth_window == 1),
255c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba          errors::Unimplemented("Depthwise max pooling not supported by MKL"));
256fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
257fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    } else {
258c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      pool_params.Init(context, ksize_, stride_, padding_, data_format_,
259c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                       &mkl_context.input_shape);
260fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    }
261fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
262fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    // Extract the parameters for the op from the pooling specs
263c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params);
264fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
265c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    mkl_context.MklCreateLayouts(context);
2664eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane    OP_REQUIRES_OK(context, context->status());
2674eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane
268c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    mkl_context.MklCreatePrimitives(context, workspace_enabled_);
2694eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane    OP_REQUIRES_OK(context, context->status());
2704eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane
271c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    mkl_context.MklPrepareInputs(context, workspace_enabled_);
2724eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane    OP_REQUIRES_OK(context, context->status());
273fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
274fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    // Create shape for the input back prop output
275fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    TensorShape mkl_input_backprop;
276c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    MklShape mkl_output_shape;
277c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    mkl_output_shape.SetMklTensor(true);
278c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    mkl_output_shape.SetMklLayout(mkl_context.prim_pooling_bwd,
279c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                                  dnnResourceDiffSrc);
280c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    mkl_output_shape.SetTfLayout(mkl_context.params.in_dim,
281c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                                 mkl_context.params.in_sizes,
282c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                                 mkl_context.params.in_strides);
283c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    mkl_output_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_);
284fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
285fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    mkl_input_backprop.AddDim(
286fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane        dnnLayoutGetMemorySize_F32(
287c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba            static_cast<dnnLayout_t>(mkl_output_shape.GetMklLayout())) /
288fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane        sizeof(T));
289e0813b2c58890decedcddf84af839a1bf3c0bc76agramesh    AllocateOutputSetMklShape(context, 0, &output_tensor, mkl_input_backprop,
290c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                              mkl_output_shape);
291c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    mkl_context.pooling_res[dnnResourceDiffSrc] = const_cast<void*>(
292c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        static_cast<const void*>(output_tensor->flat<T>().data()));
293fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
294c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    CHECK_EQ(
295c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        dnnExecute_F32(mkl_context.prim_pooling_bwd, mkl_context.pooling_res),
296c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        E_SUCCESS);
297fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
298c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    mkl_context.MklCleanup(workspace_enabled_);
299fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane  }
300fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
301fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane private:
302c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba  typedef struct {
303c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    MklPoolingOpParams params;
304c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    MklShape input_shape, output_backprop_shape;
305c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    void* pooling_resfwd[dnnResourceNumber];
306c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    void* pooling_res[dnnResourceNumber];
3074eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane    dnnPrimitive_t prim_pooling_fwd = nullptr, prim_pooling_bwd = nullptr,
3084eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane                   convert_input = nullptr, convert_outbackprop = nullptr;
3094eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane    dnnLayout_t lt_outbackprop_user = nullptr, lt_outbackprop_prim = nullptr,
3104eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane                lt_input_user = nullptr, lt_input_prim = nullptr;
311c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    void* input_buf;
312c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    void* outbackprop_buf;
3134eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane    Tensor tmp_output_buf_tensor;
3144eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane    Tensor workspace_buf_tensor;
3154eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane    Tensor input_buf_tensor, outbackprop_buf_tensor;
316c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba
317c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    void MklCreateLayouts(OpKernelContext* context) {
318c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      bool input_in_mkl_format = input_shape.IsMklTensor();
319c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor();
320c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      // Create DNN user layout for input and outbackprop or get existing layout
321c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      if (input_in_mkl_format == false) {
322c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        CHECK_EQ(dnnLayoutCreate_F32(&lt_input_user, params.in_dim,
323c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                                     params.in_sizes, params.in_strides),
324c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                 E_SUCCESS);
325c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      } else {
326c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        lt_input_user = (dnnLayout_t)input_shape.GetCurLayout();
327c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      }
328fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
329e16cd2ede6a739a1c9c4576ebf8a4fe81e83f39eTaehoon Lee      // We don't care about the output layout for now as we can create it from
330c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      // primitives for the max pooling fwd prop
331c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      if (outbackprop_in_mkl_format == false) {
332c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        CHECK_EQ(dnnLayoutCreate_F32(&lt_outbackprop_user, params.in_dim,
333c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                                     params.out_sizes, params.out_strides),
334c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                 E_SUCCESS);
335c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      } else {
336c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        lt_outbackprop_user = (dnnLayout_t)output_backprop_shape.GetCurLayout();
337c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      }
338fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    }
339fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
340c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    // Create DNN primitives
341c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    void MklCreatePrimitives(OpKernelContext* context, bool workspace_enabled) {
342c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      dnnAlgorithm_t algorithm = dnnAlgorithmPoolingMax;
343c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      dnnPrimitiveAttributes_t primAttr = nullptr;
344c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba
345c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      if (workspace_enabled == false) {
346c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        CHECK_EQ(dnnPoolingCreateForward_F32(
347c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                     &prim_pooling_fwd, primAttr, algorithm, lt_input_user,
348c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                     params.kernel_size, params.kernel_stride, params.in_offset,
349c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                     dnnBorderZerosAsymm),
350c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                 E_SUCCESS);
351c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      }
352fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
353c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      CHECK_EQ(dnnPoolingCreateBackward_F32(
354c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                   &prim_pooling_bwd, primAttr, algorithm, lt_input_user,
355c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                   params.kernel_size, params.kernel_stride, params.in_offset,
356c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                   dnnBorderZerosAsymm),
357fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane               E_SUCCESS);
358fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
359c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      // Creates conversions
360c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
361c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                   &lt_outbackprop_prim, prim_pooling_bwd, dnnResourceDiffDst),
362fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane               E_SUCCESS);
363fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
364c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      if (workspace_enabled == false) {
365c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
366c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                     &lt_input_prim, prim_pooling_fwd, dnnResourceSrc),
367fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane                 E_SUCCESS);
368c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        if (!dnnLayoutCompare_F32(lt_input_user, lt_input_prim)) {
369c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba          CHECK_EQ(dnnConversionCreate_F32(&convert_input, lt_input_user,
370c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                                           lt_input_prim),
371c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                   E_SUCCESS);
372c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba          AllocTmpBuffer(context, &input_buf_tensor, lt_input_prim, &input_buf);
373c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        }
374fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane      }
375fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
376c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      if (!dnnLayoutCompare_F32(lt_outbackprop_user, lt_outbackprop_prim)) {
377c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        CHECK_EQ(
378c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba            dnnConversionCreate_F32(&convert_outbackprop, lt_outbackprop_user,
379c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                                    lt_outbackprop_prim),
380c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba            E_SUCCESS);
381c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        AllocTmpBuffer(context, &outbackprop_buf_tensor, lt_outbackprop_prim,
382c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                       &outbackprop_buf);
383c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      }
384fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    }
385fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
386c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    // Compare incoming tensor layouts with MKL preferred layouts and convert
387c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    // data to the preferred layout if necessary
388c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    void MklPrepareInputs(OpKernelContext* context, bool workspace_enabled) {
389c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      const Tensor& tensor_in = MklGetInput(context, 0);
390c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      const Tensor& out_backprop = MklGetInput(context, 2);
391c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      bool input_in_mkl_format = input_shape.IsMklTensor();
392c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor();
393c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba
3944eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane      void* tmp_output_buf = nullptr;
3954eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane      void* workspace_buf = nullptr;
396c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba
397c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      if (workspace_enabled == false) {
398c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        if (convert_input != nullptr) {
399c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba          if (input_in_mkl_format == false) {
400982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen            CHECK_EQ(dnnConversionExecute_F32(
401982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                         convert_input,
402982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                         const_cast<void*>(static_cast<const void*>(
403982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                             tensor_in.flat<T>().data())),
404982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                         input_buf),
405982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                     E_SUCCESS);
406c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba            CHECK_EQ(dnnDelete_F32(convert_input), E_SUCCESS);
407c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba            convert_input = nullptr;
408c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba          } else {
409c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba            input_shape.GetConvertedFlatData(
410982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                lt_input_prim,
411982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                const_cast<void*>(
412982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                    static_cast<const void*>(tensor_in.flat<T>().data())),
413c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                input_buf);
414c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba          }
415c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba          pooling_resfwd[dnnResourceSrc] = input_buf;
416fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane        } else {
417c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba          pooling_resfwd[dnnResourceSrc] = const_cast<void*>(
418c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba              static_cast<const void*>(tensor_in.flat<T>().data()));
419fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane        }
420fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
421c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        dnnLayout_t lt_workspace;
422c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
423c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                     &lt_workspace, prim_pooling_fwd, dnnResourceWorkspace),
424c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                 E_SUCCESS);
425c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        AllocTmpBuffer(context, &workspace_buf_tensor, lt_workspace,
426c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                       &workspace_buf);
427c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        pooling_resfwd[dnnResourceWorkspace] = workspace_buf;
428fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
429c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        dnnLayoutDelete_F32(lt_workspace);
430fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
431c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        // We create the layout for max pooling fwd prop tmp output here
432c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        AllocTmpBuffer(context, &tmp_output_buf_tensor, lt_outbackprop_prim,
433c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                       &tmp_output_buf);
434c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        pooling_resfwd[dnnResourceDst] = tmp_output_buf;
435fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
436c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        CHECK_EQ(dnnExecute_F32(prim_pooling_fwd, pooling_resfwd), E_SUCCESS);
437c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        pooling_res[dnnResourceWorkspace] =
438c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba            pooling_resfwd[dnnResourceWorkspace];
439fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane      } else {
440c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        const Tensor& workspace = MklGetInput(context, 3);
441c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        pooling_res[dnnResourceWorkspace] = const_cast<void*>(
442c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba            static_cast<const void*>(workspace.flat<T>().data()));
443fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane      }
444fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
445c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      // Out backprop conversions if needed
446c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      if (convert_outbackprop != nullptr) {
447c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        if (outbackprop_in_mkl_format == false) {
448c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba          CHECK_EQ(dnnConversionExecute_F32(
449c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                       convert_outbackprop,
450c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                       const_cast<void*>(static_cast<const void*>(
451c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                           out_backprop.flat<T>().data())),
452c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                       outbackprop_buf),
453c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                   E_SUCCESS);
454c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba          CHECK_EQ(dnnDelete_F32(convert_outbackprop), E_SUCCESS);
455c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        } else {
456c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba          output_backprop_shape.GetConvertedFlatData(
457982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen              lt_outbackprop_prim,
458982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen              const_cast<void*>(
459982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                  static_cast<const void*>(out_backprop.flat<T>().data())),
460c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba              outbackprop_buf);
461c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        }
462c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        pooling_res[dnnResourceDiffDst] = outbackprop_buf;
463c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      } else {
464c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        pooling_res[dnnResourceDiffDst] = const_cast<void*>(
465c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba            static_cast<const void*>(out_backprop.flat<T>().data()));
466c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      }
467fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    }
468fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
469c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba    void MklCleanup(bool workspace_enabled) {
470c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      bool input_in_mkl_format = input_shape.IsMklTensor();
471c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor();
472c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      if (workspace_enabled == false) {
473c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        CHECK_EQ(dnnDelete_F32(prim_pooling_fwd), E_SUCCESS);
474c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      }
475c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      CHECK_EQ(dnnDelete_F32(prim_pooling_bwd), E_SUCCESS);
476c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      if (outbackprop_in_mkl_format == false) {
477c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        CHECK_EQ(dnnLayoutDelete_F32(lt_outbackprop_user), E_SUCCESS);
478c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      }
479c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      CHECK_EQ(dnnLayoutDelete_F32(lt_outbackprop_prim), E_SUCCESS);
480c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      if (input_in_mkl_format == false) {
481c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        CHECK_EQ(dnnLayoutDelete_F32(lt_input_user), E_SUCCESS);
482c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      }
483c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      if (workspace_enabled == false) {
484c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba        CHECK_EQ(dnnLayoutDelete_F32(lt_input_prim), E_SUCCESS);
485c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba      }
486fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane    }
487c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba  } MklMaxPoolingGradOpContext;
488fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
489c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba  std::vector<int32> ksize_;
490c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba  std::vector<int32> stride_;
491c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba  Padding padding_;
492c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba  TensorFormat data_format_;
493fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
494c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba  bool workspace_enabled_;
49504807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina};  // MklMaxPoolingGradOp
49604807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina
497e4a628adf84a2373d773103cdeabc96cbffd7b47AG Ramesh#else
49804807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina
49904807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina// An implementation of MaxPooling (forward).
50004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzainatemplate <typename Device, typename T>
50104807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzainaclass MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
50204807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina public:
50304807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina  explicit MklMaxPoolingOp(OpKernelConstruction* context)
504982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      : MklPoolingForwardOpBase<T>(context) {
50504807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina    // In Max Pooling, MKLDNN does not allow passing workspace as NULL.
50604807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina    // So we set workspace_enabled_ to true.
50704807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina    this->workspace_enabled_ = true;
50804807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina  }
50904807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina
51004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina  void Compute(OpKernelContext* context) override {
51104807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina    try {
51204807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      auto cpu_engine = engine(engine::cpu, 0);
513982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      const Tensor& input_tensor =
514982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          MklGetInput(context, this->kInputTensorIndexInput);
51504807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      MklDnnShape dnn_shape_input;
51604807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input);
51704807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      this->SanityCheckInput(context, input_tensor, dnn_shape_input);
51804807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      if (!context->status().ok()) return;
51904807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina
52004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      MklDnnData<T> dnn_data_input(&cpu_engine);
52104807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      MklDnnData<T> dnn_data_output(&cpu_engine);
522ae700bb7462ee1bd4fed3c89441e962f64c89afdNiranjan Hasabnis      MklDnnData<uint8> dnn_data_wksp(&cpu_engine);
52304807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina
52404807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      // initialize variables for the pooling op
52504807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      MklPoolParameters pool_params;
52604807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      // Get the input tensor and initialize the pooling parameters
527982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      this->ConfigureInput(context, dnn_shape_input, input_tensor, &pool_params,
528982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                           &dnn_data_input);
52904807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      OP_REQUIRES_OK(context, context->status());
53004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina
53104807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      // Declare output tensor
53204807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      Tensor* output_tensor = nullptr;
53304807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      memory::dims output_dims_mkl_order;
53404807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      this->GetOutputDims(pool_params, &output_dims_mkl_order);
53504807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina
53604807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      // If input is in Mkl layout, then just get the memory format from it
53704807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      // directly, instead of using input data_format to MaxPool.
53804807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      if (dnn_shape_input.IsMklTensor()) {
539982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen        dnn_data_output.SetUsrMem(
540982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen            output_dims_mkl_order,
541982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen            static_cast<memory::format>(
542982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                dnn_data_input.GetUsrMemDesc().data.format));
54304807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      } else {
54404807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina        dnn_data_output.SetUsrMem(output_dims_mkl_order,
54504807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina                                  this->data_format_mkldnn_);
54604807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      }
54704807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina
54804807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      // describe the memory layout; let mkl-dnn choose the best for the op
54904807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      dnn_data_output.SetOpMemDesc(output_dims_mkl_order, memory::format::any);
55004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina
551982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      auto pool_desc = pooling_forward::desc(
552982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          prop_kind::forward, algorithm::pooling_max,
553982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          dnn_data_input.GetUsrMemDesc(), dnn_data_output.GetUsrMemDesc(),
554982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          memory::dims({pool_params.row_stride, pool_params.col_stride}),
555982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          memory::dims({pool_params.window_rows, pool_params.window_cols}),
556982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          memory::dims({static_cast<int>(pool_params.pad_top),
557982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                        static_cast<int>(pool_params.pad_left)}),
558982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          memory::dims({static_cast<int>(pool_params.pad_bottom),
559982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                        static_cast<int>(pool_params.pad_right)}),
560982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          TFPaddingToMklDnnPadding(this->padding_));
561982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      auto pool_fwd_desc =
562982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          pooling_forward::primitive_desc(pool_desc, cpu_engine);
56304807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina
56404807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      this->AllocateOutputTensor(context, pool_fwd_desc, output_dims_mkl_order,
565982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                                 this->data_format_mkldnn_, &output_tensor);
56604807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      OP_REQUIRES_OK(context, context->status());
56704807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      dnn_data_output.SetUsrMemDataHandle(output_tensor);
56804807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina
56904807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      AllocateWorkspaceTensor(context, pool_fwd_desc, &dnn_data_wksp);
57004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      OP_REQUIRES_OK(context, context->status());
57104807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina
57204807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina      this->PrepareAndExecuteNet(pool_fwd_desc, &dnn_data_input,
573982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                                 &dnn_data_output, &dnn_data_wksp);
574982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    } catch (mkldnn::error& e) {
575982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      string error_msg = "Status: " + std::to_string(e.status) +
576982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                         ", message: " + string(e.message) + ", in file " +
577982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                         string(__FILE__) + ":" + std::to_string(__LINE__);
578982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:",
579982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                                              error_msg));
58004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina    }
58104807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina  }  // Compute
58204807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina
58304807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina private:
584982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen  const int kOutputTensorIndexWorkspace = 1;
585982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen
586982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen  void AllocateWorkspaceTensor(
587982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      OpKernelContext* context,
588982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      const pooling_forward::primitive_desc& pool_fwd_prim_desc,
589982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      MklDnnData<uint8>* dnn_data_wksp) {
590982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    CHECK_NOTNULL(dnn_data_wksp);
591982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    Tensor* workspace_tensor = nullptr;
592982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    memory::primitive_desc workspace_pd =
593982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen        pool_fwd_prim_desc.workspace_primitive_desc();
594982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    size_t workspace_bytes = workspace_pd.get_size();
595982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    MklDnnShape workspace_mkl_shape;
596982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    workspace_mkl_shape.SetMklTensor(false);
597982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    TensorShape workspace_tf_shape;
598982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    workspace_tf_shape.AddDim(workspace_bytes);
599982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    AllocateOutputSetMklShape(context, kOutputTensorIndexWorkspace,
600982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                              &workspace_tensor, workspace_tf_shape,
601982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                              workspace_mkl_shape);
602982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    CHECK_NOTNULL(workspace_tensor);
603982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor);
604982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen  }
605fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane};
606fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
60704807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina// The operation to compute MaxPool gradients.
60804807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina// It takes three inputs:
60904807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina//   - The original input tensor
61004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina//   - The original output tensor
61104807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina//   - Backprop tensor for output
61204807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina// It produces one output: backprop tensor for input.
61304807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzainatemplate <class Device, class T>
61404807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzainaclass MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
61504807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina public:
61604807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina  explicit MklMaxPoolingGradOp(OpKernelConstruction* context)
617982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      : MklPoolingBackwardOpBase<T>(context) {}
61804807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina
61904807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina  void Compute(OpKernelContext* context) override {
62004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina    try {
621982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      auto cpu_engine = engine(engine::cpu, 0);
622982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      const Tensor& orig_input_tensor =
623982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          MklGetInput(context, kInputTensorIndexOrigInput);
624982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      const Tensor& orig_output_tensor =
625982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          MklGetInput(context, kInputTensorIndexOrigOutput);
626982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      const Tensor& grad_tensor =
627982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          MklGetInput(context, kInputTensorIndexGradient);
628982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      const Tensor& workspace_tensor =
629982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          MklGetInput(context, kInputTensorIndexWorkspace);
630982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      MklDnnShape orig_input_mkl_shape, orig_output_mkl_shape, grad_mkl_shape,
631982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          workspace_mkl_shape;
632982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      GetMklShape(context, kInputTensorIndexOrigInput, &orig_input_mkl_shape);
633982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      GetMklShape(context, kInputTensorIndexOrigOutput, &orig_output_mkl_shape);
634982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      GetMklShape(context, kInputTensorIndexGradient, &grad_mkl_shape);
635982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      GetMklShape(context, kInputTensorIndexWorkspace, &workspace_mkl_shape);
636982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen
637982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      SanityCheckInputs(context, orig_input_tensor, orig_output_tensor,
638982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                        grad_tensor, workspace_tensor, orig_input_mkl_shape,
639982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                        orig_output_mkl_shape, grad_mkl_shape,
640982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                        workspace_mkl_shape);
641982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      if (!context->status().ok()) return;
642982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen
643982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      MklDnnData<T> grad_dnn_data(&cpu_engine);
644982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      MklDnnData<uint8> workspace_dnn_data(&cpu_engine);
645982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      MklDnnData<T> output_dnn_data(&cpu_engine);
646982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      Tensor* output_tensor = nullptr;
647982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      MklPoolParameters pool_params;
648982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      TensorShape orig_input_shape;
649982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      memory::dims output_dims_mkl_order, orig_input_dims_mkl_order;
650982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      memory::desc original_input_md = ConfigureOriginalInput(
651982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          context, orig_input_tensor, orig_input_mkl_shape,
652982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          &orig_input_dims_mkl_order, &pool_params, &orig_input_shape);
653982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen
654982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      memory::desc original_output_md = this->ConfigureOriginalOutput(
655982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          pool_params, orig_output_mkl_shape, output_dims_mkl_order);
656982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen
657982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      memory::desc target_diff_dst_md = this->ConfigureInputGradient(
658982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          grad_mkl_shape, grad_tensor, &grad_dnn_data, original_output_md);
659982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen
660982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      output_dnn_data.SetUsrMem(original_input_md);
661982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen
662982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      // Create the forward pooling primitive descriptor so we can
663982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      // pass it as a hint to the backward pooling primitive descriptor
664982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      auto pool_fwd_desc = pooling_forward::desc(
665982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          prop_kind::forward, algorithm::pooling_max, original_input_md,
666982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          original_output_md,
667982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          memory::dims({pool_params.row_stride, pool_params.col_stride}),
668982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          memory::dims({pool_params.window_rows, pool_params.window_cols}),
669982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          memory::dims({static_cast<int>(pool_params.pad_top),
670982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                        static_cast<int>(pool_params.pad_left)}),
671982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          memory::dims({static_cast<int>(pool_params.pad_bottom),
672982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                        static_cast<int>(pool_params.pad_right)}),
673982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          TFPaddingToMklDnnPadding(this->padding_));
674982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      auto pool_fwd_prim_desc =
675982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          pooling_forward::primitive_desc(pool_fwd_desc, cpu_engine);
676982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen
677982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      auto pool_bkwd_desc = pooling_backward::desc(
678982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          algorithm::pooling_max, output_dnn_data.GetUsrMemDesc(),
679982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          target_diff_dst_md,
680982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          memory::dims({pool_params.row_stride, pool_params.col_stride}),
681982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          memory::dims({pool_params.window_rows, pool_params.window_cols}),
682982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          memory::dims({static_cast<int>(pool_params.pad_top),
683982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                        static_cast<int>(pool_params.pad_left)}),
684982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          memory::dims({static_cast<int>(pool_params.pad_bottom),
685982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                        static_cast<int>(pool_params.pad_right)}),
686982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          TFPaddingToMklDnnPadding(this->padding_));
687982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      auto pool_bkwd_prim_desc = pooling_backward::primitive_desc(
688982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          pool_bkwd_desc, cpu_engine, pool_fwd_prim_desc);
689982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen
690982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      this->AllocateOutputTensor(context, pool_bkwd_prim_desc,
691982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                                 orig_input_dims_mkl_order,
692982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                                 this->data_format_mkldnn_, &output_tensor);
693982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      output_dnn_data.SetUsrMemDataHandle(output_tensor);
694982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen
695982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      ConfigureWorkspace(workspace_tensor,
696982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                         pool_fwd_prim_desc.workspace_primitive_desc(),
697982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                         &workspace_dnn_data);
698982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      this->PrepareAndExecuteNet(
699982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          pool_bkwd_prim_desc, &grad_dnn_data, &output_dnn_data,
700982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          memory::primitive_desc(target_diff_dst_md, cpu_engine),
701982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          &workspace_dnn_data);
702982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    } catch (mkldnn::error& e) {
703982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      string error_msg = "Status: " + std::to_string(e.status) +
704982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                         ", message: " + string(e.message) + ", in file " +
705982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                         string(__FILE__) + ":" + std::to_string(__LINE__);
706982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:",
707982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                                              error_msg));
70804807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina    }
70904807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina  }  // Compute
71004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina
71104807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina private:
712982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen  // .Input("orig_input: T")
713982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen  // .Input("orig_output: T")
714982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen  // .Input("grad: T")
715982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen  // .Input("workspace: T")
716982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen  const int kInputTensorIndexOrigInput = 0;
717982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen  const int kInputTensorIndexOrigOutput = 1;
718982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen  const int kInputTensorIndexGradient = 2;
719982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen  const int kInputTensorIndexWorkspace = 3;
720982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen  //  Output("output: T") in Base Class
721982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen
722982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen  memory::desc ConfigureOriginalInput(
723982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      OpKernelContext* context, const Tensor& tensor_original_input,
724982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      const MklDnnShape& original_input_mkl_shape,
725982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      memory::dims* original_input_dims_mkl_order,
726982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      MklPoolParameters* pool_params, TensorShape* input_tensor_shape) {
727982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    *input_tensor_shape = tensor_original_input.shape();
728982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    return MklPoolingBackwardOpBase<T>::ConfigureOriginalInput(
729982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen        context, tensor_original_input, original_input_mkl_shape,
730982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen        original_input_dims_mkl_order, pool_params, *input_tensor_shape);
731982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen  }
73204807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina
733982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen  void ConfigureWorkspace(const Tensor& workspace_tensor,
734982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                          memory::primitive_desc workspace_pd,
735982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                          MklDnnData<uint8>* workspace_dnn_data) {
736982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    CHECK_NOTNULL(workspace_dnn_data);
73704807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina
738982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    workspace_dnn_data->SetUsrMem(workspace_pd, &workspace_tensor);
739982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen  }
74004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina
741982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen  void SanityCheckInputs(OpKernelContext* context,
742982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                         const Tensor& orig_input_tensor,
743982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                         const Tensor& orig_output_tensor,
744982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                         const Tensor& grad_tensor,
745982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                         const Tensor& workspace_tensor,
746982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                         const MklDnnShape& orig_input_mkl_shape,
747982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                         const MklDnnShape& orig_output_mkl_shape,
748982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                         const MklDnnShape& grad_mkl_shape,
749982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                         const MklDnnShape& workspace_mkl_shape) {
750982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    if (!orig_input_mkl_shape.IsMklTensor()) {
751982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      OP_REQUIRES(context, orig_input_tensor.dims() == 4,
752982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                  errors::InvalidArgument("Original input shape must be "
753982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                                          "4-dimensional"));
754982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    } else {
755982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      OP_REQUIRES(context, orig_input_mkl_shape.GetDimension() == 4,
756982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                  errors::InvalidArgument("Original input shape must be "
757982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                                          "4-dimensional"));
758982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    }
759982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    if (!orig_output_mkl_shape.IsMklTensor()) {
760982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      OP_REQUIRES(context, orig_output_tensor.dims() == 4,
761982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                  errors::InvalidArgument("Original output must be "
762982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                                          "4-dimensional"));
763982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    } else {
764982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      OP_REQUIRES(context, orig_output_mkl_shape.GetDimension() == 4,
765982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                  errors::InvalidArgument("Original output must be "
766982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                                          "4-dimensional"));
767982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    }
768982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    if (!grad_mkl_shape.IsMklTensor()) {
769982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      OP_REQUIRES(context, grad_tensor.dims() == 4,
770982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                  errors::InvalidArgument("Gradient must be 4-dimensional"));
771982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    } else {
772982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      OP_REQUIRES(context, grad_mkl_shape.GetDimension() == 4,
773982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                  errors::InvalidArgument("Gradient must be "
774982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                                          "4-dimensional"));
775982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    }
776982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    if (this->workspace_enabled_) {
777982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      // The workspace should not be an MKL tensor
778982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      OP_REQUIRES(context, workspace_mkl_shape.IsMklTensor() == false,
779982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                  errors::InvalidArgument("Workspace tensor should not"
780982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                                          " be an MKL Tensor."));
781982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      // It should only have one dimension
782982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      OP_REQUIRES(context, workspace_tensor.dims() == 1,
783982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                  errors::InvalidArgument("Workspace tensor must be "
784982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                                          "1-dimensional"));
785982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    } else {
786982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      OP_REQUIRES(
787982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          context, this->workspace_enabled_,
788982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          errors::Unimplemented("MKL-DNN Max Pooling does not "
78904807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina                                "yet support the use case "
79004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina                                "where MaxPoolGrad is called without first"
79104807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina                                " calling MaxPool."));
79204807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina    }
793982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen  }
79404807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina};  // MklMaxPoolingGradOp
79504807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina
796e4a628adf84a2373d773103cdeabc96cbffd7b47AG Ramesh#endif  // INTEL_MKL_ML
79704807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina
798326942394e69074d50d5889218a24c9371eff259Shanqing CaiREGISTER_KERNEL_BUILDER(Name("_MklMaxPool")
799c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                            .Device(DEVICE_CPU)
800c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                            .TypeConstraint<float>("T")
80167f9925ef9ceed02892c200a3122092ab497943aNiranjan Hasabnis                            .Label(mkl_op_registry::kMklOpLabel),
802c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                        MklMaxPoolingOp<CPUDevice, float>);
803fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
804326942394e69074d50d5889218a24c9371eff259Shanqing CaiREGISTER_KERNEL_BUILDER(Name("_MklMaxPoolGrad")
805c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                            .Device(DEVICE_CPU)
806c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                            .TypeConstraint<float>("T")
80767f9925ef9ceed02892c200a3122092ab497943aNiranjan Hasabnis                            .Label(mkl_op_registry::kMklOpLabel),
808c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba                        MklMaxPoolingGradOp<CPUDevice, float>);
809fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane
810c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba}  // namespace tensorflow
811fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane#endif  // INTEL_MKL
812