1fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 3fe97705b706c9dcd36586b6158e30758346c6afdVivek RaneLicensed under the Apache License, Version 2.0 (the "License"); 4fe97705b706c9dcd36586b6158e30758346c6afdVivek Raneyou may not use this file except in compliance with the License. 5fe97705b706c9dcd36586b6158e30758346c6afdVivek RaneYou may obtain a copy of the License at 6fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 7fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane http://www.apache.org/licenses/LICENSE-2.0 8fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 9fe97705b706c9dcd36586b6158e30758346c6afdVivek RaneUnless required by applicable law or agreed to in writing, software 10fe97705b706c9dcd36586b6158e30758346c6afdVivek Ranedistributed under the License is distributed on an "AS IS" BASIS, 11fe97705b706c9dcd36586b6158e30758346c6afdVivek RaneWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12fe97705b706c9dcd36586b6158e30758346c6afdVivek RaneSee the License for the specific language governing permissions and 13fe97705b706c9dcd36586b6158e30758346c6afdVivek Ranelimitations under the License. 14fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane==============================================================================*/ 15fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 16fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane// See docs in ../ops/nn_ops.cc. 17fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane#ifdef INTEL_MKL 18fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane#define EIGEN_USE_THREADS 19fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane#include "tensorflow/core/framework/op_kernel.h" 20fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane#include "tensorflow/core/kernels/mkl_pooling_ops_common.h" 21fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane#include "tensorflow/core/lib/core/errors.h" 22c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba#include "tensorflow/core/util/mkl_util.h" 23fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane#include "tensorflow/core/util/padding.h" 24fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 25e4a628adf84a2373d773103cdeabc96cbffd7b47AG Ramesh#ifndef INTEL_MKL_ML 2604807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina#include <algorithm> 2704807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina#include "mkldnn.hpp" 28982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsenusing mkldnn::algorithm; 29982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsenusing mkldnn::engine; 3004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzainausing mkldnn::error; 31982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsenusing mkldnn::memory; 3204807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzainausing mkldnn::padding_kind; 33982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsenusing mkldnn::pooling_backward; 34982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsenusing mkldnn::pooling_forward; 3504807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzainausing mkldnn::prop_kind; 3604807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina#endif 3704807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina 38fe97705b706c9dcd36586b6158e30758346c6afdVivek Ranenamespace tensorflow { 39fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 40fe97705b706c9dcd36586b6158e30758346c6afdVivek Ranetypedef Eigen::ThreadPoolDevice CPUDevice; 41fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 42e4a628adf84a2373d773103cdeabc96cbffd7b47AG Ramesh// MKL-DNN is now default. MKL-ML must be specified explicitly. 43e4a628adf84a2373d773103cdeabc96cbffd7b47AG Ramesh#ifdef INTEL_MKL_ML 4404807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina 45fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane// An implementation of MaxPooling (forward). 46fe97705b706c9dcd36586b6158e30758346c6afdVivek Ranetemplate <typename Device, typename T> 47fe97705b706c9dcd36586b6158e30758346c6afdVivek Raneclass MklMaxPoolingOp : public OpKernel { 48fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane public: 49fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane explicit MklMaxPoolingOp(OpKernelConstruction* context) : OpKernel(context) { 50fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane string data_format; 51fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 52fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 53fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 54fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane errors::InvalidArgument("Invalid data format")); 55fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); 56fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane OP_REQUIRES(context, ksize_.size() == 4, 57fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane errors::InvalidArgument("Sliding window ksize field must " 58fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane "specify 4 dimensions")); 59fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); 60fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane OP_REQUIRES(context, stride_.size() == 4, 61fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane errors::InvalidArgument("Sliding window stride field must " 62fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane "specify 4 dimensions")); 63fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 64fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, 65fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane errors::Unimplemented("Pooling is not yet supported on the " 66fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane "batch dimension.")); 67fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 68fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane workspace_enabled_ = false; 69fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane // We may not get this attribute for this node if it does not go through 70fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane // graph rewrite pass. So we do not check for error while retrieving this 71fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane // attribute value. 72fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane context->GetAttr("workspace_enabled", &workspace_enabled_); 73fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane } 74fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 75fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane void Compute(OpKernelContext* context) override { 76c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba MklMaxPoolingOpContext mkl_context; 77fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane // Get the input tensor 78fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane const Tensor& tensor_in = MklGetInput(context, 0); 79c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba GetMklShape(context, 0, &mkl_context.input_shape); 80c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor(); 81c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba 82c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_context.params.in_dim = 4; 83c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba MklPoolParameters pool_params; 84c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba if (input_in_mkl_format == false) { 85c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba pool_params.Init(context, ksize_, stride_, padding_, data_format_, 86c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba tensor_in.shape()); 87c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba OP_REQUIRES( 88c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba context, (pool_params.depth_window == 1), 89c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba errors::Unimplemented("Depthwise max pooling not supported by MKL")); 90fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 91fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane } else { 92c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba pool_params.Init(context, ksize_, stride_, padding_, data_format_, 93c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba &mkl_context.input_shape); 94fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane } 95fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 96fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane // Extract the parameters for the op from the pooling specs 97fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 98c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params); 99c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba 100c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_context.MklCreateLayoutsAndPrimitives(context); 1014eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane OP_REQUIRES_OK(context, context->status()); 102fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 103fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane // Declare output tensor 104fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane TensorShape tensor_out_shape; 1054eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane MklShape mkl_out_shape, mkl_workspace_shape; 106fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane mkl_out_shape.SetMklTensor(true); 107c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_out_shape.SetMklLayout(mkl_context.prim_pooling_fwd, dnnResourceDst); 108c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_out_shape.SetTfLayout(mkl_context.params.in_dim, 109c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_context.params.out_sizes, 110c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_context.params.out_strides); 111c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_out_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_); 112fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 113fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane Tensor* output_tensor = nullptr; 114c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba tensor_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>( 115c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_out_shape.GetMklLayout())) / 116c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba sizeof(T)); 117e0813b2c58890decedcddf84af839a1bf3c0bc76agramesh AllocateOutputSetMklShape(context, 0, &output_tensor, tensor_out_shape, 118fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane mkl_out_shape); 119fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 120c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba Tensor* workspace_tensor; 121c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba void* workspace_buf = nullptr; 122fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 1234eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane TensorShape workspace_shape; 1244eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane mkl_workspace_shape.SetMklTensor(false); 1254eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane workspace_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>( 1264eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane mkl_context.lt_workspace)) / 1274eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane sizeof(T)); 1284eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane AllocateOutputSetMklShape(context, 1, &workspace_tensor, workspace_shape, 1294eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane mkl_workspace_shape); 1304eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane 1314eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane mkl_context.pooling_res[dnnResourceWorkspace] = const_cast<void*>( 1324eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane static_cast<const void*>(workspace_tensor->flat<T>().data())); 133c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_context.pooling_res[dnnResourceSrc] = 134c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba const_cast<void*>(static_cast<const void*>(tensor_in.flat<T>().data())); 135c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_context.pooling_res[dnnResourceDst] = const_cast<void*>( 136fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane static_cast<const void*>(output_tensor->flat<T>().data())); 137fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 138c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ( 139c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba dnnExecute_F32(mkl_context.prim_pooling_fwd, mkl_context.pooling_res), 140c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba E_SUCCESS); 141fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 142c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_context.MklCleanup(); 143fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane } 144fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 145fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane private: 146c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba typedef struct { 147c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba MklPoolingOpParams params; 148c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba MklShape input_shape; 149c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba void* pooling_res[dnnResourceNumber]; 1504eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane dnnPrimitive_t prim_pooling_fwd = nullptr; 1514eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane dnnLayout_t lt_user_input = nullptr, lt_workspace = nullptr; 152c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba 153c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba void MklCreateLayoutsAndPrimitives(OpKernelContext* context) { 154c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba bool input_in_mkl_format = input_shape.IsMklTensor(); 155c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba // Create or use existing DNN user layout 156c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba if (input_in_mkl_format == false) { 157c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnLayoutCreate_F32(<_user_input, params.in_dim, 158c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba params.in_sizes, params.in_strides), 159c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba E_SUCCESS); 160c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba } else { 161c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba lt_user_input = (dnnLayout_t)input_shape.GetCurLayout(); 162c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba } 163fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 164c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba dnnAlgorithm_t algorithm = dnnAlgorithmPoolingMax; 165c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba dnnPrimitiveAttributes_t primAttr = nullptr; 166fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 167c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba // Create DNN primitives 168c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnPoolingCreateForward_F32( 169c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba &prim_pooling_fwd, primAttr, algorithm, lt_user_input, 170c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba params.kernel_size, params.kernel_stride, params.in_offset, 171c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba dnnBorderZerosAsymm), 172c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba E_SUCCESS); 173fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 174c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba // Creates layout for the workspace 175c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(<_workspace, prim_pooling_fwd, 176c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba dnnResourceWorkspace), 177fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane E_SUCCESS); 178fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane } 179fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 180c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba void MklCleanup() { 181c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba bool input_in_mkl_format = input_shape.IsMklTensor(); 182c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnDelete_F32(prim_pooling_fwd), E_SUCCESS); 183c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba if (!input_in_mkl_format) { 184c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnLayoutDelete_F32(lt_user_input), E_SUCCESS); 185c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba } 186c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnLayoutDelete_F32(lt_workspace), E_SUCCESS); 187fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane } 188c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba } MklMaxPoolingOpContext; 189fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 190c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba std::vector<int32> ksize_; 191c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba std::vector<int32> stride_; 192c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba Padding padding_; 193c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba TensorFormat data_format_; 194c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba bool workspace_enabled_; 195fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane}; 196fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 197fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane// The operation to compute MaxPool gradients. 198fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane// It takes three inputs: 199fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane// - The original input tensor 200fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane// - The original output tensor 201fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane// - Backprop tensor for output 202fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane// It produces one output: backprop tensor for input. 203fe97705b706c9dcd36586b6158e30758346c6afdVivek Ranetemplate <class Device, class T> 204fe97705b706c9dcd36586b6158e30758346c6afdVivek Raneclass MklMaxPoolingGradOp : public OpKernel { 205fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane public: 206fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane explicit MklMaxPoolingGradOp(OpKernelConstruction* context) 207fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane : OpKernel(context) { 208fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane string data_format; 209fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 210fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 211fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 212fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane errors::InvalidArgument("Invalid data format")); 213fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); 214fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane OP_REQUIRES(context, ksize_.size() == 4, 215fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane errors::InvalidArgument("Sliding window ksize field must " 216fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane "specify 4 dimensions")); 217fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); 218fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane OP_REQUIRES(context, stride_.size() == 4, 219fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane errors::InvalidArgument("Sliding window strides field must " 220fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane "specify 4 dimensions")); 221fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 222fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, 223fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane errors::Unimplemented( 224fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane "Pooling is not yet supported on the batch dimension.")); 225fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane workspace_enabled_ = false; 226fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane // We may not get this attribute for this node if it does not go through 227fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane // graph rewrite pass. So we do not check for error while retrieving this 228fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane // attribute value. 229fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane context->GetAttr("workspace_enabled", &workspace_enabled_); 230fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane } 231fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 232fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane void Compute(OpKernelContext* context) override { 233c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba MklMaxPoolingGradOpContext mkl_context; 234fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane // Input - The original input tensor 235fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane const Tensor& tensor_in = MklGetInput(context, 0); 236fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 237fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane // Output - Backprop tensor for input. 238fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane Tensor* output_tensor = nullptr; 239fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 240c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba GetMklShape(context, 0, &mkl_context.input_shape); 241c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba GetMklShape(context, 2, &mkl_context.output_backprop_shape); 242c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor(); 243fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 244c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba if (input_in_mkl_format == false) 245c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_context.params.in_dim = tensor_in.dims(); 246fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane else 247c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_context.params.in_dim = mkl_context.input_shape.GetDimension(); 248fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 249c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba MklPoolParameters pool_params; 250c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba if (input_in_mkl_format == false) { 251c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba pool_params.Init(context, ksize_, stride_, padding_, data_format_, 252c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba tensor_in.shape()); 253c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba OP_REQUIRES( 254c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba context, (pool_params.depth_window == 1), 255c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba errors::Unimplemented("Depthwise max pooling not supported by MKL")); 256fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 257fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane } else { 258c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba pool_params.Init(context, ksize_, stride_, padding_, data_format_, 259c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba &mkl_context.input_shape); 260fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane } 261fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 262fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane // Extract the parameters for the op from the pooling specs 263c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params); 264fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 265c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_context.MklCreateLayouts(context); 2664eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane OP_REQUIRES_OK(context, context->status()); 2674eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane 268c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_context.MklCreatePrimitives(context, workspace_enabled_); 2694eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane OP_REQUIRES_OK(context, context->status()); 2704eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane 271c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_context.MklPrepareInputs(context, workspace_enabled_); 2724eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane OP_REQUIRES_OK(context, context->status()); 273fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 274fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane // Create shape for the input back prop output 275fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane TensorShape mkl_input_backprop; 276c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba MklShape mkl_output_shape; 277c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_output_shape.SetMklTensor(true); 278c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_output_shape.SetMklLayout(mkl_context.prim_pooling_bwd, 279c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba dnnResourceDiffSrc); 280c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_output_shape.SetTfLayout(mkl_context.params.in_dim, 281c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_context.params.in_sizes, 282c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_context.params.in_strides); 283c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_output_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_); 284fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 285fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane mkl_input_backprop.AddDim( 286fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane dnnLayoutGetMemorySize_F32( 287c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba static_cast<dnnLayout_t>(mkl_output_shape.GetMklLayout())) / 288fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane sizeof(T)); 289e0813b2c58890decedcddf84af839a1bf3c0bc76agramesh AllocateOutputSetMklShape(context, 0, &output_tensor, mkl_input_backprop, 290c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_output_shape); 291c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_context.pooling_res[dnnResourceDiffSrc] = const_cast<void*>( 292c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba static_cast<const void*>(output_tensor->flat<T>().data())); 293fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 294c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ( 295c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba dnnExecute_F32(mkl_context.prim_pooling_bwd, mkl_context.pooling_res), 296c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba E_SUCCESS); 297fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 298c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba mkl_context.MklCleanup(workspace_enabled_); 299fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane } 300fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 301fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane private: 302c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba typedef struct { 303c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba MklPoolingOpParams params; 304c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba MklShape input_shape, output_backprop_shape; 305c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba void* pooling_resfwd[dnnResourceNumber]; 306c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba void* pooling_res[dnnResourceNumber]; 3074eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane dnnPrimitive_t prim_pooling_fwd = nullptr, prim_pooling_bwd = nullptr, 3084eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane convert_input = nullptr, convert_outbackprop = nullptr; 3094eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane dnnLayout_t lt_outbackprop_user = nullptr, lt_outbackprop_prim = nullptr, 3104eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane lt_input_user = nullptr, lt_input_prim = nullptr; 311c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba void* input_buf; 312c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba void* outbackprop_buf; 3134eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane Tensor tmp_output_buf_tensor; 3144eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane Tensor workspace_buf_tensor; 3154eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane Tensor input_buf_tensor, outbackprop_buf_tensor; 316c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba 317c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba void MklCreateLayouts(OpKernelContext* context) { 318c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba bool input_in_mkl_format = input_shape.IsMklTensor(); 319c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor(); 320c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba // Create DNN user layout for input and outbackprop or get existing layout 321c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba if (input_in_mkl_format == false) { 322c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnLayoutCreate_F32(<_input_user, params.in_dim, 323c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba params.in_sizes, params.in_strides), 324c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba E_SUCCESS); 325c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba } else { 326c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba lt_input_user = (dnnLayout_t)input_shape.GetCurLayout(); 327c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba } 328fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 329e16cd2ede6a739a1c9c4576ebf8a4fe81e83f39eTaehoon Lee // We don't care about the output layout for now as we can create it from 330c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba // primitives for the max pooling fwd prop 331c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba if (outbackprop_in_mkl_format == false) { 332c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnLayoutCreate_F32(<_outbackprop_user, params.in_dim, 333c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba params.out_sizes, params.out_strides), 334c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba E_SUCCESS); 335c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba } else { 336c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba lt_outbackprop_user = (dnnLayout_t)output_backprop_shape.GetCurLayout(); 337c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba } 338fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane } 339fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 340c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba // Create DNN primitives 341c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba void MklCreatePrimitives(OpKernelContext* context, bool workspace_enabled) { 342c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba dnnAlgorithm_t algorithm = dnnAlgorithmPoolingMax; 343c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba dnnPrimitiveAttributes_t primAttr = nullptr; 344c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba 345c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba if (workspace_enabled == false) { 346c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnPoolingCreateForward_F32( 347c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba &prim_pooling_fwd, primAttr, algorithm, lt_input_user, 348c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba params.kernel_size, params.kernel_stride, params.in_offset, 349c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba dnnBorderZerosAsymm), 350c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba E_SUCCESS); 351c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba } 352fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 353c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnPoolingCreateBackward_F32( 354c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba &prim_pooling_bwd, primAttr, algorithm, lt_input_user, 355c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba params.kernel_size, params.kernel_stride, params.in_offset, 356c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba dnnBorderZerosAsymm), 357fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane E_SUCCESS); 358fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 359c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba // Creates conversions 360c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( 361c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba <_outbackprop_prim, prim_pooling_bwd, dnnResourceDiffDst), 362fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane E_SUCCESS); 363fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 364c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba if (workspace_enabled == false) { 365c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( 366c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba <_input_prim, prim_pooling_fwd, dnnResourceSrc), 367fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane E_SUCCESS); 368c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba if (!dnnLayoutCompare_F32(lt_input_user, lt_input_prim)) { 369c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnConversionCreate_F32(&convert_input, lt_input_user, 370c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba lt_input_prim), 371c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba E_SUCCESS); 372c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba AllocTmpBuffer(context, &input_buf_tensor, lt_input_prim, &input_buf); 373c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba } 374fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane } 375fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 376c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba if (!dnnLayoutCompare_F32(lt_outbackprop_user, lt_outbackprop_prim)) { 377c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ( 378c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba dnnConversionCreate_F32(&convert_outbackprop, lt_outbackprop_user, 379c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba lt_outbackprop_prim), 380c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba E_SUCCESS); 381c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba AllocTmpBuffer(context, &outbackprop_buf_tensor, lt_outbackprop_prim, 382c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba &outbackprop_buf); 383c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba } 384fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane } 385fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 386c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba // Compare incoming tensor layouts with MKL preferred layouts and convert 387c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba // data to the preferred layout if necessary 388c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba void MklPrepareInputs(OpKernelContext* context, bool workspace_enabled) { 389c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba const Tensor& tensor_in = MklGetInput(context, 0); 390c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba const Tensor& out_backprop = MklGetInput(context, 2); 391c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba bool input_in_mkl_format = input_shape.IsMklTensor(); 392c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor(); 393c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba 3944eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane void* tmp_output_buf = nullptr; 3954eaf53c278d46d23655fe9c420d08103c8b82112Vivek Rane void* workspace_buf = nullptr; 396c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba 397c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba if (workspace_enabled == false) { 398c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba if (convert_input != nullptr) { 399c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba if (input_in_mkl_format == false) { 400982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen CHECK_EQ(dnnConversionExecute_F32( 401982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen convert_input, 402982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen const_cast<void*>(static_cast<const void*>( 403982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen tensor_in.flat<T>().data())), 404982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen input_buf), 405982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen E_SUCCESS); 406c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnDelete_F32(convert_input), E_SUCCESS); 407c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba convert_input = nullptr; 408c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba } else { 409c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba input_shape.GetConvertedFlatData( 410982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen lt_input_prim, 411982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen const_cast<void*>( 412982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen static_cast<const void*>(tensor_in.flat<T>().data())), 413c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba input_buf); 414c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba } 415c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba pooling_resfwd[dnnResourceSrc] = input_buf; 416fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane } else { 417c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba pooling_resfwd[dnnResourceSrc] = const_cast<void*>( 418c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba static_cast<const void*>(tensor_in.flat<T>().data())); 419fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane } 420fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 421c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba dnnLayout_t lt_workspace; 422c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( 423c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba <_workspace, prim_pooling_fwd, dnnResourceWorkspace), 424c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba E_SUCCESS); 425c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba AllocTmpBuffer(context, &workspace_buf_tensor, lt_workspace, 426c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba &workspace_buf); 427c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba pooling_resfwd[dnnResourceWorkspace] = workspace_buf; 428fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 429c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba dnnLayoutDelete_F32(lt_workspace); 430fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 431c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba // We create the layout for max pooling fwd prop tmp output here 432c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba AllocTmpBuffer(context, &tmp_output_buf_tensor, lt_outbackprop_prim, 433c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba &tmp_output_buf); 434c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba pooling_resfwd[dnnResourceDst] = tmp_output_buf; 435fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 436c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnExecute_F32(prim_pooling_fwd, pooling_resfwd), E_SUCCESS); 437c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba pooling_res[dnnResourceWorkspace] = 438c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba pooling_resfwd[dnnResourceWorkspace]; 439fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane } else { 440c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba const Tensor& workspace = MklGetInput(context, 3); 441c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba pooling_res[dnnResourceWorkspace] = const_cast<void*>( 442c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba static_cast<const void*>(workspace.flat<T>().data())); 443fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane } 444fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 445c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba // Out backprop conversions if needed 446c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba if (convert_outbackprop != nullptr) { 447c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba if (outbackprop_in_mkl_format == false) { 448c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnConversionExecute_F32( 449c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba convert_outbackprop, 450c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba const_cast<void*>(static_cast<const void*>( 451c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba out_backprop.flat<T>().data())), 452c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba outbackprop_buf), 453c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba E_SUCCESS); 454c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnDelete_F32(convert_outbackprop), E_SUCCESS); 455c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba } else { 456c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba output_backprop_shape.GetConvertedFlatData( 457982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen lt_outbackprop_prim, 458982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen const_cast<void*>( 459982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen static_cast<const void*>(out_backprop.flat<T>().data())), 460c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba outbackprop_buf); 461c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba } 462c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba pooling_res[dnnResourceDiffDst] = outbackprop_buf; 463c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba } else { 464c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba pooling_res[dnnResourceDiffDst] = const_cast<void*>( 465c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba static_cast<const void*>(out_backprop.flat<T>().data())); 466c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba } 467fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane } 468fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 469c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba void MklCleanup(bool workspace_enabled) { 470c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba bool input_in_mkl_format = input_shape.IsMklTensor(); 471c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor(); 472c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba if (workspace_enabled == false) { 473c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnDelete_F32(prim_pooling_fwd), E_SUCCESS); 474c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba } 475c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnDelete_F32(prim_pooling_bwd), E_SUCCESS); 476c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba if (outbackprop_in_mkl_format == false) { 477c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnLayoutDelete_F32(lt_outbackprop_user), E_SUCCESS); 478c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba } 479c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnLayoutDelete_F32(lt_outbackprop_prim), E_SUCCESS); 480c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba if (input_in_mkl_format == false) { 481c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnLayoutDelete_F32(lt_input_user), E_SUCCESS); 482c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba } 483c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba if (workspace_enabled == false) { 484c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba CHECK_EQ(dnnLayoutDelete_F32(lt_input_prim), E_SUCCESS); 485c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba } 486fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane } 487c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba } MklMaxPoolingGradOpContext; 488fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 489c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba std::vector<int32> ksize_; 490c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba std::vector<int32> stride_; 491c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba Padding padding_; 492c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba TensorFormat data_format_; 493fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 494c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba bool workspace_enabled_; 49504807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina}; // MklMaxPoolingGradOp 49604807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina 497e4a628adf84a2373d773103cdeabc96cbffd7b47AG Ramesh#else 49804807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina 49904807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina// An implementation of MaxPooling (forward). 50004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzainatemplate <typename Device, typename T> 50104807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzainaclass MklMaxPoolingOp : public MklPoolingForwardOpBase<T> { 50204807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina public: 50304807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina explicit MklMaxPoolingOp(OpKernelConstruction* context) 504982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen : MklPoolingForwardOpBase<T>(context) { 50504807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina // In Max Pooling, MKLDNN does not allow passing workspace as NULL. 50604807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina // So we set workspace_enabled_ to true. 50704807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina this->workspace_enabled_ = true; 50804807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina } 50904807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina 51004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina void Compute(OpKernelContext* context) override { 51104807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina try { 51204807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina auto cpu_engine = engine(engine::cpu, 0); 513982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen const Tensor& input_tensor = 514982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen MklGetInput(context, this->kInputTensorIndexInput); 51504807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina MklDnnShape dnn_shape_input; 51604807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input); 51704807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina this->SanityCheckInput(context, input_tensor, dnn_shape_input); 51804807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina if (!context->status().ok()) return; 51904807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina 52004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina MklDnnData<T> dnn_data_input(&cpu_engine); 52104807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina MklDnnData<T> dnn_data_output(&cpu_engine); 522ae700bb7462ee1bd4fed3c89441e962f64c89afdNiranjan Hasabnis MklDnnData<uint8> dnn_data_wksp(&cpu_engine); 52304807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina 52404807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina // initialize variables for the pooling op 52504807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina MklPoolParameters pool_params; 52604807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina // Get the input tensor and initialize the pooling parameters 527982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen this->ConfigureInput(context, dnn_shape_input, input_tensor, &pool_params, 528982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen &dnn_data_input); 52904807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina OP_REQUIRES_OK(context, context->status()); 53004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina 53104807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina // Declare output tensor 53204807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina Tensor* output_tensor = nullptr; 53304807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina memory::dims output_dims_mkl_order; 53404807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina this->GetOutputDims(pool_params, &output_dims_mkl_order); 53504807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina 53604807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina // If input is in Mkl layout, then just get the memory format from it 53704807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina // directly, instead of using input data_format to MaxPool. 53804807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina if (dnn_shape_input.IsMklTensor()) { 539982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen dnn_data_output.SetUsrMem( 540982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen output_dims_mkl_order, 541982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen static_cast<memory::format>( 542982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen dnn_data_input.GetUsrMemDesc().data.format)); 54304807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina } else { 54404807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina dnn_data_output.SetUsrMem(output_dims_mkl_order, 54504807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina this->data_format_mkldnn_); 54604807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina } 54704807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina 54804807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina // describe the memory layout; let mkl-dnn choose the best for the op 54904807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina dnn_data_output.SetOpMemDesc(output_dims_mkl_order, memory::format::any); 55004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina 551982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen auto pool_desc = pooling_forward::desc( 552982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen prop_kind::forward, algorithm::pooling_max, 553982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen dnn_data_input.GetUsrMemDesc(), dnn_data_output.GetUsrMemDesc(), 554982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen memory::dims({pool_params.row_stride, pool_params.col_stride}), 555982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen memory::dims({pool_params.window_rows, pool_params.window_cols}), 556982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen memory::dims({static_cast<int>(pool_params.pad_top), 557982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen static_cast<int>(pool_params.pad_left)}), 558982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen memory::dims({static_cast<int>(pool_params.pad_bottom), 559982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen static_cast<int>(pool_params.pad_right)}), 560982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen TFPaddingToMklDnnPadding(this->padding_)); 561982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen auto pool_fwd_desc = 562982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen pooling_forward::primitive_desc(pool_desc, cpu_engine); 56304807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina 56404807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina this->AllocateOutputTensor(context, pool_fwd_desc, output_dims_mkl_order, 565982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen this->data_format_mkldnn_, &output_tensor); 56604807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina OP_REQUIRES_OK(context, context->status()); 56704807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina dnn_data_output.SetUsrMemDataHandle(output_tensor); 56804807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina 56904807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina AllocateWorkspaceTensor(context, pool_fwd_desc, &dnn_data_wksp); 57004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina OP_REQUIRES_OK(context, context->status()); 57104807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina 57204807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina this->PrepareAndExecuteNet(pool_fwd_desc, &dnn_data_input, 573982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen &dnn_data_output, &dnn_data_wksp); 574982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen } catch (mkldnn::error& e) { 575982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen string error_msg = "Status: " + std::to_string(e.status) + 576982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen ", message: " + string(e.message) + ", in file " + 577982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen string(__FILE__) + ":" + std::to_string(__LINE__); 578982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:", 579982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen error_msg)); 58004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina } 58104807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina } // Compute 58204807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina 58304807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina private: 584982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen const int kOutputTensorIndexWorkspace = 1; 585982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen 586982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen void AllocateWorkspaceTensor( 587982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen OpKernelContext* context, 588982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen const pooling_forward::primitive_desc& pool_fwd_prim_desc, 589982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen MklDnnData<uint8>* dnn_data_wksp) { 590982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen CHECK_NOTNULL(dnn_data_wksp); 591982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen Tensor* workspace_tensor = nullptr; 592982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen memory::primitive_desc workspace_pd = 593982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen pool_fwd_prim_desc.workspace_primitive_desc(); 594982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen size_t workspace_bytes = workspace_pd.get_size(); 595982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen MklDnnShape workspace_mkl_shape; 596982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen workspace_mkl_shape.SetMklTensor(false); 597982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen TensorShape workspace_tf_shape; 598982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen workspace_tf_shape.AddDim(workspace_bytes); 599982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen AllocateOutputSetMklShape(context, kOutputTensorIndexWorkspace, 600982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen &workspace_tensor, workspace_tf_shape, 601982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen workspace_mkl_shape); 602982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen CHECK_NOTNULL(workspace_tensor); 603982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor); 604982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen } 605fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane}; 606fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 60704807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina// The operation to compute MaxPool gradients. 60804807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina// It takes three inputs: 60904807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina// - The original input tensor 61004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina// - The original output tensor 61104807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina// - Backprop tensor for output 61204807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina// It produces one output: backprop tensor for input. 61304807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzainatemplate <class Device, class T> 61404807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzainaclass MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> { 61504807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina public: 61604807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina explicit MklMaxPoolingGradOp(OpKernelConstruction* context) 617982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen : MklPoolingBackwardOpBase<T>(context) {} 61804807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina 61904807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina void Compute(OpKernelContext* context) override { 62004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina try { 621982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen auto cpu_engine = engine(engine::cpu, 0); 622982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen const Tensor& orig_input_tensor = 623982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen MklGetInput(context, kInputTensorIndexOrigInput); 624982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen const Tensor& orig_output_tensor = 625982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen MklGetInput(context, kInputTensorIndexOrigOutput); 626982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen const Tensor& grad_tensor = 627982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen MklGetInput(context, kInputTensorIndexGradient); 628982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen const Tensor& workspace_tensor = 629982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen MklGetInput(context, kInputTensorIndexWorkspace); 630982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen MklDnnShape orig_input_mkl_shape, orig_output_mkl_shape, grad_mkl_shape, 631982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen workspace_mkl_shape; 632982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen GetMklShape(context, kInputTensorIndexOrigInput, &orig_input_mkl_shape); 633982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen GetMklShape(context, kInputTensorIndexOrigOutput, &orig_output_mkl_shape); 634982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen GetMklShape(context, kInputTensorIndexGradient, &grad_mkl_shape); 635982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen GetMklShape(context, kInputTensorIndexWorkspace, &workspace_mkl_shape); 636982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen 637982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen SanityCheckInputs(context, orig_input_tensor, orig_output_tensor, 638982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen grad_tensor, workspace_tensor, orig_input_mkl_shape, 639982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen orig_output_mkl_shape, grad_mkl_shape, 640982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen workspace_mkl_shape); 641982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen if (!context->status().ok()) return; 642982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen 643982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen MklDnnData<T> grad_dnn_data(&cpu_engine); 644982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen MklDnnData<uint8> workspace_dnn_data(&cpu_engine); 645982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen MklDnnData<T> output_dnn_data(&cpu_engine); 646982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen Tensor* output_tensor = nullptr; 647982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen MklPoolParameters pool_params; 648982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen TensorShape orig_input_shape; 649982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen memory::dims output_dims_mkl_order, orig_input_dims_mkl_order; 650982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen memory::desc original_input_md = ConfigureOriginalInput( 651982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen context, orig_input_tensor, orig_input_mkl_shape, 652982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen &orig_input_dims_mkl_order, &pool_params, &orig_input_shape); 653982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen 654982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen memory::desc original_output_md = this->ConfigureOriginalOutput( 655982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen pool_params, orig_output_mkl_shape, output_dims_mkl_order); 656982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen 657982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen memory::desc target_diff_dst_md = this->ConfigureInputGradient( 658982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen grad_mkl_shape, grad_tensor, &grad_dnn_data, original_output_md); 659982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen 660982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen output_dnn_data.SetUsrMem(original_input_md); 661982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen 662982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen // Create the forward pooling primitive descriptor so we can 663982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen // pass it as a hint to the backward pooling primitive descriptor 664982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen auto pool_fwd_desc = pooling_forward::desc( 665982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen prop_kind::forward, algorithm::pooling_max, original_input_md, 666982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen original_output_md, 667982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen memory::dims({pool_params.row_stride, pool_params.col_stride}), 668982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen memory::dims({pool_params.window_rows, pool_params.window_cols}), 669982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen memory::dims({static_cast<int>(pool_params.pad_top), 670982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen static_cast<int>(pool_params.pad_left)}), 671982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen memory::dims({static_cast<int>(pool_params.pad_bottom), 672982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen static_cast<int>(pool_params.pad_right)}), 673982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen TFPaddingToMklDnnPadding(this->padding_)); 674982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen auto pool_fwd_prim_desc = 675982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen pooling_forward::primitive_desc(pool_fwd_desc, cpu_engine); 676982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen 677982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen auto pool_bkwd_desc = pooling_backward::desc( 678982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen algorithm::pooling_max, output_dnn_data.GetUsrMemDesc(), 679982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen target_diff_dst_md, 680982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen memory::dims({pool_params.row_stride, pool_params.col_stride}), 681982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen memory::dims({pool_params.window_rows, pool_params.window_cols}), 682982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen memory::dims({static_cast<int>(pool_params.pad_top), 683982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen static_cast<int>(pool_params.pad_left)}), 684982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen memory::dims({static_cast<int>(pool_params.pad_bottom), 685982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen static_cast<int>(pool_params.pad_right)}), 686982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen TFPaddingToMklDnnPadding(this->padding_)); 687982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen auto pool_bkwd_prim_desc = pooling_backward::primitive_desc( 688982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen pool_bkwd_desc, cpu_engine, pool_fwd_prim_desc); 689982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen 690982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen this->AllocateOutputTensor(context, pool_bkwd_prim_desc, 691982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen orig_input_dims_mkl_order, 692982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen this->data_format_mkldnn_, &output_tensor); 693982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen output_dnn_data.SetUsrMemDataHandle(output_tensor); 694982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen 695982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen ConfigureWorkspace(workspace_tensor, 696982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen pool_fwd_prim_desc.workspace_primitive_desc(), 697982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen &workspace_dnn_data); 698982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen this->PrepareAndExecuteNet( 699982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen pool_bkwd_prim_desc, &grad_dnn_data, &output_dnn_data, 700982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen memory::primitive_desc(target_diff_dst_md, cpu_engine), 701982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen &workspace_dnn_data); 702982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen } catch (mkldnn::error& e) { 703982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen string error_msg = "Status: " + std::to_string(e.status) + 704982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen ", message: " + string(e.message) + ", in file " + 705982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen string(__FILE__) + ":" + std::to_string(__LINE__); 706982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:", 707982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen error_msg)); 70804807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina } 70904807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina } // Compute 71004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina 71104807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina private: 712982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen // .Input("orig_input: T") 713982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen // .Input("orig_output: T") 714982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen // .Input("grad: T") 715982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen // .Input("workspace: T") 716982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen const int kInputTensorIndexOrigInput = 0; 717982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen const int kInputTensorIndexOrigOutput = 1; 718982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen const int kInputTensorIndexGradient = 2; 719982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen const int kInputTensorIndexWorkspace = 3; 720982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen // Output("output: T") in Base Class 721982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen 722982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen memory::desc ConfigureOriginalInput( 723982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen OpKernelContext* context, const Tensor& tensor_original_input, 724982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen const MklDnnShape& original_input_mkl_shape, 725982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen memory::dims* original_input_dims_mkl_order, 726982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen MklPoolParameters* pool_params, TensorShape* input_tensor_shape) { 727982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen *input_tensor_shape = tensor_original_input.shape(); 728982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen return MklPoolingBackwardOpBase<T>::ConfigureOriginalInput( 729982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen context, tensor_original_input, original_input_mkl_shape, 730982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen original_input_dims_mkl_order, pool_params, *input_tensor_shape); 731982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen } 73204807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina 733982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen void ConfigureWorkspace(const Tensor& workspace_tensor, 734982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen memory::primitive_desc workspace_pd, 735982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen MklDnnData<uint8>* workspace_dnn_data) { 736982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen CHECK_NOTNULL(workspace_dnn_data); 73704807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina 738982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen workspace_dnn_data->SetUsrMem(workspace_pd, &workspace_tensor); 739982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen } 74004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina 741982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen void SanityCheckInputs(OpKernelContext* context, 742982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen const Tensor& orig_input_tensor, 743982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen const Tensor& orig_output_tensor, 744982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen const Tensor& grad_tensor, 745982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen const Tensor& workspace_tensor, 746982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen const MklDnnShape& orig_input_mkl_shape, 747982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen const MklDnnShape& orig_output_mkl_shape, 748982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen const MklDnnShape& grad_mkl_shape, 749982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen const MklDnnShape& workspace_mkl_shape) { 750982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen if (!orig_input_mkl_shape.IsMklTensor()) { 751982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen OP_REQUIRES(context, orig_input_tensor.dims() == 4, 752982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen errors::InvalidArgument("Original input shape must be " 753982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen "4-dimensional")); 754982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen } else { 755982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen OP_REQUIRES(context, orig_input_mkl_shape.GetDimension() == 4, 756982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen errors::InvalidArgument("Original input shape must be " 757982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen "4-dimensional")); 758982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen } 759982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen if (!orig_output_mkl_shape.IsMklTensor()) { 760982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen OP_REQUIRES(context, orig_output_tensor.dims() == 4, 761982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen errors::InvalidArgument("Original output must be " 762982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen "4-dimensional")); 763982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen } else { 764982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen OP_REQUIRES(context, orig_output_mkl_shape.GetDimension() == 4, 765982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen errors::InvalidArgument("Original output must be " 766982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen "4-dimensional")); 767982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen } 768982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen if (!grad_mkl_shape.IsMklTensor()) { 769982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen OP_REQUIRES(context, grad_tensor.dims() == 4, 770982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen errors::InvalidArgument("Gradient must be 4-dimensional")); 771982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen } else { 772982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen OP_REQUIRES(context, grad_mkl_shape.GetDimension() == 4, 773982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen errors::InvalidArgument("Gradient must be " 774982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen "4-dimensional")); 775982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen } 776982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen if (this->workspace_enabled_) { 777982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen // The workspace should not be an MKL tensor 778982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen OP_REQUIRES(context, workspace_mkl_shape.IsMklTensor() == false, 779982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen errors::InvalidArgument("Workspace tensor should not" 780982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen " be an MKL Tensor.")); 781982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen // It should only have one dimension 782982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen OP_REQUIRES(context, workspace_tensor.dims() == 1, 783982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen errors::InvalidArgument("Workspace tensor must be " 784982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen "1-dimensional")); 785982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen } else { 786982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen OP_REQUIRES( 787982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen context, this->workspace_enabled_, 788982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen errors::Unimplemented("MKL-DNN Max Pooling does not " 78904807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina "yet support the use case " 79004807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina "where MaxPoolGrad is called without first" 79104807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina " calling MaxPool.")); 79204807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina } 793982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen } 79404807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina}; // MklMaxPoolingGradOp 79504807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina 796e4a628adf84a2373d773103cdeabc96cbffd7b47AG Ramesh#endif // INTEL_MKL_ML 79704807b625b0260c4daff98a618f3742c5fe1a782Mahmoud Abuzaina 798326942394e69074d50d5889218a24c9371eff259Shanqing CaiREGISTER_KERNEL_BUILDER(Name("_MklMaxPool") 799c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba .Device(DEVICE_CPU) 800c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba .TypeConstraint<float>("T") 80167f9925ef9ceed02892c200a3122092ab497943aNiranjan Hasabnis .Label(mkl_op_registry::kMklOpLabel), 802c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba MklMaxPoolingOp<CPUDevice, float>); 803fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 804326942394e69074d50d5889218a24c9371eff259Shanqing CaiREGISTER_KERNEL_BUILDER(Name("_MklMaxPoolGrad") 805c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba .Device(DEVICE_CPU) 806c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba .TypeConstraint<float>("T") 80767f9925ef9ceed02892c200a3122092ab497943aNiranjan Hasabnis .Label(mkl_op_registry::kMklOpLabel), 808c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba MklMaxPoolingGradOp<CPUDevice, float>); 809fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane 810c3461478b7bdfcec54d3815191fd06090182deceJayaram Bobba} // namespace tensorflow 811fe97705b706c9dcd36586b6158e30758346c6afdVivek Rane#endif // INTEL_MKL 812