1ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 3ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower Licensed under the Apache License, Version 2.0 (the "License"); 4ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower you may not use this file except in compliance with the License. 5ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower You may obtain a copy of the License at 6ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 7ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower http://www.apache.org/licenses/LICENSE-2.0 8ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 9ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower Unless required by applicable law or agreed to in writing, software 10ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower distributed under the License is distributed on an "AS IS" BASIS, 11ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower See the License for the specific language governing permissions and 13ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower limitations under the License. 14ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower ==============================================================================*/ 15ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 16ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower#ifdef INTEL_MKL 17ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower#define EIGEN_USE_THREADS 18ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 19ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower#include "tensorflow/core/common_runtime/device.h" 20ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower#include "tensorflow/core/framework/common_shape_fns.h" 21ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower#include "tensorflow/core/framework/numeric_op.h" 22ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower#include "tensorflow/core/framework/register_types.h" 23ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower#include "tensorflow/core/util/mkl_util.h" 24ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 25ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower#include "tensorflow/core/kernels/mkl_pooling_ops_common.h" 26ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 27d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case#ifndef INTEL_MKL_ML 2890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?#include "mkldnn.hpp" 290f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlowerusing mkldnn::algorithm; 300f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlowerusing mkldnn::engine; 3190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?using mkldnn::error; 320f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlowerusing mkldnn::memory; 3390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?using mkldnn::padding_kind; 340f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlowerusing mkldnn::pooling_backward; 350f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlowerusing mkldnn::pooling_forward; 3690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?using mkldnn::prop_kind; 3790e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?#endif 3890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 39ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlowernamespace tensorflow { 40ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 41ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlowertypedef Eigen::ThreadPoolDevice CPUDevice; 42ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 43d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case#ifdef INTEL_MKL_ML 4490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 45ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlowertemplate <typename Device, typename T> 46326942394e69074d50d5889218a24c9371eff259Shanqing Caiclass MklAvgPoolingOp : public OpKernel { 47ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower public: 48326942394e69074d50d5889218a24c9371eff259Shanqing Cai explicit MklAvgPoolingOp(OpKernelConstruction* context) : OpKernel(context) { 49ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower string data_format; 50ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 51ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 52ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower errors::InvalidArgument("Invalid data format")); 53ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 54ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); 55ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower OP_REQUIRES(context, ksize_.size() == 4, 56ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower errors::InvalidArgument("Sliding window ksize field must " 57ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower "specify 4 dimensions")); 58ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); 59ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower OP_REQUIRES(context, stride_.size() == 4, 60ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower errors::InvalidArgument("Sliding window stride field must " 61ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower "specify 4 dimensions")); 62ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 63ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, 64ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower errors::Unimplemented("Pooling is not yet supported on the " 65ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower "batch dimension.")); 66ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } 67ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 68ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower void Compute(OpKernelContext* context) override { 69ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower MklAvgPoolingOpContext mkl_context; 70ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower const Tensor& tensor_in = MklGetInput(context, 0); 71ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower GetMklShape(context, 0, &mkl_context.input_shape); 72ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor(); 73ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 74ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower if (!input_in_mkl_format) 75ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.params.in_dim = tensor_in.dims(); 76ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower else 77ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.params.in_dim = mkl_context.input_shape.GetDimension(); 78ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 79ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower MklPoolParameters pool_params; 80ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower if (!input_in_mkl_format) { 81ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower pool_params.Init(context, ksize_, stride_, padding_, data_format_, 82ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower tensor_in.shape()); 83ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } else { 84ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower pool_params.Init(context, ksize_, stride_, padding_, data_format_, 85ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower &mkl_context.input_shape); 86ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } 87ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 88ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower // Extract the parameters for the op from the pooling specs 89ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params); 90ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 91ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower Tensor mkl_tmp_input_buf_tensor_; 92ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.MklCreateLayoutsAndPrimitives(context, 93ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower &mkl_tmp_input_buf_tensor_); 94326942394e69074d50d5889218a24c9371eff259Shanqing Cai OP_REQUIRES_OK(context, context->status()); 95ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 96ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower Tensor workspace_tensor; 97ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower void* workspace_buf; 98ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower AllocTmpBuffer(context, &workspace_tensor, mkl_context.lt_workspace, 99ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower &workspace_buf); 100ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 101ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower if (mkl_context.convert_input != nullptr) { 102ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower if (input_in_mkl_format == false) { 103ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ( 104ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower dnnConversionExecute_F32( 105ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.convert_input, 106ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower static_cast<void*>(const_cast<T*>(tensor_in.flat<T>().data())), 107ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.input_buf), 108ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower E_SUCCESS); 109ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ(dnnDelete_F32(mkl_context.convert_input), E_SUCCESS); 110ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } else { 111ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.input_shape.GetConvertedFlatData( 112ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.lt_prim_input, 113ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower static_cast<void*>(const_cast<T*>(tensor_in.flat<T>().data())), 114ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.input_buf); 115ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } 116ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.pooling_res[dnnResourceSrc] = mkl_context.input_buf; 117ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } else { 118ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.pooling_res[dnnResourceSrc] = 119ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower static_cast<void*>(const_cast<T*>(tensor_in.flat<T>().data())); 120ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } 121ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 122ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower // Declare output tensor and allocate memory 123ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower Tensor* output = nullptr; 124ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower TensorShape tensor_out_shape; 125ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower MklShape mkl_out_shape; 126ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_out_shape.SetMklTensor(true); 127ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_out_shape.SetMklLayout(mkl_context.prim_pooling_fwd, dnnResourceDst); 128ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_out_shape.SetTfLayout(mkl_context.params.in_dim, 129ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.params.out_sizes, 130ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.params.out_strides); 131ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_out_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_); 132ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 133ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower tensor_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>( 134ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_out_shape.GetMklLayout())) / 135ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower sizeof(T)); 136ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 137326942394e69074d50d5889218a24c9371eff259Shanqing Cai AllocateOutputSetMklShape(context, 0, &output, tensor_out_shape, 138ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_out_shape); 139ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.pooling_res[dnnResourceDst] = 140ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower static_cast<void*>(output->flat<T>().data()); 141ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 142ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.pooling_res[dnnResourceWorkspace] = workspace_buf; 143ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 144ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ( 145ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower dnnExecute_F32(mkl_context.prim_pooling_fwd, mkl_context.pooling_res), 146ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower E_SUCCESS); 147ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 148ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.MklCleanup(); 14990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? } // Compute 150ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 151ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower private: 152ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower typedef struct { 153ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower MklPoolingOpParams params; 154ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower MklShape input_shape; 155326942394e69074d50d5889218a24c9371eff259Shanqing Cai dnnPrimitive_t prim_pooling_fwd = nullptr, convert_input = nullptr; 156326942394e69074d50d5889218a24c9371eff259Shanqing Cai dnnLayout_t lt_user_input = nullptr, lt_prim_input = nullptr, 157326942394e69074d50d5889218a24c9371eff259Shanqing Cai lt_workspace = nullptr; 158326942394e69074d50d5889218a24c9371eff259Shanqing Cai void* input_buf = nullptr; 159ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower void* pooling_res[dnnResourceNumber]; 160ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 161ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower void MklCreateLayoutsAndPrimitives(OpKernelContext* context, 162ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower Tensor* mkl_tmp_input_buf_tensor) { 163ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower bool input_in_mkl_format = input_shape.IsMklTensor(); 164ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 165ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower if (!input_in_mkl_format) { 166ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ(dnnLayoutCreate_F32(<_user_input, params.in_dim, 167ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower params.in_sizes, params.in_strides), 168ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower E_SUCCESS); 169ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } else { 170ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower lt_user_input = (dnnLayout_t)input_shape.GetCurLayout(); 171ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } 172ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 173ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower dnnAlgorithm_t algorithm = dnnAlgorithmPoolingAvg; 174ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower dnnPrimitiveAttributes_t primAttr = nullptr; 175ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 176ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower // Create DNN primitives 177ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ(dnnPoolingCreateForward_F32( 178ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower &prim_pooling_fwd, primAttr, algorithm, lt_user_input, 179ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower params.kernel_size, params.kernel_stride, params.in_offset, 180ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower dnnBorderZerosAsymm), 181ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower E_SUCCESS); 182ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 183ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( 184ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower <_prim_input, prim_pooling_fwd, dnnResourceSrc), 185ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower E_SUCCESS); 186ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower if (!dnnLayoutCompare_F32(lt_user_input, lt_prim_input)) { 187ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ(dnnConversionCreate_F32(&convert_input, lt_user_input, 188ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower lt_prim_input), 189ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower E_SUCCESS); 190ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 191ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_prim_input, 192ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower &input_buf); 193ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } 194ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 195ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(<_workspace, prim_pooling_fwd, 196ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower dnnResourceWorkspace), 197ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower E_SUCCESS); 198ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } 199ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 200ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower void MklCleanup() { 201ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower bool input_in_mkl_format = input_shape.IsMklTensor(); 202ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower if (!input_in_mkl_format) { 203ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ(dnnLayoutDelete_F32(lt_user_input), E_SUCCESS); 204ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } 205ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 206ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ(dnnDelete_F32(prim_pooling_fwd), E_SUCCESS); 207ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ(dnnLayoutDelete_F32(lt_prim_input), E_SUCCESS); 208ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } 209ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } MklAvgPoolingOpContext; 210ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 211ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower std::vector<int32> ksize_; 212ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower std::vector<int32> stride_; 213ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower Padding padding_; 214ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower TensorFormat data_format_; 215ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower}; 216ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 217ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower//----------------------------------------------------------------------------- 218ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 219ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlowertemplate <class Device, class T> 220ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlowerclass MklAvgPoolingGradOp : public OpKernel { 221ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower public: 222ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower explicit MklAvgPoolingGradOp(OpKernelConstruction* context) 223ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower : OpKernel(context) { 224ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower string data_format; 225ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 226ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 227ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 228ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower errors::InvalidArgument("Invalid data format")); 229ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); 230ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower OP_REQUIRES(context, ksize_.size() == 4, 231ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower errors::InvalidArgument("Sliding window ksize field must " 232ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower "specify 4 dimensions")); 233ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); 234ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower OP_REQUIRES(context, stride_.size() == 4, 235ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower errors::InvalidArgument("Sliding window strides field must " 236ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower "specify 4 dimensions")); 237ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 238ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, 239ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower errors::Unimplemented("Pooling is not yet supported on the " 240ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower "batch dimension.")); 241ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } 242ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 243ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower void Compute(OpKernelContext* context) override { 244ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower MklAvgPoolingGradOpContext mkl_context; 245ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower const Tensor& tensor_in_shape = MklGetInput(context, 0); 246ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower const Tensor& out_backprop = MklGetInput(context, 1); 247ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower GetMklShape(context, 1, &mkl_context.out_backprop_shape); 248ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower bool outbackprop_in_mkl_format = 249ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.out_backprop_shape.IsMklTensor(); 250ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 251ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower TensorShape output_shape; 252ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower auto shape_vec = tensor_in_shape.vec<int32>(); 253ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower for (int64 i = 0; i < tensor_in_shape.NumElements(); ++i) { 254ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower output_shape.AddDim(shape_vec(i)); 255ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } 256ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 257ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower MklPoolParameters pool_params; 258ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower pool_params.Init(context, ksize_, stride_, padding_, data_format_, 259ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower output_shape); 260ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 261326942394e69074d50d5889218a24c9371eff259Shanqing Cai if (outbackprop_in_mkl_format == false) 262326942394e69074d50d5889218a24c9371eff259Shanqing Cai mkl_context.params.in_dim = out_backprop.dims(); 263326942394e69074d50d5889218a24c9371eff259Shanqing Cai else 264326942394e69074d50d5889218a24c9371eff259Shanqing Cai mkl_context.params.in_dim = mkl_context.out_backprop_shape.GetDimension(); 265326942394e69074d50d5889218a24c9371eff259Shanqing Cai 266ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower // Extract the parameters for the op from the pooling specs 267ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params); 268ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 269ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower // Tensors needed to create temporary buffers 270ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower Tensor outbackprop_buf_tensor; 271ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower void* outbackprop_buf; 272ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.MklCreateLayoutsAndPrimitives(context); 273326942394e69074d50d5889218a24c9371eff259Shanqing Cai OP_REQUIRES_OK(context, context->status()); 274ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 275ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower // Check if outbackprop layout requires conversion. 276ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower if (!dnnLayoutCompare_F32(mkl_context.lt_user_outbackprop, 277ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.lt_prim_outbackprop)) { 278ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ(dnnConversionCreate_F32(&mkl_context.convert_outbackprop, 279ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.lt_user_outbackprop, 280ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.lt_prim_outbackprop), 281ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower E_SUCCESS); 282ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 283ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower AllocTmpBuffer(context, &outbackprop_buf_tensor, 284ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.lt_prim_outbackprop, &outbackprop_buf); 285ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 286ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower if (!outbackprop_in_mkl_format) { 287ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ(dnnConversionExecute_F32(mkl_context.convert_outbackprop, 288ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower static_cast<void*>(const_cast<T*>( 289ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower out_backprop.flat<T>().data())), 290ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower outbackprop_buf), 291ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower E_SUCCESS); 292ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ(dnnDelete_F32(mkl_context.convert_outbackprop), E_SUCCESS); 293ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } else { 294ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.out_backprop_shape.GetConvertedFlatData( 295ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.lt_prim_outbackprop, 296ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower static_cast<void*>(const_cast<T*>(out_backprop.flat<T>().data())), 297ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower outbackprop_buf); 298ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } 299ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.pooling_res[dnnResourceDiffDst] = outbackprop_buf; 300ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } else { 301ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.pooling_res[dnnResourceDiffDst] = 302ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower static_cast<void*>(const_cast<T*>(out_backprop.flat<T>().data())); 303ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } 304ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 305ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower // Handle workspace requirements. 306ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower Tensor workspace_buf_tensor; 307ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower void* workspace_buf; 308ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower AllocTmpBuffer(context, &workspace_buf_tensor, mkl_context.lt_workspace, 309ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower &workspace_buf); 310ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.pooling_res[dnnResourceWorkspace] = workspace_buf; 311ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 312ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower // Handle MKL output tensor setup. 313ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower Tensor* output = nullptr; 314ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower TensorShape tensor_out_shape; 315ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower MklShape mkl_out_shape; 316ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_out_shape.SetMklTensor(true); 317ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_out_shape.SetMklLayout(mkl_context.prim_pooling_bwd, 318ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower dnnResourceDiffSrc); 319ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_out_shape.SetTfLayout(mkl_context.params.in_dim, 320ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.params.in_sizes, 321ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.params.in_strides); 322ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_out_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_); 323ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 324ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower tensor_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>( 325ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_out_shape.GetMklLayout())) / 326ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower sizeof(T)); 327ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 328326942394e69074d50d5889218a24c9371eff259Shanqing Cai AllocateOutputSetMklShape(context, 0, &output, tensor_out_shape, 329ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_out_shape); 330ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 331ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower // Set output tensor. 332ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.pooling_res[dnnResourceDiffSrc] = 333ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower static_cast<void*>(output->flat<T>().data()); 334ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 335ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower // Execute primitive. 336ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ( 337ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower dnnExecute_F32(mkl_context.prim_pooling_bwd, mkl_context.pooling_res), 338ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower E_SUCCESS); 339ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 340ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower mkl_context.MklCleanup(); 341ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } 342ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 343ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower private: 344ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower typedef struct { 345ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower MklPoolingOpParams params; 346ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower MklShape out_backprop_shape; 347326942394e69074d50d5889218a24c9371eff259Shanqing Cai dnnPrimitive_t prim_pooling_bwd = nullptr, convert_outbackprop = nullptr; 348ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower void* pooling_res[dnnResourceNumber]; 349326942394e69074d50d5889218a24c9371eff259Shanqing Cai dnnLayout_t lt_user_input = nullptr, lt_user_outbackprop = nullptr, 350326942394e69074d50d5889218a24c9371eff259Shanqing Cai lt_prim_outbackprop = nullptr, lt_workspace = nullptr; 351ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 352ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower void MklCreateLayoutsAndPrimitives(OpKernelContext* context) { 353ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower const Tensor& tensor_in_shape = MklGetInput(context, 0); 354ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower const Tensor& out_backprop = MklGetInput(context, 1); 355ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower bool outbackprop_in_mkl_format = out_backprop_shape.IsMklTensor(); 356ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 357ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower if (!outbackprop_in_mkl_format) { 358ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower // For avgpooling, tensor_in_shape should have 1 dimension, and 4 359ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower // elements. 3600f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower OP_REQUIRES( 3610f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower context, 3620f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 4, 3630f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower errors::InvalidArgument("original input shape must be " 3640f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower "1-dimensional and 4 elements")); 365ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 366ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower // For avgpooling, out_backprop should have 4 dimensions. 367ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower OP_REQUIRES(context, out_backprop.dims() == 4, 368ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower errors::InvalidArgument("out_backprop must be " 369ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower "4-dimensional")); 370ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } else { 371ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower // Input in MKL format. 372ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower // For avgpooling, out_backprop should have 4 dimensions. 373ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower OP_REQUIRES(context, out_backprop_shape.GetDimension() == 4, 374ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower errors::InvalidArgument("out_backprop must be " 375ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower "4-dimensional")); 376ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } 377ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 378ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower // TODO(inteltf): Get outbackprop layout. 379ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower // Do we need to create layout in every invocation? 380ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower if (!outbackprop_in_mkl_format) { 381ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ(dnnLayoutCreate_F32(<_user_outbackprop, params.in_dim, 382ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower params.out_sizes, params.out_strides), 383ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower E_SUCCESS); 384ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } else { 385ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower lt_user_outbackprop = (dnnLayout_t)out_backprop_shape.GetCurLayout(); 386ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } 387ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 388ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower // Create the backward primitive 389ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower // Create DNN user layout 390ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ(dnnLayoutCreate_F32(<_user_input, params.in_dim, 391ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower params.in_sizes, params.in_strides), 392ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower E_SUCCESS); 393ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 394ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower // Create PoolingBackward primitive 395ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower dnnAlgorithm_t algorithm = dnnAlgorithmPoolingAvg; 396ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower dnnPrimitiveAttributes_t primAttr = nullptr; 397ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ(dnnPoolingCreateBackward_F32( 398ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower &prim_pooling_bwd, primAttr, algorithm, lt_user_input, 399ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower params.kernel_size, params.kernel_stride, params.in_offset, 400ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower dnnBorderZerosAsymm), 401ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower E_SUCCESS); 402ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 403ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower // Create expected outbackprop layout from the primitive. 404ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( 405ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower <_prim_outbackprop, prim_pooling_bwd, dnnResourceDiffDst), 406ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower E_SUCCESS); 407ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 408ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(<_workspace, prim_pooling_bwd, 409ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower dnnResourceWorkspace), 410ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower E_SUCCESS); 411ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } 412ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 413ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower void MklCleanup() { 414ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower bool outbackprop_in_mkl_format = out_backprop_shape.IsMklTensor(); 415ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ(dnnDelete_F32(prim_pooling_bwd), E_SUCCESS); 416ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ(dnnLayoutDelete_F32(lt_user_input), E_SUCCESS); 417ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower if (!outbackprop_in_mkl_format) { 418ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ(dnnLayoutDelete_F32(lt_user_outbackprop), E_SUCCESS); 419ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } 420ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ(dnnLayoutDelete_F32(lt_prim_outbackprop), E_SUCCESS); 421ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower CHECK_EQ(dnnLayoutDelete_F32(lt_workspace), E_SUCCESS); 422ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } 423ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower } MklAvgPoolingGradOpContext; 424ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 425ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower std::vector<int32> ksize_; 426ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower std::vector<int32> stride_; 427ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower Padding padding_; 428ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower TensorFormat data_format_; 42990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?}; // MklAvgPoolingGradOp 43090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 431d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case#else 43290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 43390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?template <typename Device, typename T> 43490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> { 43590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? public: 43690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? explicit MklAvgPoolingOp(OpKernelConstruction* context) 4370f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower : MklPoolingForwardOpBase<T>(context) { 43890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // Workspace is an MKLDNN construct that is only used in Max Pooling. 43990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // So set workspace_enabled_ to false. 44090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? this->workspace_enabled_ = false; 44190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? } 44290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 44390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? void Compute(OpKernelContext* context) override { 44490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? try { 44590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? auto cpu_engine = engine(engine::cpu, 0); 4460f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower const Tensor& input_tensor = 4470f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower MklGetInput(context, this->kInputTensorIndexInput); 44890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? MklDnnShape dnn_shape_input; 44990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input); 45090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? this->SanityCheckInput(context, input_tensor, dnn_shape_input); 45190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? if (!context->status().ok()) return; 45290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 45390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? MklDnnData<T> dnn_data_input(&cpu_engine); 45490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? MklDnnData<T> dnn_data_output(&cpu_engine); 45590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 45690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // initialize variables for the pooling op 45790e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? MklPoolParameters pool_params; 45890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // Get the input tensor and initialize the pooling parameters 4590f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower this->ConfigureInput(context, dnn_shape_input, input_tensor, &pool_params, 4600f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower &dnn_data_input); 46190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? OP_REQUIRES_OK(context, context->status()); 46290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 46390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // Declare output tensor 46490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? Tensor* output_tensor = nullptr; 46590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? memory::dims output_dims_mkl_order; 46690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? this->GetOutputDims(pool_params, &output_dims_mkl_order); 46790e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 468d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case // If input is an empty tensor, allocate an empty output tensor and return 469d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case if (input_tensor.NumElements() == 0) { 470d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case MklDnnShape output_mkl_shape; 471d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case output_mkl_shape.SetMklTensor(false); 472d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case TensorShape output_tf_shape; 473d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case if (pool_params.data_format == TensorFormat::FORMAT_NCHW) { 474d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order); 475d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case } else { 476d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case memory::dims output_dims_NHWC_order; 477d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case output_dims_NHWC_order = {pool_params.tensor_in_batch, 478d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case static_cast<int>(pool_params.out_height), 479d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case static_cast<int>(pool_params.out_width), 480d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case pool_params.out_depth}; 481d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case output_tf_shape = MklDnnDimsToTFShape(output_dims_NHWC_order); 482d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case } 483d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case const int kOutputIndex = 0; 484d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case AllocateOutputSetMklShape(context, kOutputIndex, &output_tensor, 485d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case output_tf_shape, output_mkl_shape); 486d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case CHECK_NOTNULL(output_tensor); 487d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case return; 488d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case } 489d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case 49090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // If input is in Mkl layout, then just get the memory format from it 49190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // directly, instead of using input data_format to AvgPool. 49290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? if (dnn_shape_input.IsMklTensor()) { 4930f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower dnn_data_output.SetUsrMem( 4940f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower output_dims_mkl_order, 4950f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower static_cast<memory::format>( 4960f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower dnn_data_input.GetUsrMemDesc().data.format)); 49790e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 49890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? } else { 4990f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower dnn_data_output.SetUsrMem(output_dims_mkl_order, 5000f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower this->data_format_mkldnn_); 50190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? } 50290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 5030f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower // describe the memory layout 50490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? dnn_data_output.SetOpMemDesc(output_dims_mkl_order, memory::format::any); 50590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 50690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // 3. create a pooling primitive descriptor 5070f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower auto pool_desc = pooling_forward::desc( 5080f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower prop_kind::forward, algorithm::pooling_avg_exclude_padding, 5090f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower dnn_data_input.GetUsrMemDesc(), dnn_data_output.GetUsrMemDesc(), 5100f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower memory::dims({pool_params.row_stride, pool_params.col_stride}), 5110f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower memory::dims({pool_params.window_rows, pool_params.window_cols}), 5120f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower memory::dims({static_cast<int>(pool_params.pad_top), 5130f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower static_cast<int>(pool_params.pad_left)}), 5140f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower memory::dims({static_cast<int>(pool_params.pad_bottom), 5150f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower static_cast<int>(pool_params.pad_right)}), 5160f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower TFPaddingToMklDnnPadding(this->padding_)); 5170f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower auto pool_prim_desc = 5180f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower pooling_forward::primitive_desc(pool_desc, cpu_engine); 51990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 52090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? this->AllocateOutputTensor(context, pool_prim_desc, output_dims_mkl_order, 5210f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower this->data_format_mkldnn_, &output_tensor); 52290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? CHECK_NOTNULL(output_tensor); 52390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 52490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? OP_REQUIRES_OK(context, context->status()); 52590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? dnn_data_output.SetUsrMemDataHandle(output_tensor); 52690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 5270f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower this->PrepareAndExecuteNet(pool_prim_desc, &dnn_data_input, 5280f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower &dnn_data_output); 5290f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower } catch (mkldnn::error& e) { 5300f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower string error_msg = "Status: " + std::to_string(e.status) + 5310f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower ", message: " + string(e.message) + ", in file " + 5320f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower string(__FILE__) + ":" + std::to_string(__LINE__); 5330f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower OP_REQUIRES_OK( 5340f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower context, 5350f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower errors::Aborted("Operation received an exception:", error_msg)); 53690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? } 53790e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? } // Compute 5380f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower}; // MklAvgPoolingOp 53990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 54090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?//----------------------------------------------------------------------------- 54190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 54290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?template <class Device, class T> 54390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> { 54490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? public: 54590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? explicit MklAvgPoolingGradOp(OpKernelConstruction* context) 5460f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower : MklPoolingBackwardOpBase<T>(context) {} 54790e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 54890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? void Compute(OpKernelContext* context) override { 54990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? try { 55090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? auto cpu_engine = engine(engine::cpu, 0); 55190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? MklDnnShape original_input_mkl_shape, input_gradient_mkl_shape; 5520f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower const Tensor& tensor_in_shape = 5530f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower MklGetInput(context, kInputTensorIndexInputShape); 5540f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower const Tensor& input_gradient_tensor = 5550f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower MklGetInput(context, kInputTensorIndexInputGradient); 55690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? GetMklShape(context, kInputTensorIndexInputShape, 5570f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower &original_input_mkl_shape); 55890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? GetMklShape(context, kInputTensorIndexInputGradient, 5590f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower &input_gradient_mkl_shape); 56090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 5610f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower SanityCheckInputs(context, tensor_in_shape, input_gradient_tensor, 5620f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower original_input_mkl_shape, input_gradient_mkl_shape); 56390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? if (!context->status().ok()) return; 56490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 56590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // Used to allocate output_diff_src/diff_src 56690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // and create pool_fwd mdm desc 56790e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // 0. Input("orig_input_shape: int32") //NOT a T Tensor! 56890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // 1. Input("grad: T") 56990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 57090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? MklDnnData<T> input_gradient_diff_dst(&cpu_engine); 57190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? MklDnnData<T> output_diff_src(&cpu_engine); 57290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? Tensor* output_tensor_diff_src = nullptr; 57390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? TensorShape original_input_shape; 57490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? MklPoolParameters pool_params; 57590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? memory::dims output_dims_mkl_order, original_input_dims_nchw; 57690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // Configure the original input memory descriptor 5770f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower memory::desc original_input_md = ConfigureOriginalInput( 5780f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower context, tensor_in_shape, original_input_mkl_shape, 5790f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower &original_input_dims_nchw, &pool_params, &original_input_shape); 58090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 58190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // configure the original output memory descriptor 58290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // by definition, the shape of the original output is the same 58390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // as the shape of the gradient diff_dst 58490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? memory::desc original_output_md = this->ConfigureOriginalOutput( 5850f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower pool_params, input_gradient_mkl_shape, output_dims_mkl_order); 58690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 58790e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? memory::desc target_diff_dst_md = this->ConfigureInputGradient( 5880f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower input_gradient_mkl_shape, input_gradient_tensor, 5890f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower &input_gradient_diff_dst, original_output_md); 59090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // The shape of the output diff src needs to be the same shape as the 59190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // original input. But we will set its format to be same as the format of 59290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // input gradient. We won't use format of original input since it will 59390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // always be in Tensorflow layout (given that AvgPoolGrad gets shape of 59490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // the input rather than actual input). 5950f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower output_diff_src.SetUsrMem( 5960f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower original_input_dims_nchw, 5970f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower static_cast<memory::format>(target_diff_dst_md.data.format)); 59890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 59990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // Create the forward pooling primitive descriptor so we can reference it 60090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // in the backward pooling primitive descriptor 6010f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower auto pool_fwd_desc = pooling_forward::desc( 6020f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower prop_kind::forward, algorithm::pooling_avg_exclude_padding, 6030f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower original_input_md, original_output_md, 6040f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower memory::dims({pool_params.row_stride, pool_params.col_stride}), 6050f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower memory::dims({pool_params.window_rows, pool_params.window_cols}), 6060f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower memory::dims({static_cast<int>(pool_params.pad_top), 6070f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower static_cast<int>(pool_params.pad_left)}), 6080f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower memory::dims({static_cast<int>(pool_params.pad_bottom), 6090f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower static_cast<int>(pool_params.pad_right)}), 6100f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower TFPaddingToMklDnnPadding(this->padding_)); 6110f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower auto pool_fwd_prim_desc = 6120f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower pooling_forward::primitive_desc(pool_fwd_desc, cpu_engine); 61390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 61490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? auto pool_bkwd_desc = pooling_backward::desc( 6150f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower algorithm::pooling_avg_exclude_padding, 6160f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower output_diff_src.GetUsrMemDesc(), target_diff_dst_md, 6170f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower memory::dims({pool_params.row_stride, pool_params.col_stride}), 6180f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower memory::dims({pool_params.window_rows, pool_params.window_cols}), 6190f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower memory::dims({static_cast<int>(pool_params.pad_top), 6200f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower static_cast<int>(pool_params.pad_left)}), 6210f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower memory::dims({static_cast<int>(pool_params.pad_bottom), 6220f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower static_cast<int>(pool_params.pad_right)}), 6230f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower TFPaddingToMklDnnPadding(this->padding_)); 6240f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower auto pool_bkwd_prim_desc = pooling_backward::primitive_desc( 6250f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower pool_bkwd_desc, cpu_engine, pool_fwd_prim_desc); 6260f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower this->AllocateOutputTensor( 6270f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower context, pool_bkwd_prim_desc, original_input_dims_nchw, 6280f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower this->data_format_mkldnn_, &output_tensor_diff_src); 62990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 63090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? output_diff_src.SetUsrMemDataHandle(output_tensor_diff_src); 63190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 6320f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower this->PrepareAndExecuteNet( 6330f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower pool_bkwd_prim_desc, &input_gradient_diff_dst, &output_diff_src, 6340f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower memory::primitive_desc(target_diff_dst_md, cpu_engine)); 6350f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower } catch (mkldnn::error& e) { 63690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? string error_msg = "Status: " + std::to_string(e.status) + 6370f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower ", message: " + string(e.message) + ", in file " + 6380f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower string(__FILE__) + ":" + std::to_string(__LINE__); 6390f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:", 6400f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower error_msg)); 64190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? } 64290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? } // Compute 64390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 64490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? private: 64590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // 0. Input("orig_input_shape: int32") 64690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // 1. Input("grad: T") 64790e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? const int kInputTensorIndexInputShape = 0; 64890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? const int kInputTensorIndexInputGradient = 1; 64990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 6500f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower memory::desc ConfigureOriginalInput( 6510f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower OpKernelContext* context, const Tensor& tensor_original_input_shape, 6520f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower const MklDnnShape& original_input_mkl_shape, 6530f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower memory::dims* original_input_dims_mkl_order, 6540f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower MklPoolParameters* pool_params, TensorShape* input_tensor_shape) { 65590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? CHECK_NOTNULL(original_input_dims_mkl_order); 65690e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? CHECK_NOTNULL(pool_params); 65790e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? CHECK_NOTNULL(input_tensor_shape); 65890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // For AvgPoolGrad, we only get the size of the original input because 65990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // The original data is irrelvant. 66090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? auto shape_vec = tensor_original_input_shape.vec<int32>(); 66190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? for (int64 i = 0; i < tensor_original_input_shape.NumElements(); ++i) { 66290e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? input_tensor_shape->AddDim(shape_vec(i)); 66390e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? } 66490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 66590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? return MklPoolingBackwardOpBase<T>::ConfigureOriginalInput( 6660f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower context, tensor_original_input_shape, original_input_mkl_shape, 6670f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower original_input_dims_mkl_order, pool_params, *input_tensor_shape); 6680f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower } 66990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 67090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? void SanityCheckInputs(OpKernelContext* context, 6710f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower const Tensor& tensor_in_shape, 6720f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower const Tensor& input_gradient_tensor, 6730f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower const MklDnnShape& original_input_mkl_shape, 6740f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower const MklDnnShape& input_gradient_mkl_shape) { 67590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? if (!original_input_mkl_shape.IsMklTensor()) { 6760f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower OP_REQUIRES( 6770f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower context, 6780f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 4, 67990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? errors::InvalidArgument("original input shape must be " 6800f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower "1-dimensional and 4 elements")); 68190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? } else { 6820f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower OP_REQUIRES(context, 6830f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower original_input_mkl_shape.GetDimension() == 1 && 6840f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower original_input_mkl_shape.DimSize(0) == 4, 6850f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower errors::InvalidArgument("original input shape must be " 6860f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower "1-dimensional and 4 elements")); 68790e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? } 68890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 68990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? if (!input_gradient_mkl_shape.IsMklTensor()) { 69090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? // For avgpooling, input_gradient_diff_dst should have 4 dimensions. 69190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? OP_REQUIRES(context, input_gradient_tensor.dims() == 4, 6920f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower errors::InvalidArgument("Gradient shape must be " 6930f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower "4-dimensional")); 69490e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? } else { 69590e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? OP_REQUIRES(context, input_gradient_mkl_shape.GetDimension() == 4, 6960f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower errors::InvalidArgument("Gradient shape must be " 6970f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower "4-dimensional")); 69890e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? } 69990e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? } 70090e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man?}; // MklAvgPoolingGradOp 70190e42f3ac8c43474633136af4242dca04b6a1e09Dandelion Man? 702d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case#endif // INTEL_MKL_ML 703ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 704326942394e69074d50d5889218a24c9371eff259Shanqing CaiREGISTER_KERNEL_BUILDER(Name("_MklAvgPool") 705ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower .Device(DEVICE_CPU) 706ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower .TypeConstraint<float>("T") 707326942394e69074d50d5889218a24c9371eff259Shanqing Cai .Label(mkl_op_registry::kMklOpLabel), 708ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower MklAvgPoolingOp<CPUDevice, float>); 709ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 710326942394e69074d50d5889218a24c9371eff259Shanqing CaiREGISTER_KERNEL_BUILDER(Name("_MklAvgPoolGrad") 711ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower .Device(DEVICE_CPU) 712ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower .TypeConstraint<float>("T") 713326942394e69074d50d5889218a24c9371eff259Shanqing Cai .Label(mkl_op_registry::kMklOpLabel), 714ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower MklAvgPoolingGradOp<CPUDevice, float>); 715ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower 716ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower} // namespace tensorflow 717ccbc8991db3943ef984405881a1c917c530f902fA. Unique TensorFlower#endif // INTEL_MKL 718