1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/nn_ops.cc. This opkernel uses MKL library, create MKL
17// layout and primitives, use MKL dnn primitives to compute convolution backward
18// input
19
20#ifdef INTEL_MKL
21
22#define USE_EIGEN_TENSOR
23#define EIGEN_USE_THREADS
24#include <algorithm>
25#include <vector>
26#include "mkl_dnn.h"
27#include "mkl_dnn_types.h"
28#include "tensorflow/core/framework/numeric_op.h"
29#include "tensorflow/core/framework/op_kernel.h"
30#include "tensorflow/core/framework/register_types.h"
31#include "tensorflow/core/framework/tensor.h"
32#include "tensorflow/core/framework/tensor_shape.h"
33#include "tensorflow/core/framework/tensor_slice.h"
34#include "tensorflow/core/kernels/conv_grad_ops.h"
35#include "tensorflow/core/kernels/mkl_conv_ops.h"
36#include "tensorflow/core/kernels/ops_util.h"
37#include "tensorflow/core/lib/core/errors.h"
38#include "tensorflow/core/lib/gtl/array_slice.h"
39#include "tensorflow/core/platform/logging.h"
40#include "tensorflow/core/platform/macros.h"
41#include "tensorflow/core/util/mkl_util.h"
42#include "tensorflow/core/util/padding.h"
43#include "tensorflow/core/util/tensor_format.h"
44#include "tensorflow/core/util/use_cudnn.h"
45#include "tensorflow/core/util/work_sharder.h"
46
47#ifndef INTEL_MKL_ML
48#include "mkldnn.hpp"
49
50using mkldnn::convolution_backward_data;
51using mkldnn::prop_kind;
52using mkldnn::stream;
53#endif
54
55namespace tensorflow {
56
57typedef Eigen::ThreadPoolDevice CPUDevice;
58
59#ifdef INTEL_MKL_ML
60
61template <typename Device, class T>
62class MklConv2DCustomBackpropInputOp : public OpKernel {
63 public:
64  ~MklConv2DCustomBackpropInputOp() {}
65  explicit MklConv2DCustomBackpropInputOp(OpKernelConstruction* context)
66      : OpKernel(context) {
67    string dataformat;
68    OP_REQUIRES_OK(context, context->GetAttr("data_format", &dataformat));
69    OP_REQUIRES(context, FormatFromString(dataformat, &data_format),
70                errors::InvalidArgument("Invalid data format"));
71    OP_REQUIRES_OK(context, context->GetAttr("strides", &strides));
72    int stride_n = GetTensorDim(strides, data_format, 'N');
73    int stride_c = GetTensorDim(strides, data_format, 'C');
74    OP_REQUIRES(
75        context, (stride_n == 1 && stride_c == 1),
76        errors::InvalidArgument("Current implementation does not yet support "
77                                "strides in the batch and depth dimensions."));
78
79    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding));
80  }
81
82  void Compute(OpKernelContext* context) override {
83    MklConvBackInputOpContext mkl_context;
84    const Tensor& input = MklGetInput(context, 0);
85    const Tensor& filter = MklGetInput(context, 1);
86
87    GetMklShape(context, 1, &(mkl_context.filter_shape));
88    bool filter_in_mkl_format = mkl_context.filter_shape.IsMklTensor();
89
90    const Tensor& out_backprop = MklGetInput(context, 2);
91    GetMklShape(context, 2, &(mkl_context.outback_shape));
92    bool outback_in_mkl_format = mkl_context.outback_shape.IsMklTensor();
93
94    TensorShape input_shape, filter_shape, outback_shape;
95
96    // Generate input shape.
97    OP_REQUIRES(
98        context, TensorShapeUtils::IsVector(input.shape()),
99        errors::InvalidArgument(
100            "Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
101            input.dims()));
102    OP_REQUIRES_OK(
103        context, TensorShapeUtils::MakeShape(input.vec<int32>(), &input_shape));
104
105    // Generate shape for filter prop if input is in MKL format.
106    if (filter_in_mkl_format) {
107      OP_REQUIRES(context, mkl_context.filter_shape.GetDimension() == 4,
108                  errors::InvalidArgument(
109                      "Conv2DCustomBackpropInput: size must be 4-dim"));
110
111      const int64* filter_sizes =
112          (const int64*)mkl_context.filter_shape.GetSizes();
113      const int64 filter_dims = mkl_context.filter_shape.GetDimension();
114
115      OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
116                                  filter_sizes, filter_dims, &filter_shape));
117    } else {
118      filter_shape = filter.shape();
119    }
120
121    // Generate shape for outback prop if input is in MKL format.
122    if (outback_in_mkl_format) {
123      OP_REQUIRES(context, mkl_context.outback_shape.GetDimension() == 4,
124                  errors::InvalidArgument(
125                      "Conv2DCustomBackpropInput: size must be 4-dim"));
126
127      MklSizesToTFSizes(context, data_format, mkl_context.outback_shape,
128                        &outback_shape);
129    } else {
130      outback_shape = out_backprop.shape();
131    }
132
133    ConvBackpropDimensions dims;
134    OP_REQUIRES_OK(
135        context,
136        ConvBackpropComputeDimensions(
137            "Conv2DCustomBackpropInput", /*num_spatial_dims=*/2, input_shape,
138            filter_shape, outback_shape, strides, padding, data_format, &dims));
139
140    int64 pad_top, pad_bottom;
141    int64 pad_left, pad_right;
142    OP_REQUIRES_OK(
143        context,
144        GetWindowedOutputSizeVerbose(
145            dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
146            dims.spatial_dims[0].stride, padding,
147            &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom));
148    OP_REQUIRES_OK(
149        context,
150        GetWindowedOutputSizeVerbose(
151            dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
152            dims.spatial_dims[1].stride, padding,
153            &dims.spatial_dims[1].output_size, &pad_left, &pad_right));
154
155    mkl_context.in_dims = 4;
156
157    mkl_context.in_sizes[0] =
158        static_cast<size_t>(dims.spatial_dims[1].input_size);
159    mkl_context.in_sizes[1] =
160        static_cast<size_t>(dims.spatial_dims[0].input_size);
161    mkl_context.in_sizes[2] = static_cast<size_t>(dims.in_depth);
162    mkl_context.in_sizes[3] = static_cast<size_t>(dims.batch_size);
163
164    mkl_context.out_sizes[0] =
165        static_cast<size_t>(dims.spatial_dims[1].output_size);
166    mkl_context.out_sizes[1] =
167        static_cast<size_t>(dims.spatial_dims[0].output_size);
168    mkl_context.out_sizes[2] = static_cast<size_t>(dims.out_depth);
169    mkl_context.out_sizes[3] = static_cast<size_t>(dims.batch_size);
170
171    mkl_context.input_offset[0] = static_cast<int>(-pad_left);
172    mkl_context.input_offset[1] = static_cast<int>(-pad_top);
173
174    mkl_context.conv_strides[0] =
175        static_cast<size_t>(dims.spatial_dims[1].stride);
176    mkl_context.conv_strides[1] =
177        static_cast<size_t>(dims.spatial_dims[0].stride);
178
179    GetStridesFromSizes(data_format, mkl_context.out_strides,
180                        mkl_context.out_sizes);
181    GetStridesFromSizes(data_format, mkl_context.in_strides,
182                        mkl_context.in_sizes);
183
184    mkl_context.filter_size[0] = dims.spatial_dims[1].filter_size;
185    mkl_context.filter_size[1] = dims.spatial_dims[0].filter_size;
186    mkl_context.filter_size[2] = dims.in_depth;
187    mkl_context.filter_size[3] = dims.out_depth;
188
189    mkl_context.filter_stride[0] =
190        mkl_context.filter_size[2] * mkl_context.filter_size[3];
191    mkl_context.filter_stride[1] = mkl_context.filter_size[2] *
192                                   mkl_context.filter_size[0] *
193                                   mkl_context.filter_size[3];
194    mkl_context.filter_stride[2] = mkl_context.filter_size[3];
195    mkl_context.filter_stride[3] = 1;
196
197    CHECK_EQ(
198        dnnConvolutionCreateBackwardData_F32(
199            &mkl_context.prim_bwddata, NULL, dnnAlgorithmConvolutionDirect,
200            mkl_context.in_dims, mkl_context.in_sizes, mkl_context.out_sizes,
201            mkl_context.filter_size, mkl_context.conv_strides,
202            mkl_context.input_offset, dnnBorderZeros),
203        E_SUCCESS);
204
205    // Allocate output tensor and shape
206    TensorShape mkl_out_shape;
207    MklShape mklOutputShape;
208    mklOutputShape.SetMklTensor(true);
209    mklOutputShape.SetMklLayout(mkl_context.prim_bwddata, dnnResourceDiffSrc);
210    mklOutputShape.SetTfLayout(mkl_context.in_dims, mkl_context.in_sizes,
211                               mkl_context.in_strides);
212    // MKL might change the dimension ordering.
213    // Create mapping to recover the original TF dimension order
214    mklOutputShape.SetTfDimOrder(mkl_context.in_dims, data_format);
215
216    Tensor* in_backprop = nullptr;
217    mkl_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
218                             mklOutputShape.GetMklLayout())) /
219                         sizeof(T));
220    AllocateOutputSetMklShape(context, 0, &in_backprop, mkl_out_shape,
221                              mklOutputShape);
222
223    mkl_context.conv_res[dnnResourceDiffSrc] =
224        static_cast<void*>(const_cast<T*>(in_backprop->flat<T>().data()));
225
226    mkl_context.MklCreateInputLayouts(context);
227    Tensor mkl_tmp_outbackprop_buf_tensor, mkl_tmp_filter_buf_tensor;
228    mkl_context.MklPrepareConvolutionInputs(
229        context, &mkl_tmp_outbackprop_buf_tensor, &mkl_tmp_filter_buf_tensor);
230
231    CHECK_EQ(dnnExecute_F32(mkl_context.prim_bwddata, mkl_context.conv_res),
232             E_SUCCESS);
233    mkl_context.MklCleanup();
234  }
235
236 private:
237  typedef struct {
238    int in_dims;
239    size_t in_sizes[4];
240    size_t in_strides[4];
241    size_t out_sizes[4];
242    size_t out_strides[4];
243    int input_offset[2];
244    size_t filter_size[4];
245    size_t filter_stride[4];
246    size_t conv_strides[2];
247    MklShape filter_shape, outback_shape;
248    dnnPrimitive_t prim_bwddata;
249    void* conv_res[dnnResourceNumber];
250    dnnLayout_t lt_filter, lt_outbackprop;
251
252    // Create MKL dnnLayout_t objects for tensors coming into the layer
253    void MklCreateInputLayouts(OpKernelContext* context) {
254      bool filter_in_mkl_format = filter_shape.IsMklTensor();
255      bool outback_in_mkl_format = outback_shape.IsMklTensor();
256      if (filter_in_mkl_format) {
257        lt_filter = (dnnLayout_t)filter_shape.GetCurLayout();
258      } else {
259        CHECK_EQ(dnnLayoutCreate_F32(&lt_filter, in_dims, filter_size,
260                                     filter_stride),
261                 E_SUCCESS);
262      }
263
264      if (outback_in_mkl_format) {
265        lt_outbackprop = (dnnLayout_t)outback_shape.GetCurLayout();
266      } else {
267        CHECK_EQ(dnnLayoutCreate_F32(&lt_outbackprop, in_dims, out_sizes,
268                                     out_strides),
269                 E_SUCCESS);
270      }
271    }
272
273    // Compare incoming input tensor layouts with MKL preferred layouts and
274    // convert data to the preferred layout if necessary
275    void MklPrepareConvolutionInputs(OpKernelContext* context,
276                                     Tensor* mkl_tmp_outbackprop_buf_tensor,
277                                     Tensor* mkl_tmp_filter_buf_tensor) {
278      dnnPrimitive_t mkl_convert_filter = nullptr,
279                     mkl_convert_outbackprop = nullptr;
280      void *mkl_filter_buf = nullptr, *mkl_outbackprop_buf = nullptr;
281      dnnLayout_t mkl_lt_filter_internal = nullptr,
282                  mkl_lt_outbackprop_internal = nullptr;
283      CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
284                   &mkl_lt_filter_internal, prim_bwddata, dnnResourceFilter),
285               E_SUCCESS);
286
287      const Tensor& filter = MklGetInput(context, 1);
288
289      CHECK_EQ(
290          dnnLayoutCreateFromPrimitive_F32(&mkl_lt_outbackprop_internal,
291                                           prim_bwddata, dnnResourceDiffDst),
292          E_SUCCESS);
293      if (!dnnLayoutCompare_F32(mkl_lt_filter_internal, lt_filter)) {
294        // Create conversion primitive
295        CHECK_EQ(dnnConversionCreate_F32(&mkl_convert_filter, lt_filter,
296                                         mkl_lt_filter_internal),
297                 E_SUCCESS);
298
299        AllocTmpBuffer(context, mkl_tmp_filter_buf_tensor,
300                       mkl_lt_filter_internal, &mkl_filter_buf);
301        CHECK_EQ(
302            dnnConversionExecute_F32(
303                mkl_convert_filter,
304                static_cast<void*>(const_cast<T*>(filter.flat<T>().data())),
305                mkl_filter_buf),
306            E_SUCCESS);
307
308        // Assign filter buf to resources[] for convolution.
309        conv_res[dnnResourceFilter] = mkl_filter_buf;
310        dnnDelete_F32(mkl_convert_filter);
311      } else {
312        // If we do not need any layout conversion for filter, then
313        // we directly assign input filter to resources[].
314        conv_res[dnnResourceFilter] =
315            static_cast<void*>(const_cast<T*>(filter.flat<T>().data()));
316      }
317      dnnLayoutDelete_F32(mkl_lt_filter_internal);
318      const Tensor& out_backprop = MklGetInput(context, 2);
319      // --
320      // We do similar steps as above for outputbackprop.
321      if (!dnnLayoutCompare_F32(mkl_lt_outbackprop_internal, lt_outbackprop)) {
322        CHECK_EQ(
323            dnnConversionCreate_F32(&mkl_convert_outbackprop, lt_outbackprop,
324                                    mkl_lt_outbackprop_internal),
325            E_SUCCESS);
326        AllocTmpBuffer(context, mkl_tmp_outbackprop_buf_tensor,
327                       mkl_lt_outbackprop_internal, &mkl_outbackprop_buf);
328
329        CHECK_EQ(dnnConversionExecute_F32(mkl_convert_outbackprop,
330                                          static_cast<void*>(const_cast<T*>(
331                                              out_backprop.flat<T>().data())),
332                                          mkl_outbackprop_buf),
333                 E_SUCCESS);
334
335        conv_res[dnnResourceDiffDst] = mkl_outbackprop_buf;
336        dnnDelete_F32(mkl_convert_outbackprop);
337      } else {
338        conv_res[dnnResourceDiffDst] =
339            static_cast<void*>(const_cast<T*>(out_backprop.flat<T>().data()));
340      }
341      dnnLayoutDelete_F32(mkl_lt_outbackprop_internal);
342    }
343
344    // Cleanup member layouts and primitives
345    void MklCleanup() {
346      bool filter_in_mkl_format = filter_shape.IsMklTensor();
347      bool outback_in_mkl_format = outback_shape.IsMklTensor();
348      if (!filter_in_mkl_format) dnnLayoutDelete_F32(lt_filter);
349      if (!outback_in_mkl_format) dnnLayoutDelete_F32(lt_outbackprop);
350      dnnDelete_F32(prim_bwddata);
351    }
352  } MklConvBackInputOpContext;
353
354  std::vector<int32> strides;
355  Padding padding;
356  TensorFormat data_format;
357};
358
359#else
360
361template <typename Device, class T>
362class MklConv2DCustomBackpropInputOp
363    : public MklConv2DBackpropCommonOp<Device, T> {
364 public:
365  explicit MklConv2DCustomBackpropInputOp(OpKernelConstruction* context)
366      : MklConv2DBackpropCommonOp<Device, T>(context) {}
367  ~MklConv2DCustomBackpropInputOp() {}
368
369 private:
370  const int kInputIndex_Filter = 1, kInputIndex_InputSizes = 0,
371            kInputIndex_OutBackProp = 2;
372  void ValidateMklShapes(const MklDnnShape& input_mkl_shape,
373                         const MklDnnShape& filter_mkl_shape,
374                         const MklDnnShape& obp_mkl_shape) {
375    // Tensor that feeds to 'Input' slot of BackpropInput is always just a shape
376    // of the Tensor and never an actual tensor. So it will never be in MKL
377    // layout.
378    CHECK(!input_mkl_shape.IsMklTensor())
379        << "Conv2DBackpropInput: input should not be in MKL Layout";
380  }
381
382  size_t GetInputTensorIndexWithSizes() { return kInputIndex_InputSizes; }
383
384  TensorShape MakeInputTfShape(OpKernelContext* context,
385                               const Tensor& input_tensor) {
386    TensorShape input_tf_shape;
387    CHECK_EQ(TensorShapeUtils::IsVector(input_tensor.shape()), true);
388    CHECK_EQ(
389        TensorShapeUtils::MakeShape(input_tensor.vec<int32>(), &input_tf_shape)
390            .ok(),
391        true);
392    return input_tf_shape;
393  }
394
395  TensorShape MakeFilterTfShape(OpKernelContext* context,
396                                const Tensor& filter_tensor) {
397    return GetTfShape(context, kInputIndex_Filter);
398  }
399
400  TensorShape GetOutputTfShape(const TensorShape& input_shape,
401                               const TensorShape& filter_shape,
402                               const TensorShape& outbprop_shape) {
403    // Output Shape of Conv2DBackpropInput is same as shape of Conv2D 'input'.
404    return input_shape;
405  }
406
407  const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims,
408                                    const memory::dims& fwd_filter_dims) {
409    // Output Shape of Conv2DBackpropInput is same as shape of Conv2D 'input'.
410    return fwd_input_dims;
411  }
412
413  memory::format GetOutputFormat(const memory::format data_format) {
414    // Output layout is Tensorflow's layout in data format order.
415    return data_format;
416  }
417
418  void CreatePrimitive(OpKernelContext* context, const engine& cpu_engine,
419                       const convolution_forward::primitive_desc& conv_fwd_pd,
420                       MklDnnData<T>* input, MklDnnData<T>* filter,
421                       MklDnnData<T>* outbackprop, MklDnnData<T>* output,
422                       Tensor** output_tensor, const memory::dims& strides,
423                       const memory::dims& padding_l,
424                       const memory::dims& padding_r, padding_kind padding,
425                       const memory::dims& bwd_output_dims,
426                       memory::format bwd_output_format) {
427    CHECK_NOTNULL(context);
428    CHECK_NOTNULL(input);
429    CHECK_NOTNULL(filter);
430    CHECK_NOTNULL(outbackprop);
431    CHECK_NOTNULL(output);
432    CHECK_NOTNULL(output_tensor);
433
434    // Create convolution backward data primitive.
435    auto bwd_desc = convolution_backward_data::desc(
436        convolution_direct, output->GetOpMemDesc(), filter->GetOpMemDesc(),
437        outbackprop->GetOpMemDesc(), strides, padding_l, padding_r, padding);
438
439    auto bwd_pd = convolution_backward_data::primitive_desc(
440        bwd_desc, cpu_engine, conv_fwd_pd);
441
442    // Allocate output tensor in TensorFlow and MKL layout.
443    AllocateOutputTensor(context, bwd_pd, bwd_output_dims, bwd_output_format,
444                         output_tensor);
445    CHECK_NOTNULL(*output_tensor);
446    // Set buffer handle using allocated output tensor.
447    output->SetUsrMemDataHandle(*output_tensor);
448
449    PrepareAndExecutePrimitive(bwd_pd, filter, outbackprop, output);
450  }
451
452  // Allocate output tensor.
453  void AllocateOutputTensor(
454      OpKernelContext* context,
455      const convolution_backward_data::primitive_desc& conv_pd,
456      const memory::dims& output_dims_mkl_order,
457      memory::format output_tf_format, Tensor** output_tensor) {
458    CHECK_NOTNULL(output_tensor);
459
460    // Output primitive descriptor for backward data is diff_src.
461    auto dst_pd = conv_pd.diff_src_primitive_desc();
462
463    // Allocate shape of Mkl tensor.
464    MklDnnShape output_mkl_shape;
465    output_mkl_shape.SetMklTensor(true);
466    output_mkl_shape.SetMklLayout(&dst_pd);
467    output_mkl_shape.SetElemType(MklDnnType<T>());
468    output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
469                                 output_dims_mkl_order, output_tf_format);
470
471    // Allocate shape of TF tensor.
472    TensorShape output_tf_shape;
473    output_tf_shape.AddDim(dst_pd.get_size() / sizeof(T));
474
475    AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape,
476                              output_mkl_shape);
477  }
478
479  // Prepare and execute net - checks for input and output reorders.
480  void PrepareAndExecutePrimitive(
481      const convolution_backward_data::primitive_desc& conv_pd,
482      MklDnnData<T>* filter, MklDnnData<T>* obp, MklDnnData<T>* output) {
483    // Create reorders between user layout and MKL layout if it is needed and
484    // add it to the net before convolution.
485    std::vector<primitive> net;
486    filter->CheckReorderToOpMem(conv_pd.weights_primitive_desc(), &net);
487    obp->CheckReorderToOpMem(conv_pd.diff_dst_primitive_desc(), &net);
488
489    net.push_back(convolution_backward_data(
490        conv_pd, obp->GetOpMem(), filter->GetOpMem(), output->GetOpMem()));
491
492    stream(stream::kind::eager).submit(net).wait();
493  }
494};
495
496#endif  // INTEL_MKL_ML
497
498#define REGISTER_MKL_CPU_KERNELS(T)                                 \
499  REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput")           \
500                              .Device(DEVICE_CPU)                   \
501                              .TypeConstraint<T>("T")               \
502                              .Label(mkl_op_registry::kMklOpLabel), \
503                          MklConv2DCustomBackpropInputOp<CPUDevice, T>);
504
505TF_CALL_float(REGISTER_MKL_CPU_KERNELS);
506#undef REGISTER_MKL_CPU_KERNELS
507
508}  // namespace tensorflow
509#endif  // INTEL_MKL
510