1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/math_ops.cc.
17
18#ifdef INTEL_MKL
19#define EIGEN_USE_THREADS
20
21#include <numeric>
22#include "tensorflow/core/framework/numeric_op.h"
23#include "tensorflow/core/framework/register_types.h"
24#include "tensorflow/core/lib/gtl/inlined_vector.h"
25#include "tensorflow/core/platform/logging.h"
26
27#include "mkl_dnn.h"
28#include "mkl_dnn_types.h"
29#include "tensorflow/core/util/mkl_util.h"
30
31#ifndef INTEL_MKL_ML
32#include "mkldnn.hpp"
33using mkldnn::stream;
34using mkldnn::sum;
35#endif
36
37namespace tensorflow {
38typedef Eigen::ThreadPoolDevice CPUDevice;
39
40#ifdef INTEL_MKL_ML
41
42template <typename Device, typename T>
43class MklAddNOp : public OpKernel {
44 public:
45  explicit MklAddNOp(OpKernelConstruction* context) : OpKernel(context) {}
46
47  void Compute(OpKernelContext* ctx) override {
48    const int num = ctx->num_inputs();
49    OP_REQUIRES(ctx, num / 2 == 2,
50                errors::InvalidArgument("Only additions of two tensors "
51                                        "supported by MKL. Num inputs: ",
52                                        num));
53
54    MklAddNOpContext mkl_context;
55    size_t src1_idx = 0, src2_idx = 1;
56    const Tensor& input0 = MklGetInput(ctx, src1_idx);
57    GetMklShape(ctx, src1_idx, &(mkl_context.input1_shape));
58    bool input1_in_mkl_format = mkl_context.input1_shape.IsMklTensor();
59
60    const Tensor& input1 = MklGetInput(ctx, src2_idx);
61    GetMklShape(ctx, src2_idx, &(mkl_context.input2_shape));
62    bool input2_in_mkl_format = mkl_context.input2_shape.IsMklTensor();
63
64    // if the shapes of two tensors are not same raise op error
65    TensorShape src1_shape, src2_shape;
66    src1_shape = input0.shape();
67    src2_shape = input1.shape();
68    if (!src1_shape.IsSameSize(src2_shape)) {
69      ctx->SetStatus(errors::InvalidArgument(
70          "Inputs to operation ", this->name(), " of type ",
71          this->type_string(), " must have the same size and shape.  Input 0: ",
72          src1_shape.DebugString(), " != input 1: ", src2_shape.DebugString()));
73    }
74    // handle the case of a scalar
75    if (!input1_in_mkl_format && input0.dims() == 0) {
76      const TensorShape& o_shape = input0.shape();
77      Tensor* out_tensor = nullptr;
78      mkl_context.output_shape.SetMklTensor(false);
79      AllocateOutputSetMklShape(ctx, src1_idx, &out_tensor, o_shape,
80                                mkl_context.output_shape);
81      float user_i1 = (input0.scalar<T>()());
82      float user_i2 = (input1.scalar<T>()());
83      out_tensor->scalar<T>()() = std::plus<float>{}(user_i1, user_i2);
84      return;
85    }
86
87    mkl_context.in_dims = input1_in_mkl_format
88                              ? mkl_context.input1_shape.GetDimension()
89                              : input0.dims();
90    mkl_context.in_dims = input2_in_mkl_format
91                              ? mkl_context.input2_shape.GetDimension()
92                              : input1.dims();
93
94    // If there is nothing to compute, return.
95    if (!input1_in_mkl_format && !input2_in_mkl_format) {
96      const TensorShape& o_shape = input0.shape();
97      if (o_shape.num_elements() == 0) {
98        Tensor* out_tensor = nullptr;
99        mkl_context.output_shape.SetMklTensor(false);
100        AllocateOutputSetMklShape(ctx, src1_idx, &out_tensor, o_shape,
101                                  mkl_context.output_shape);
102        return;
103      }
104    }
105
106    mkl_context.in_sizes = new size_t[mkl_context.in_dims];
107    mkl_context.in_strides = new size_t[mkl_context.in_dims];
108    // Generate size, stride for input if input is in MKL format.
109    if (input1_in_mkl_format || input2_in_mkl_format) {
110      const MklShape* tmp_mkl_shape = (input1_in_mkl_format)
111                                          ? &mkl_context.input1_shape
112                                          : &mkl_context.input2_shape;
113      for (int i = 0; i < mkl_context.in_dims; i++) {
114        mkl_context.in_sizes[i] = tmp_mkl_shape->GetSizes()[i];
115        mkl_context.in_strides[i] = tmp_mkl_shape->GetStrides()[i];
116      }
117    } else {
118      for (int i = 0; i < mkl_context.in_dims; i++) {
119        mkl_context.in_sizes[i] =
120            input0.dim_size((mkl_context.in_dims - 1) - i);
121      }
122      mkl_context.in_strides[0] = 1;
123      for (int i = 1; i < mkl_context.in_dims; i++) {
124        mkl_context.in_strides[i] =
125            mkl_context.in_strides[i - 1] * mkl_context.in_sizes[i - 1];
126      }
127    }
128    std::vector<float> coeff(2, 1.0);
129    mkl_context.MklCreateInputLayouts(ctx);
130    CHECK_EQ(dnnSumCreate_F32(&mkl_context.Eltwise, mkl_context.attributes, 2,
131                              mkl_context.lt_input1, &coeff[0]),
132             E_SUCCESS);
133
134    Tensor mkl_tmp_input1_buf_tensor, mkl_tmp_input2_buf_tensor;
135    mkl_context.MklPrepareAddNInputs(ctx, &mkl_tmp_input1_buf_tensor,
136                                     &mkl_tmp_input2_buf_tensor);
137    Tensor* output = nullptr;
138    if (input1_in_mkl_format || input2_in_mkl_format) {
139      TensorShape tf_shape;
140      mkl_context.output_shape.SetMklTensor(true);
141      mkl_context.output_shape.SetMklLayout(mkl_context.Eltwise,
142                                            dnnResourceDst);
143
144      mkl_context.output_shape.SetTfLayout(
145          mkl_context.in_dims, mkl_context.in_sizes, mkl_context.in_strides);
146      if (input1_in_mkl_format == true) {
147        mkl_context.output_shape.SetTfDimOrder(
148            mkl_context.in_dims, mkl_context.input1_shape.GetTfToMklDimMap());
149      } else {
150        mkl_context.output_shape.SetTfDimOrder(
151            mkl_context.in_dims, mkl_context.input2_shape.GetTfToMklDimMap());
152      }
153      tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
154                          mkl_context.output_shape.GetMklLayout())) /
155                      sizeof(T));
156
157      AllocateOutputSetMklShape(ctx, src1_idx, &output, tf_shape,
158                                mkl_context.output_shape);
159    } else {
160      const TensorShape& o_shape = input1.shape();
161      mkl_context.output_shape.SetMklTensor(false);
162      AllocateOutputSetMklShape(ctx, src1_idx, &output, o_shape,
163                                mkl_context.output_shape);
164    }
165
166    mkl_context.Eltwise_res[dnnResourceDst] =
167        static_cast<void*>(output->flat<T>().data());
168
169    // Execute convolution
170    CHECK_EQ(dnnExecute_F32(mkl_context.Eltwise, mkl_context.Eltwise_res),
171             E_SUCCESS);
172
173    mkl_context.MklCleanup();
174  }
175
176 private:
177  typedef struct {
178    int in_dims;
179    size_t* in_sizes = nullptr;
180    size_t* in_strides = nullptr;
181    dnnPrimitive_t Eltwise = nullptr;
182    dnnPrimitiveAttributes_t attributes = nullptr;
183    void* Eltwise_res[dnnResourceNumber];
184    dnnLayout_t lt_input1 = nullptr, lt_input2 = nullptr;
185    MklShape input1_shape, input2_shape, output_shape;
186
187    void MklCreateInputLayouts(OpKernelContext* context) {
188      bool input1_in_mkl_format = input1_shape.IsMklTensor();
189      if (!input1_in_mkl_format) {
190        CHECK_EQ(dnnLayoutCreate_F32(&lt_input1, in_dims, in_sizes, in_strides),
191                 E_SUCCESS);
192      } else {
193        lt_input1 = static_cast<dnnLayout_t>(input1_shape.GetCurLayout());
194      }
195
196      bool input2_in_mkl_format = input2_shape.IsMklTensor();
197      if (!input2_in_mkl_format) {
198        CHECK_EQ(dnnLayoutCreate_F32(&lt_input2, in_dims, in_sizes, in_strides),
199                 E_SUCCESS);
200      } else {
201        lt_input2 = static_cast<dnnLayout_t>(input2_shape.GetCurLayout());
202      }
203    }
204
205    void MklPrepareAddNInputs(OpKernelContext* context,
206                              Tensor* mkl_tmp_input1_buf_tensor,
207                              Tensor* mkl_tmp_input2_buf_tensor) {
208      bool mkl_convert_input1, mkl_convert_input2;
209      dnnPrimitive_t mkl_prim_convert_input1 = nullptr,
210                     mkl_prim_convert_input2 = nullptr;
211      dnnLayout_t mkl_lt_internal_input1 = nullptr,
212                  mkl_lt_internal_input2 = nullptr;
213      void *mkl_buf_convert_input1 = nullptr, *mkl_buf_convert_input2 = nullptr;
214      dnnResourceType_t dnnResourceMultipleSrc2 =
215          (dnnResourceType_t)(dnnResourceMultipleSrc + 1);
216      // Compare with internal layouts and convert if needed
217      const Tensor& input1 = MklGetInput(context, 0);
218
219      void* mkl_buf_input1 =
220          const_cast<void*>(static_cast<const void*>(input1.flat<T>().data()));
221
222      CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
223                   &mkl_lt_internal_input1, Eltwise, dnnResourceMultipleSrc),
224               E_SUCCESS);
225      mkl_convert_input1 =
226          !dnnLayoutCompare_F32(mkl_lt_internal_input1, lt_input1);
227      if (mkl_convert_input1) {
228        CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input1, lt_input1,
229                                         mkl_lt_internal_input1),
230                 E_SUCCESS);
231        AllocTmpBuffer(context, mkl_tmp_input1_buf_tensor,
232                       mkl_lt_internal_input1, &mkl_buf_convert_input1);
233        CHECK_EQ(
234            dnnConversionExecute_F32(mkl_prim_convert_input1, mkl_buf_input1,
235                                     mkl_buf_convert_input1),
236            E_SUCCESS);
237        dnnDelete_F32(mkl_prim_convert_input1);
238      }
239      dnnLayoutDelete_F32(mkl_lt_internal_input1);
240
241      Eltwise_res[dnnResourceMultipleSrc] =
242          (mkl_convert_input1) ? mkl_buf_convert_input1 : mkl_buf_input1;
243
244      const Tensor& input2 = MklGetInput(context, 1);
245      void* mkl_buf_input2 =
246          const_cast<void*>(static_cast<const void*>(input2.flat<T>().data()));
247      CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
248                   &mkl_lt_internal_input2, Eltwise, dnnResourceMultipleSrc2),
249               E_SUCCESS);
250      mkl_convert_input2 =
251          !dnnLayoutCompare_F32(mkl_lt_internal_input2, lt_input2);
252      if (mkl_convert_input2) {
253        CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input2, lt_input2,
254                                         mkl_lt_internal_input2),
255                 E_SUCCESS);
256        AllocTmpBuffer(context, mkl_tmp_input2_buf_tensor,
257                       mkl_lt_internal_input2, &mkl_buf_convert_input2);
258        CHECK_EQ(
259            dnnConversionExecute_F32(mkl_prim_convert_input2, mkl_buf_input2,
260                                     mkl_buf_convert_input2),
261            E_SUCCESS);
262        dnnDelete_F32(mkl_prim_convert_input2);
263      }
264      dnnLayoutDelete_F32(mkl_lt_internal_input2);
265
266      Eltwise_res[dnnResourceMultipleSrc2] =
267          (mkl_convert_input2) ? mkl_buf_convert_input2 : mkl_buf_input2;
268    }
269
270    void MklCleanup() {
271      bool input1_in_mkl_format = input1_shape.IsMklTensor();
272      bool input2_in_mkl_format = input2_shape.IsMklTensor();
273      dnnDelete_F32(Eltwise);
274      if (!input1_in_mkl_format || !input2_in_mkl_format) {
275        delete[] in_sizes;
276        delete[] in_strides;
277      }
278      if (!input1_in_mkl_format) {
279        dnnLayoutDelete_F32(lt_input1);
280      }
281      if (!input2_in_mkl_format) {
282        dnnLayoutDelete_F32(lt_input2);
283      }
284    }
285  } MklAddNOpContext;
286};
287
288#else  // INTEL_MKL_ML
289template <typename Device, typename T>
290class MklAddNOp : public OpKernel {
291 public:
292  ~MklAddNOp() {}
293  explicit MklAddNOp(OpKernelConstruction* context) : OpKernel(context) {}
294
295  void Compute(OpKernelContext* ctx) override {
296    const int num = ctx->num_inputs();
297    // Only additions of 2 input tensors is supported now
298    OP_REQUIRES(ctx, num / 2 == 2,
299                errors::InvalidArgument("Only additions of two tensors "
300                                        "supported by MKL. Num inputs: ",
301                                        num));
302
303    try {
304      auto cpu_engine = engine(engine::cpu, 0);
305      size_t src1_idx = 0, src2_idx = 1, output_idx = 0;
306      const Tensor& src1_tensor = MklGetInput(ctx, src1_idx);
307      const Tensor& src2_tensor = MklGetInput(ctx, src2_idx);
308
309      MklDnnShape src1_mkl_shape, src2_mkl_shape;
310      GetMklShape(ctx, src1_idx, &src1_mkl_shape);
311      GetMklShape(ctx, src2_idx, &src2_mkl_shape);
312      bool input1_in_mkl_format = src1_mkl_shape.IsMklTensor();
313      bool input2_in_mkl_format = src2_mkl_shape.IsMklTensor();
314      int src1_dims_size = input1_in_mkl_format ? src1_mkl_shape.GetDimension()
315                                                : src1_tensor.dims();
316      int src2_dims_size = input2_in_mkl_format ? src2_mkl_shape.GetDimension()
317                                                : src2_tensor.dims();
318      // if the shapes of two tensors are not same raise op error
319      TensorShape src1_shape, src2_shape;
320      src1_shape = input1_in_mkl_format ? src1_mkl_shape.GetTfShape()
321                                        : src1_tensor.shape();
322      src2_shape = input2_in_mkl_format ? src2_mkl_shape.GetTfShape()
323                                        : src2_tensor.shape();
324
325      if (!src1_shape.IsSameSize(src2_shape)) {
326        ctx->SetStatus(errors::InvalidArgument(
327            "Inputs to operation ", this->name(), " of type ",
328            this->type_string(),
329            " must have the same size and shape.  Input 0: ",
330            src1_shape.DebugString(),
331            " != input 1: ", src2_shape.DebugString()));
332      }
333
334      if (!input1_in_mkl_format && src1_dims_size == 0) {
335        Tensor* dst_tensor = nullptr;
336        MklShape mkl_shape_dst;
337        mkl_shape_dst.SetMklTensor(false);
338        AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
339                                  src1_tensor.shape(), mkl_shape_dst);
340        float user_i1 = (src1_tensor.scalar<T>()());
341        float user_i2 = (src2_tensor.scalar<T>()());
342        dst_tensor->scalar<T>()() = std::plus<float>{}(user_i1, user_i2);
343        return;
344      }
345
346      // If there is nothing to compute, return.
347      if (!input1_in_mkl_format && !input2_in_mkl_format) {
348        if (src1_tensor.shape().num_elements() == 0) {
349          Tensor* dst_tensor = nullptr;
350          MklShape mkl_shape_dst;
351          mkl_shape_dst.SetMklTensor(false);
352          AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
353                                    src1_tensor.shape(), mkl_shape_dst);
354          return;
355        }
356      }
357
358      std::vector<double> coeff(2, 1.0);
359      MklDnnData<T> src1(&cpu_engine);
360      MklDnnData<T> src2(&cpu_engine);
361      MklDnnData<T> dst(&cpu_engine);
362
363      int tmp_size = input1_in_mkl_format ? src2_dims_size : src1_dims_size;
364      memory::dims dims(tmp_size);
365      memory::dims strides(tmp_size);
366      memory::desc md1({}, memory::data_undef, memory::format_undef);
367      memory::desc md2({}, memory::data_undef, memory::format_undef);
368
369      // For creating Sum primitive, we need to ensure that all inputs are in
370      // same format. What that means is if we have a mixed input case - where
371      // one input is in Tensorflow format and one input is in MKL format -,
372      // then we need to ensure that all inputs are in same format for
373      // primitive construction. For performance reason, we say that all inputs
374      // are in MKL format in such case, and insert reorder for input that is
375      // in Tensorflow format into MKL format. On the other hand, if both the
376      // inputs are in MKL format or both are in Tensorflow format, then we
377      // dont need reorder.
378      if (!input1_in_mkl_format && !input2_in_mkl_format) {
379        // If both the inputs are in Tensorflow format, we create blocked memory
380        // descriptor.
381        dims = TFShapeToMklDnnDims(src1_tensor.shape());
382        strides = CalculateTFStrides(dims);
383        md1 = MklDnnData<T>::CreateBlockedMemDesc(dims, strides);
384        md2 = md1;
385      } else if (input1_in_mkl_format && !input2_in_mkl_format) {
386        // If one input is in MKL format and other is in Tensorflow, then
387        // create respective descriptors describing the actual case. For input
388        // in Mkl format, we just get Mkl layout from MklDnnShape. For input in
389        // Tensorflow format, we create memory descriptor using data format.
390        md1 = src1_mkl_shape.GetMklLayout();
391
392        memory::format src1_mkl_data_format = src1_mkl_shape.GetTfDataFormat();
393        auto src1_tf_data_format =
394            MklDnnDataFormatToTFDataFormat(src1_mkl_data_format);
395        auto src2_dims =
396            TFShapeToMklDnnDimsInNCHW(src2_tensor.shape(), src1_tf_data_format);
397        md2 = memory::desc(src2_dims, MklDnnType<T>(), src1_mkl_data_format);
398      } else if (input2_in_mkl_format && !input1_in_mkl_format) {
399        // Same comment as above.
400        memory::format src2_mkl_data_format = src2_mkl_shape.GetTfDataFormat();
401        auto src2_tf_data_format =
402            MklDnnDataFormatToTFDataFormat(src2_mkl_data_format);
403        auto src1_dims =
404            TFShapeToMklDnnDimsInNCHW(src1_tensor.shape(), src2_tf_data_format);
405        md1 = memory::desc(src1_dims, MklDnnType<T>(), src2_mkl_data_format);
406
407        md2 = src2_mkl_shape.GetMklLayout();
408      } else {
409        // If both the inputs are in MKL format, we use Mkl layout of the input
410        // tensors.
411        md1 = src1_mkl_shape.GetMklLayout();
412        md2 = src2_mkl_shape.GetMklLayout();
413      }
414      src1.SetUsrMem(md1, &src1_tensor);
415      src2.SetUsrMem(md2, &src2_tensor);
416
417      // As per comment above, we tell MKLDNN that both the inputs are in same
418      // format. So we set common memory descriptor in MKL format, if any of the
419      // inputs are in MKL format. Let's get memory descriptor that we will use
420      // for both the inputs.
421      // We set output memory descriptor in MKL format, if any of the
422      // inputs are in MKL format.
423      memory::desc common_md({}, memory::data_undef, memory::format_undef);
424      if (input1_in_mkl_format || input2_in_mkl_format) {
425        common_md = input1_in_mkl_format ? md1 : md2;
426        dst.SetUsrMem(common_md);
427      } else {
428        // Since both the inputs are in Tensorflow format, and have
429        // same shape, we can get memory descriptor from any input.
430        common_md = md1;
431        dst.SetUsrMem(common_md);
432      }
433
434      std::vector<memory::primitive_desc> srcs_pd;
435      // Memory descriptor for 1st input
436      srcs_pd.push_back(memory::primitive_desc(common_md, cpu_engine));
437      // Memory descriptor for 2nd input
438      srcs_pd.push_back(memory::primitive_desc(common_md, cpu_engine));
439      auto sum_pd = sum::primitive_desc(dst.GetUsrMemDesc(), coeff, srcs_pd);
440
441      // Now we setup resources for primitive execution.
442      // First, we need to check if any of the inputs need to be reordered as
443      // per the logic described above. Since output will be in MKL format if
444      // atleast one input is in MKL format, we choose output descriptor for
445      // reorder.
446      std::vector<primitive::at> inputs;
447      std::vector<primitive> net;
448      // Check if actual input format of the tensor is different than common_pd
449      // we told MKLDNN. In that case, we will need reorder.
450      src1.CheckReorderToOpMem(srcs_pd[0], &net);
451      src2.CheckReorderToOpMem(srcs_pd[1], &net);
452      inputs.push_back(src1.GetOpMem());
453      inputs.push_back(src2.GetOpMem());
454
455      // Allocate output tensor now.
456      Tensor* dst_tensor = nullptr;
457      MklDnnShape output_mkl_shape;
458      TensorShape output_tf_shape;
459
460      if (input2_in_mkl_format || input1_in_mkl_format) {
461        output_mkl_shape.SetMklTensor(true);
462        auto output_pd = dst.GetUsrMemPrimDesc();
463        output_mkl_shape.SetMklLayout(&output_pd);
464        output_mkl_shape.SetElemType(MklDnnType<T>());
465        if (input1_in_mkl_format) {
466          output_mkl_shape.SetTfLayout(src1_dims_size,
467                                       src1_mkl_shape.GetSizesAsMklDnnDims(),
468                                       src1_mkl_shape.GetTfDataFormat());
469        } else {
470          output_mkl_shape.SetTfLayout(src2_dims_size,
471                                       src2_mkl_shape.GetSizesAsMklDnnDims(),
472                                       src2_mkl_shape.GetTfDataFormat());
473        }
474        output_tf_shape.AddDim((output_pd.get_size() / sizeof(T)));
475      } else {
476        output_mkl_shape.SetMklTensor(false);
477        output_tf_shape = src1_tensor.shape();
478      }
479      AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor, output_tf_shape,
480                                output_mkl_shape);
481      dst.SetUsrMemDataHandle(dst_tensor);
482
483      // Create Sum op, and submit net for execution.
484      net.push_back(sum(sum_pd, inputs, dst.GetOpMem()));
485      stream(stream::kind::eager).submit(net).wait();
486    } catch (mkldnn::error& e) {
487      string error_msg = "Status: " + std::to_string(e.status) +
488                         ", message: " + string(e.message) + ", in file " +
489                         string(__FILE__) + ":" + std::to_string(__LINE__);
490      OP_REQUIRES_OK(
491          ctx, errors::Aborted("Operation received an exception:", error_msg));
492    }
493  }
494};
495
496#endif
497#define REGISTER_MKL_CPU(T)                                         \
498  REGISTER_KERNEL_BUILDER(Name("_MklAddN")                          \
499                              .Device(DEVICE_CPU)                   \
500                              .TypeConstraint<T>("T")               \
501                              .Label(mkl_op_registry::kMklOpLabel), \
502                          MklAddNOp<CPUDevice, T>);
503
504TF_CALL_float(REGISTER_MKL_CPU);
505#undef REGISTER_MKL_CPU
506}  // namespace tensorflow
507#endif  // INTEL_MKL
508