1ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
3ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower   Licensed under the Apache License, Version 2.0 (the "License");
4ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower   you may not use this file except in compliance with the License.
5ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower   You may obtain a copy of the License at
6ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
7ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower   http://www.apache.org/licenses/LICENSE-2.0
8ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
9ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower   Unless required by applicable law or agreed to in writing, software
10ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower   distributed under the License is distributed on an "AS IS" BASIS,
11ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower   See the License for the specific language governing permissions and
13ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower   limitations under the License.
14ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower   ==============================================================================*/
15ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
16ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower#ifdef INTEL_MKL
17ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower#define EIGEN_USE_THREADS
18ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
19ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower#include "tensorflow/core/common_runtime/device.h"
20ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower#include "tensorflow/core/framework/common_shape_fns.h"
21ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower#include "tensorflow/core/framework/numeric_op.h"
22ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower#include "tensorflow/core/framework/register_types.h"
23ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower#include "tensorflow/core/util/mkl_util.h"
24ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
25ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower#include "tensorflow/core/kernels/mkl_pooling_ops_common.h"
26ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
27d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case#ifndef INTEL_MKL_ML
2890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?#include "mkldnn.hpp"
290f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlowerusing mkldnn::algorithm;
300f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlowerusing mkldnn::engine;
3190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?using mkldnn::error;
320f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlowerusing mkldnn::memory;
3390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?using mkldnn::padding_kind;
340f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlowerusing mkldnn::pooling_backward;
350f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlowerusing mkldnn::pooling_forward;
3690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?using mkldnn::prop_kind;
3790e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?#endif
3890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
39ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlowernamespace tensorflow {
40ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
41ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlowertypedef Eigen::ThreadPoolDevice CPUDevice;
42ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
43d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case#ifdef INTEL_MKL_ML
4490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
45ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlowertemplate <typename Device, typename T>
46326942394e69074d50d5889218a24c9371eff259Shanqing Caiclass MklAvgPoolingOp : public OpKernel {
47ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower public:
48326942394e69074d50d5889218a24c9371eff259Shanqing Cai  explicit MklAvgPoolingOp(OpKernelConstruction* context) : OpKernel(context) {
49ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    string data_format;
50ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
51ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
52ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                errors::InvalidArgument("Invalid data format"));
53ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
54ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
55ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    OP_REQUIRES(context, ksize_.size() == 4,
56ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                errors::InvalidArgument("Sliding window ksize field must "
57ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                                        "specify 4 dimensions"));
58ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
59ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    OP_REQUIRES(context, stride_.size() == 4,
60ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                errors::InvalidArgument("Sliding window stride field must "
61ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                                        "specify 4 dimensions"));
62ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
63ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
64ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                errors::Unimplemented("Pooling is not yet supported on the "
65ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                                      "batch dimension."));
66ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower  }
67ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
68ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower  void Compute(OpKernelContext* context) override {
69ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    MklAvgPoolingOpContext mkl_context;
70ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    const Tensor& tensor_in = MklGetInput(context, 0);
71ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    GetMklShape(context, 0, &mkl_context.input_shape);
72ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor();
73ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
74ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    if (!input_in_mkl_format)
75ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      mkl_context.params.in_dim = tensor_in.dims();
76ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    else
77ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      mkl_context.params.in_dim = mkl_context.input_shape.GetDimension();
78ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
79ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    MklPoolParameters pool_params;
80ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    if (!input_in_mkl_format) {
81ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      pool_params.Init(context, ksize_, stride_, padding_, data_format_,
82ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                       tensor_in.shape());
83ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    } else {
84ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      pool_params.Init(context, ksize_, stride_, padding_, data_format_,
85ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                       &mkl_context.input_shape);
86ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    }
87ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
88ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    // Extract the parameters for the op from the pooling specs
89ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params);
90ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
91ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    Tensor mkl_tmp_input_buf_tensor_;
92ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    mkl_context.MklCreateLayoutsAndPrimitives(context,
93ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                                              &mkl_tmp_input_buf_tensor_);
94326942394e69074d50d5889218a24c9371eff259Shanqing Cai    OP_REQUIRES_OK(context, context->status());
95ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
96ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    Tensor workspace_tensor;
97ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    void* workspace_buf;
98ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    AllocTmpBuffer(context, &workspace_tensor, mkl_context.lt_workspace,
99ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                   &workspace_buf);
100ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
101ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    if (mkl_context.convert_input != nullptr) {
102ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      if (input_in_mkl_format == false) {
103ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        CHECK_EQ(
104ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower            dnnConversionExecute_F32(
105ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                mkl_context.convert_input,
106ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                static_cast<void*>(const_cast<T*>(tensor_in.flat<T>().data())),
107ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                mkl_context.input_buf),
108ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower            E_SUCCESS);
109ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        CHECK_EQ(dnnDelete_F32(mkl_context.convert_input), E_SUCCESS);
110ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      } else {
111ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        mkl_context.input_shape.GetConvertedFlatData(
112ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower            mkl_context.lt_prim_input,
113ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower            static_cast<void*>(const_cast<T*>(tensor_in.flat<T>().data())),
114ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower            mkl_context.input_buf);
115ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      }
116ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      mkl_context.pooling_res[dnnResourceSrc] = mkl_context.input_buf;
117ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    } else {
118ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      mkl_context.pooling_res[dnnResourceSrc] =
119ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower          static_cast<void*>(const_cast<T*>(tensor_in.flat<T>().data()));
120ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    }
121ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
122ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    // Declare output tensor and allocate memory
123ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    Tensor* output = nullptr;
124ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    TensorShape tensor_out_shape;
125ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    MklShape mkl_out_shape;
126ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    mkl_out_shape.SetMklTensor(true);
127ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    mkl_out_shape.SetMklLayout(mkl_context.prim_pooling_fwd, dnnResourceDst);
128ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    mkl_out_shape.SetTfLayout(mkl_context.params.in_dim,
129ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                              mkl_context.params.out_sizes,
130ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                              mkl_context.params.out_strides);
131ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    mkl_out_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_);
132ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
133ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    tensor_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
134ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                                mkl_out_shape.GetMklLayout())) /
135ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                            sizeof(T));
136ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
137326942394e69074d50d5889218a24c9371eff259Shanqing Cai    AllocateOutputSetMklShape(context, 0, &output, tensor_out_shape,
138ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                              mkl_out_shape);
139ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    mkl_context.pooling_res[dnnResourceDst] =
140ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        static_cast<void*>(output->flat<T>().data());
141ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
142ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    mkl_context.pooling_res[dnnResourceWorkspace] = workspace_buf;
143ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
144ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    CHECK_EQ(
145ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        dnnExecute_F32(mkl_context.prim_pooling_fwd, mkl_context.pooling_res),
146ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        E_SUCCESS);
147ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
148ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    mkl_context.MklCleanup();
14990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?  }  // Compute
150ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
151ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower private:
152ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower  typedef struct {
153ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    MklPoolingOpParams params;
154ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    MklShape input_shape;
155326942394e69074d50d5889218a24c9371eff259Shanqing Cai    dnnPrimitive_t prim_pooling_fwd = nullptr, convert_input = nullptr;
156326942394e69074d50d5889218a24c9371eff259Shanqing Cai    dnnLayout_t lt_user_input = nullptr, lt_prim_input = nullptr,
157326942394e69074d50d5889218a24c9371eff259Shanqing Cai                lt_workspace = nullptr;
158326942394e69074d50d5889218a24c9371eff259Shanqing Cai    void* input_buf = nullptr;
159ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    void* pooling_res[dnnResourceNumber];
160ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
161ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    void MklCreateLayoutsAndPrimitives(OpKernelContext* context,
162ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                                       Tensor* mkl_tmp_input_buf_tensor) {
163ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      bool input_in_mkl_format = input_shape.IsMklTensor();
164ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
165ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      if (!input_in_mkl_format) {
166ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        CHECK_EQ(dnnLayoutCreate_F32(&lt_user_input, params.in_dim,
167ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                                     params.in_sizes, params.in_strides),
168ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                 E_SUCCESS);
169ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      } else {
170ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        lt_user_input = (dnnLayout_t)input_shape.GetCurLayout();
171ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      }
172ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
173ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      dnnAlgorithm_t algorithm = dnnAlgorithmPoolingAvg;
174ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      dnnPrimitiveAttributes_t primAttr = nullptr;
175ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
176ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      // Create DNN primitives
177ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      CHECK_EQ(dnnPoolingCreateForward_F32(
178ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                   &prim_pooling_fwd, primAttr, algorithm, lt_user_input,
179ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                   params.kernel_size, params.kernel_stride, params.in_offset,
180ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                   dnnBorderZerosAsymm),
181ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower               E_SUCCESS);
182ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
183ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
184ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                   &lt_prim_input, prim_pooling_fwd, dnnResourceSrc),
185ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower               E_SUCCESS);
186ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      if (!dnnLayoutCompare_F32(lt_user_input, lt_prim_input)) {
187ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        CHECK_EQ(dnnConversionCreate_F32(&convert_input, lt_user_input,
188ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                                         lt_prim_input),
189ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                 E_SUCCESS);
190ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
191ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_prim_input,
192ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                       &input_buf);
193ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      }
194ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
195ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_workspace, prim_pooling_fwd,
196ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                                                dnnResourceWorkspace),
197ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower               E_SUCCESS);
198ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    }
199ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
200ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    void MklCleanup() {
201ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      bool input_in_mkl_format = input_shape.IsMklTensor();
202ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      if (!input_in_mkl_format) {
203ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        CHECK_EQ(dnnLayoutDelete_F32(lt_user_input), E_SUCCESS);
204ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      }
205ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
206ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      CHECK_EQ(dnnDelete_F32(prim_pooling_fwd), E_SUCCESS);
207ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      CHECK_EQ(dnnLayoutDelete_F32(lt_prim_input), E_SUCCESS);
208ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    }
209ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower  } MklAvgPoolingOpContext;
210ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
211ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower  std::vector<int32> ksize_;
212ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower  std::vector<int32> stride_;
213ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower  Padding padding_;
214ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower  TensorFormat data_format_;
215ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower};
216ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
217ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower//-----------------------------------------------------------------------------
218ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
219ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlowertemplate <class Device, class T>
220ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlowerclass MklAvgPoolingGradOp : public OpKernel {
221ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower public:
222ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower  explicit MklAvgPoolingGradOp(OpKernelConstruction* context)
223ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      : OpKernel(context) {
224ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    string data_format;
225ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
226ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
227ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
228ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                errors::InvalidArgument("Invalid data format"));
229ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
230ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    OP_REQUIRES(context, ksize_.size() == 4,
231ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                errors::InvalidArgument("Sliding window ksize field must "
232ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                                        "specify 4 dimensions"));
233ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
234ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    OP_REQUIRES(context, stride_.size() == 4,
235ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                errors::InvalidArgument("Sliding window strides field must "
236ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                                        "specify 4 dimensions"));
237ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
238ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
239ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                errors::Unimplemented("Pooling is not yet supported on the "
240ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                                      "batch dimension."));
241ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower  }
242ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
243ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower  void Compute(OpKernelContext* context) override {
244ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    MklAvgPoolingGradOpContext mkl_context;
245ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    const Tensor& tensor_in_shape = MklGetInput(context, 0);
246ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    const Tensor& out_backprop = MklGetInput(context, 1);
247ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    GetMklShape(context, 1, &mkl_context.out_backprop_shape);
248ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    bool outbackprop_in_mkl_format =
249ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        mkl_context.out_backprop_shape.IsMklTensor();
250ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
251ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    TensorShape output_shape;
252ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    auto shape_vec = tensor_in_shape.vec<int32>();
253ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    for (int64 i = 0; i < tensor_in_shape.NumElements(); ++i) {
254ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      output_shape.AddDim(shape_vec(i));
255ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    }
256ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
257ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    MklPoolParameters pool_params;
258ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    pool_params.Init(context, ksize_, stride_, padding_, data_format_,
259ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                     output_shape);
260ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
261326942394e69074d50d5889218a24c9371eff259Shanqing Cai    if (outbackprop_in_mkl_format == false)
262326942394e69074d50d5889218a24c9371eff259Shanqing Cai      mkl_context.params.in_dim = out_backprop.dims();
263326942394e69074d50d5889218a24c9371eff259Shanqing Cai    else
264326942394e69074d50d5889218a24c9371eff259Shanqing Cai      mkl_context.params.in_dim = mkl_context.out_backprop_shape.GetDimension();
265326942394e69074d50d5889218a24c9371eff259Shanqing Cai
266ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    // Extract the parameters for the op from the pooling specs
267ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params);
268ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
269ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    // Tensors needed to create temporary buffers
270ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    Tensor outbackprop_buf_tensor;
271ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    void* outbackprop_buf;
272ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    mkl_context.MklCreateLayoutsAndPrimitives(context);
273326942394e69074d50d5889218a24c9371eff259Shanqing Cai    OP_REQUIRES_OK(context, context->status());
274ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
275ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    // Check if outbackprop layout requires conversion.
276ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    if (!dnnLayoutCompare_F32(mkl_context.lt_user_outbackprop,
277ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                              mkl_context.lt_prim_outbackprop)) {
278ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      CHECK_EQ(dnnConversionCreate_F32(&mkl_context.convert_outbackprop,
279ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                                       mkl_context.lt_user_outbackprop,
280ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                                       mkl_context.lt_prim_outbackprop),
281ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower               E_SUCCESS);
282ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
283ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      AllocTmpBuffer(context, &outbackprop_buf_tensor,
284ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                     mkl_context.lt_prim_outbackprop, &outbackprop_buf);
285ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
286ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      if (!outbackprop_in_mkl_format) {
287ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        CHECK_EQ(dnnConversionExecute_F32(mkl_context.convert_outbackprop,
288ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                                          static_cast<void*>(const_cast<T*>(
289ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                                              out_backprop.flat<T>().data())),
290ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                                          outbackprop_buf),
291ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                 E_SUCCESS);
292ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        CHECK_EQ(dnnDelete_F32(mkl_context.convert_outbackprop), E_SUCCESS);
293ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      } else {
294ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        mkl_context.out_backprop_shape.GetConvertedFlatData(
295ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower            mkl_context.lt_prim_outbackprop,
296ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower            static_cast<void*>(const_cast<T*>(out_backprop.flat<T>().data())),
297ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower            outbackprop_buf);
298ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      }
299ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      mkl_context.pooling_res[dnnResourceDiffDst] = outbackprop_buf;
300ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    } else {
301ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      mkl_context.pooling_res[dnnResourceDiffDst] =
302ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower          static_cast<void*>(const_cast<T*>(out_backprop.flat<T>().data()));
303ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    }
304ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
305ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    // Handle workspace requirements.
306ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    Tensor workspace_buf_tensor;
307ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    void* workspace_buf;
308ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    AllocTmpBuffer(context, &workspace_buf_tensor, mkl_context.lt_workspace,
309ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                   &workspace_buf);
310ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    mkl_context.pooling_res[dnnResourceWorkspace] = workspace_buf;
311ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
312ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    // Handle MKL output tensor setup.
313ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    Tensor* output = nullptr;
314ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    TensorShape tensor_out_shape;
315ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    MklShape mkl_out_shape;
316ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    mkl_out_shape.SetMklTensor(true);
317ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    mkl_out_shape.SetMklLayout(mkl_context.prim_pooling_bwd,
318ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                               dnnResourceDiffSrc);
319ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    mkl_out_shape.SetTfLayout(mkl_context.params.in_dim,
320ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                              mkl_context.params.in_sizes,
321ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                              mkl_context.params.in_strides);
322ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    mkl_out_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_);
323ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
324ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    tensor_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
325ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                                mkl_out_shape.GetMklLayout())) /
326ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                            sizeof(T));
327ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
328326942394e69074d50d5889218a24c9371eff259Shanqing Cai    AllocateOutputSetMklShape(context, 0, &output, tensor_out_shape,
329ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                              mkl_out_shape);
330ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
331ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    // Set output tensor.
332ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    mkl_context.pooling_res[dnnResourceDiffSrc] =
333ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        static_cast<void*>(output->flat<T>().data());
334ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
335ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    // Execute primitive.
336ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    CHECK_EQ(
337ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        dnnExecute_F32(mkl_context.prim_pooling_bwd, mkl_context.pooling_res),
338ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        E_SUCCESS);
339ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
340ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    mkl_context.MklCleanup();
341ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower  }
342ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
343ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower private:
344ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower  typedef struct {
345ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    MklPoolingOpParams params;
346ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    MklShape out_backprop_shape;
347326942394e69074d50d5889218a24c9371eff259Shanqing Cai    dnnPrimitive_t prim_pooling_bwd = nullptr, convert_outbackprop = nullptr;
348ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    void* pooling_res[dnnResourceNumber];
349326942394e69074d50d5889218a24c9371eff259Shanqing Cai    dnnLayout_t lt_user_input = nullptr, lt_user_outbackprop = nullptr,
350326942394e69074d50d5889218a24c9371eff259Shanqing Cai                lt_prim_outbackprop = nullptr, lt_workspace = nullptr;
351ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
352ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    void MklCreateLayoutsAndPrimitives(OpKernelContext* context) {
353ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      const Tensor& tensor_in_shape = MklGetInput(context, 0);
354ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      const Tensor& out_backprop = MklGetInput(context, 1);
355ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      bool outbackprop_in_mkl_format = out_backprop_shape.IsMklTensor();
356ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
357ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      if (!outbackprop_in_mkl_format) {
358ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        // For avgpooling, tensor_in_shape should have 1 dimension, and 4
359ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        // elements.
3600f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower        OP_REQUIRES(
3610f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower            context,
3620f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower            tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 4,
3630f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower            errors::InvalidArgument("original input shape must be "
3640f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                                    "1-dimensional and 4 elements"));
365ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
366ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        // For avgpooling, out_backprop should have 4 dimensions.
367ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        OP_REQUIRES(context, out_backprop.dims() == 4,
368ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                    errors::InvalidArgument("out_backprop must be "
369ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                                            "4-dimensional"));
370ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      } else {
371ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        // Input in MKL format.
372ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        // For avgpooling, out_backprop should have 4 dimensions.
373ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        OP_REQUIRES(context, out_backprop_shape.GetDimension() == 4,
374ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                    errors::InvalidArgument("out_backprop must be "
375ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                                            "4-dimensional"));
376ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      }
377ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
378ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      // TODO(inteltf): Get outbackprop layout.
379ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      // Do we need to create layout in every invocation?
380ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      if (!outbackprop_in_mkl_format) {
381ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        CHECK_EQ(dnnLayoutCreate_F32(&lt_user_outbackprop, params.in_dim,
382ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                                     params.out_sizes, params.out_strides),
383ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                 E_SUCCESS);
384ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      } else {
385ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        lt_user_outbackprop = (dnnLayout_t)out_backprop_shape.GetCurLayout();
386ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      }
387ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
388ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      // Create the backward primitive
389ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      // Create DNN user layout
390ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      CHECK_EQ(dnnLayoutCreate_F32(&lt_user_input, params.in_dim,
391ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                                   params.in_sizes, params.in_strides),
392ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower               E_SUCCESS);
393ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
394ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      // Create PoolingBackward primitive
395ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      dnnAlgorithm_t algorithm = dnnAlgorithmPoolingAvg;
396ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      dnnPrimitiveAttributes_t primAttr = nullptr;
397ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      CHECK_EQ(dnnPoolingCreateBackward_F32(
398ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                   &prim_pooling_bwd, primAttr, algorithm, lt_user_input,
399ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                   params.kernel_size, params.kernel_stride, params.in_offset,
400ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                   dnnBorderZerosAsymm),
401ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower               E_SUCCESS);
402ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
403ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      // Create expected outbackprop layout from the primitive.
404ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
405ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                   &lt_prim_outbackprop, prim_pooling_bwd, dnnResourceDiffDst),
406ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower               E_SUCCESS);
407ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
408ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_workspace, prim_pooling_bwd,
409ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                                                dnnResourceWorkspace),
410ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower               E_SUCCESS);
411ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    }
412ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
413ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    void MklCleanup() {
414ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      bool outbackprop_in_mkl_format = out_backprop_shape.IsMklTensor();
415ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      CHECK_EQ(dnnDelete_F32(prim_pooling_bwd), E_SUCCESS);
416ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      CHECK_EQ(dnnLayoutDelete_F32(lt_user_input), E_SUCCESS);
417ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      if (!outbackprop_in_mkl_format) {
418ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower        CHECK_EQ(dnnLayoutDelete_F32(lt_user_outbackprop), E_SUCCESS);
419ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      }
420ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      CHECK_EQ(dnnLayoutDelete_F32(lt_prim_outbackprop), E_SUCCESS);
421ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower      CHECK_EQ(dnnLayoutDelete_F32(lt_workspace), E_SUCCESS);
422ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower    }
423ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower  } MklAvgPoolingGradOpContext;
424ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
425ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower  std::vector<int32> ksize_;
426ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower  std::vector<int32> stride_;
427ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower  Padding padding_;
428ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower  TensorFormat data_format_;
42990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?};  // MklAvgPoolingGradOp
43090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
431d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case#else
43290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
43390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?template <typename Device, typename T>
43490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
43590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? public:
43690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?  explicit MklAvgPoolingOp(OpKernelConstruction* context)
4370f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      : MklPoolingForwardOpBase<T>(context) {
43890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?    // Workspace is an MKLDNN construct that is only used in Max Pooling.
43990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?    // So set workspace_enabled_ to false.
44090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?    this->workspace_enabled_ = false;
44190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?  }
44290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
44390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?  void Compute(OpKernelContext* context) override {
44490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?    try {
44590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      auto cpu_engine = engine(engine::cpu, 0);
4460f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      const Tensor& input_tensor =
4470f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          MklGetInput(context, this->kInputTensorIndexInput);
44890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      MklDnnShape dnn_shape_input;
44990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input);
45090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      this->SanityCheckInput(context, input_tensor, dnn_shape_input);
45190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      if (!context->status().ok()) return;
45290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
45390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      MklDnnData<T> dnn_data_input(&cpu_engine);
45490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      MklDnnData<T> dnn_data_output(&cpu_engine);
45590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
45690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      // initialize variables for the pooling op
45790e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      MklPoolParameters pool_params;
45890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      // Get the input tensor and initialize the pooling parameters
4590f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      this->ConfigureInput(context, dnn_shape_input, input_tensor, &pool_params,
4600f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                           &dnn_data_input);
46190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      OP_REQUIRES_OK(context, context->status());
46290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
46390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      // Declare output tensor
46490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      Tensor* output_tensor = nullptr;
46590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      memory::dims output_dims_mkl_order;
46690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      this->GetOutputDims(pool_params, &output_dims_mkl_order);
46790e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
468d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case      // If input is an empty tensor, allocate an empty output tensor and return
469d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case      if (input_tensor.NumElements() == 0) {
470d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case        MklDnnShape output_mkl_shape;
471d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case        output_mkl_shape.SetMklTensor(false);
472d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case        TensorShape output_tf_shape;
473d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case        if (pool_params.data_format == TensorFormat::FORMAT_NCHW) {
474d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case          output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order);
475d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case        } else {
476d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case          memory::dims output_dims_NHWC_order;
477d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case          output_dims_NHWC_order = {pool_params.tensor_in_batch,
478d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case                                    static_cast<int>(pool_params.out_height),
479d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case                                    static_cast<int>(pool_params.out_width),
480d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case                                    pool_params.out_depth};
481d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case          output_tf_shape = MklDnnDimsToTFShape(output_dims_NHWC_order);
482d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case        }
483d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case        const int kOutputIndex = 0;
484d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case        AllocateOutputSetMklShape(context, kOutputIndex, &output_tensor,
485d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case                                  output_tf_shape, output_mkl_shape);
486d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case        CHECK_NOTNULL(output_tensor);
487d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case        return;
488d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case      }
489d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case
49090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      // If input is in Mkl layout, then just get the memory format from it
49190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      // directly, instead of using input data_format to AvgPool.
49290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      if (dnn_shape_input.IsMklTensor()) {
4930f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower        dnn_data_output.SetUsrMem(
4940f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower            output_dims_mkl_order,
4950f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower            static_cast<memory::format>(
4960f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                dnn_data_input.GetUsrMemDesc().data.format));
49790e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
49890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      } else {
4990f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower        dnn_data_output.SetUsrMem(output_dims_mkl_order,
5000f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                                  this->data_format_mkldnn_);
50190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      }
50290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
5030f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      // describe the memory layout
50490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      dnn_data_output.SetOpMemDesc(output_dims_mkl_order, memory::format::any);
50590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
50690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      // 3. create a pooling primitive descriptor
5070f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      auto pool_desc = pooling_forward::desc(
5080f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          prop_kind::forward, algorithm::pooling_avg_exclude_padding,
5090f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          dnn_data_input.GetUsrMemDesc(), dnn_data_output.GetUsrMemDesc(),
5100f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          memory::dims({pool_params.row_stride, pool_params.col_stride}),
5110f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          memory::dims({pool_params.window_rows, pool_params.window_cols}),
5120f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          memory::dims({static_cast<int>(pool_params.pad_top),
5130f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                        static_cast<int>(pool_params.pad_left)}),
5140f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          memory::dims({static_cast<int>(pool_params.pad_bottom),
5150f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                        static_cast<int>(pool_params.pad_right)}),
5160f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          TFPaddingToMklDnnPadding(this->padding_));
5170f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      auto pool_prim_desc =
5180f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          pooling_forward::primitive_desc(pool_desc, cpu_engine);
51990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
52090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      this->AllocateOutputTensor(context, pool_prim_desc, output_dims_mkl_order,
5210f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                                 this->data_format_mkldnn_, &output_tensor);
52290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      CHECK_NOTNULL(output_tensor);
52390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
52490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      OP_REQUIRES_OK(context, context->status());
52590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      dnn_data_output.SetUsrMemDataHandle(output_tensor);
52690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
5270f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      this->PrepareAndExecuteNet(pool_prim_desc, &dnn_data_input,
5280f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                                 &dnn_data_output);
5290f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower    } catch (mkldnn::error& e) {
5300f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      string error_msg = "Status: " + std::to_string(e.status) +
5310f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                         ", message: " + string(e.message) + ", in file " +
5320f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                         string(__FILE__) + ":" + std::to_string(__LINE__);
5330f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      OP_REQUIRES_OK(
5340f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          context,
5350f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          errors::Aborted("Operation received an exception:", error_msg));
53690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?    }
53790e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?  }  // Compute
5380f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower};   // MklAvgPoolingOp
53990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
54090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?//-----------------------------------------------------------------------------
54190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
54290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?template <class Device, class T>
54390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
54490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? public:
54590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?  explicit MklAvgPoolingGradOp(OpKernelConstruction* context)
5460f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      : MklPoolingBackwardOpBase<T>(context) {}
54790e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
54890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?  void Compute(OpKernelContext* context) override {
54990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?    try {
55090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      auto cpu_engine = engine(engine::cpu, 0);
55190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      MklDnnShape original_input_mkl_shape, input_gradient_mkl_shape;
5520f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      const Tensor& tensor_in_shape =
5530f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          MklGetInput(context, kInputTensorIndexInputShape);
5540f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      const Tensor& input_gradient_tensor =
5550f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          MklGetInput(context, kInputTensorIndexInputGradient);
55690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      GetMklShape(context, kInputTensorIndexInputShape,
5570f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                  &original_input_mkl_shape);
55890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      GetMklShape(context, kInputTensorIndexInputGradient,
5590f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                  &input_gradient_mkl_shape);
56090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
5610f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      SanityCheckInputs(context, tensor_in_shape, input_gradient_tensor,
5620f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                        original_input_mkl_shape, input_gradient_mkl_shape);
56390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      if (!context->status().ok()) return;
56490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
56590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      // Used to allocate output_diff_src/diff_src
56690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      // and create pool_fwd mdm desc
56790e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      // 0. Input("orig_input_shape: int32") //NOT a T Tensor!
56890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      // 1. Input("grad: T")
56990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
57090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      MklDnnData<T> input_gradient_diff_dst(&cpu_engine);
57190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      MklDnnData<T> output_diff_src(&cpu_engine);
57290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      Tensor* output_tensor_diff_src = nullptr;
57390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      TensorShape original_input_shape;
57490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      MklPoolParameters pool_params;
57590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      memory::dims output_dims_mkl_order, original_input_dims_nchw;
57690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      // Configure the original input memory descriptor
5770f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      memory::desc original_input_md = ConfigureOriginalInput(
5780f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          context, tensor_in_shape, original_input_mkl_shape,
5790f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          &original_input_dims_nchw, &pool_params, &original_input_shape);
58090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
58190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      // configure the original output memory descriptor
58290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      // by definition, the shape of the original output is the same
58390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      // as the shape of the gradient diff_dst
58490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      memory::desc original_output_md = this->ConfigureOriginalOutput(
5850f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          pool_params, input_gradient_mkl_shape, output_dims_mkl_order);
58690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
58790e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      memory::desc target_diff_dst_md = this->ConfigureInputGradient(
5880f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          input_gradient_mkl_shape, input_gradient_tensor,
5890f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          &input_gradient_diff_dst, original_output_md);
59090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      // The shape of the output diff src needs to be the same shape as the
59190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      // original input. But we will set its format to be same as the format of
59290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      // input gradient. We won't use format of original input since it will
59390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      // always be in Tensorflow layout (given that AvgPoolGrad gets shape of
59490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      // the input rather than actual input).
5950f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      output_diff_src.SetUsrMem(
5960f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          original_input_dims_nchw,
5970f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          static_cast<memory::format>(target_diff_dst_md.data.format));
59890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
59990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      // Create the forward pooling primitive descriptor so we can reference it
60090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      // in the backward pooling primitive descriptor
6010f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      auto pool_fwd_desc = pooling_forward::desc(
6020f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          prop_kind::forward, algorithm::pooling_avg_exclude_padding,
6030f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          original_input_md, original_output_md,
6040f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          memory::dims({pool_params.row_stride, pool_params.col_stride}),
6050f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          memory::dims({pool_params.window_rows, pool_params.window_cols}),
6060f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          memory::dims({static_cast<int>(pool_params.pad_top),
6070f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                        static_cast<int>(pool_params.pad_left)}),
6080f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          memory::dims({static_cast<int>(pool_params.pad_bottom),
6090f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                        static_cast<int>(pool_params.pad_right)}),
6100f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          TFPaddingToMklDnnPadding(this->padding_));
6110f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      auto pool_fwd_prim_desc =
6120f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          pooling_forward::primitive_desc(pool_fwd_desc, cpu_engine);
61390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
61490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      auto pool_bkwd_desc = pooling_backward::desc(
6150f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          algorithm::pooling_avg_exclude_padding,
6160f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          output_diff_src.GetUsrMemDesc(), target_diff_dst_md,
6170f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          memory::dims({pool_params.row_stride, pool_params.col_stride}),
6180f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          memory::dims({pool_params.window_rows, pool_params.window_cols}),
6190f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          memory::dims({static_cast<int>(pool_params.pad_top),
6200f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                        static_cast<int>(pool_params.pad_left)}),
6210f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          memory::dims({static_cast<int>(pool_params.pad_bottom),
6220f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                        static_cast<int>(pool_params.pad_right)}),
6230f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          TFPaddingToMklDnnPadding(this->padding_));
6240f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      auto pool_bkwd_prim_desc = pooling_backward::primitive_desc(
6250f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          pool_bkwd_desc, cpu_engine, pool_fwd_prim_desc);
6260f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      this->AllocateOutputTensor(
6270f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          context, pool_bkwd_prim_desc, original_input_dims_nchw,
6280f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          this->data_format_mkldnn_, &output_tensor_diff_src);
62990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
63090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      output_diff_src.SetUsrMemDataHandle(output_tensor_diff_src);
63190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
6320f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      this->PrepareAndExecuteNet(
6330f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          pool_bkwd_prim_desc, &input_gradient_diff_dst, &output_diff_src,
6340f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          memory::primitive_desc(target_diff_dst_md, cpu_engine));
6350f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower    } catch (mkldnn::error& e) {
63690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      string error_msg = "Status: " + std::to_string(e.status) +
6370f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                         ", message: " + string(e.message) + ", in file " +
6380f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                         string(__FILE__) + ":" + std::to_string(__LINE__);
6390f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:",
6400f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                                              error_msg));
64190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?    }
64290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?  }  // Compute
64390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
64490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? private:
64590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?  // 0. Input("orig_input_shape: int32")
64690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?  // 1. Input("grad: T")
64790e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?  const int kInputTensorIndexInputShape = 0;
64890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?  const int kInputTensorIndexInputGradient = 1;
64990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
6500f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower  memory::desc ConfigureOriginalInput(
6510f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      OpKernelContext* context, const Tensor& tensor_original_input_shape,
6520f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      const MklDnnShape& original_input_mkl_shape,
6530f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      memory::dims* original_input_dims_mkl_order,
6540f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      MklPoolParameters* pool_params, TensorShape* input_tensor_shape) {
65590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?    CHECK_NOTNULL(original_input_dims_mkl_order);
65690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?    CHECK_NOTNULL(pool_params);
65790e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?    CHECK_NOTNULL(input_tensor_shape);
65890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?    // For AvgPoolGrad, we only get the size of the original input because
65990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?    // The original data is irrelvant.
66090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?    auto shape_vec = tensor_original_input_shape.vec<int32>();
66190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?    for (int64 i = 0; i < tensor_original_input_shape.NumElements(); ++i) {
66290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      input_tensor_shape->AddDim(shape_vec(i));
66390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?    }
66490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
66590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?    return MklPoolingBackwardOpBase<T>::ConfigureOriginalInput(
6660f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower        context, tensor_original_input_shape, original_input_mkl_shape,
6670f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower        original_input_dims_mkl_order, pool_params, *input_tensor_shape);
6680f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower  }
66990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
67090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?  void SanityCheckInputs(OpKernelContext* context,
6710f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                         const Tensor& tensor_in_shape,
6720f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                         const Tensor& input_gradient_tensor,
6730f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                         const MklDnnShape& original_input_mkl_shape,
6740f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                         const MklDnnShape& input_gradient_mkl_shape) {
67590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?    if (!original_input_mkl_shape.IsMklTensor()) {
6760f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      OP_REQUIRES(
6770f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          context,
6780f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower          tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 4,
67990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?          errors::InvalidArgument("original input shape must be "
6800f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                                  "1-dimensional and 4 elements"));
68190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?    } else {
6820f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      OP_REQUIRES(context,
6830f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                  original_input_mkl_shape.GetDimension() == 1 &&
6840f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                      original_input_mkl_shape.DimSize(0) == 4,
6850f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                  errors::InvalidArgument("original input shape must be "
6860f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                                          "1-dimensional and 4 elements"));
68790e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?    }
68890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
68990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?    if (!input_gradient_mkl_shape.IsMklTensor()) {
69090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      // For avgpooling, input_gradient_diff_dst should have 4 dimensions.
69190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      OP_REQUIRES(context, input_gradient_tensor.dims() == 4,
6920f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                  errors::InvalidArgument("Gradient shape must be "
6930f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                                          "4-dimensional"));
69490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?    } else {
69590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?      OP_REQUIRES(context, input_gradient_mkl_shape.GetDimension() == 4,
6960f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                  errors::InvalidArgument("Gradient shape must be "
6970f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                                          "4-dimensional"));
69890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?    }
69990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?  }
70090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?};  // MklAvgPoolingGradOp
70190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?
702d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case#endif  // INTEL_MKL_ML
703ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
704326942394e69074d50d5889218a24c9371eff259Shanqing CaiREGISTER_KERNEL_BUILDER(Name("_MklAvgPool")
705ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                            .Device(DEVICE_CPU)
706ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                            .TypeConstraint<float>("T")
707326942394e69074d50d5889218a24c9371eff259Shanqing Cai                            .Label(mkl_op_registry::kMklOpLabel),
708ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                        MklAvgPoolingOp<CPUDevice, float>);
709ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
710326942394e69074d50d5889218a24c9371eff259Shanqing CaiREGISTER_KERNEL_BUILDER(Name("_MklAvgPoolGrad")
711ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                            .Device(DEVICE_CPU)
712ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                            .TypeConstraint<float>("T")
713326942394e69074d50d5889218a24c9371eff259Shanqing Cai                            .Label(mkl_op_registry::kMklOpLabel),
714ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower                        MklAvgPoolingGradOp<CPUDevice, float>);
715ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower
716ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower}  // namespace tensorflow
717ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower#endif  // INTEL_MKL
718