15eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
25eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
35eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank ChenLicensed under the Apache License, Version 2.0 (the "License");
45eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenyou may not use this file except in compliance with the License.
55eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank ChenYou may obtain a copy of the License at
65eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
75eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    http://www.apache.org/licenses/LICENSE-2.0
85eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
95eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank ChenUnless required by applicable law or agreed to in writing, software
105eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chendistributed under the License is distributed on an "AS IS" BASIS,
115eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank ChenWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
125eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank ChenSee the License for the specific language governing permissions and
135eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenlimitations under the License.
145eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen==============================================================================*/
155eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
165eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen#ifndef TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_
175eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen#define TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_
185eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen#ifdef INTEL_MKL
195eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
205eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen#include <string>
215eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen#include "tensorflow/core/framework/op_kernel.h"
225eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
235eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chennamespace tensorflow {
247149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// Since our ops are going to produce and also consume N addition tensors
257149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// (Mkl) for N Tensorflow tensors, we can have following different
267149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// orderings among these 2N tensors.
277149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower//
287149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// E.g., for Tensorflow tensors A, B, and C, our ops will produce and
297149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// consume A_m, B_m, and C_m additionally.
307149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower//
317149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// INTERLEAVED: in this case 2N tensors are interleaved. So for above
327149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower//              example, the ordering looks like: A, A_m, B, B_m, C, C_m.
337149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower//
347149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed
357149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower//             by N Mkl tensors. So for above example, the ordering looks
367149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower//             like: A, B, C, A_m, B_m, C_m
377149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower//
387149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// Following APIs map index of original Tensorflow tensors to their
397149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// appropriate position based on selected ordering. For contiguous ordering,
407149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// we need to know the total number of tensors (parameter total).
417149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower//
427149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlowertypedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering;
437149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// NOTE: Currently, we use contiguous ordering. If you change this, then you
447149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// would need to change Mkl op definitions in nn_ops.cc.
457149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlowerstatic MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
465eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
477149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// Get index of MetaData tensor from index 'n' of Data tensor.
487149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlowerinline int DataIndexToMetaDataIndex(int n, int total_tensors) {
497149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower  if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
507149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower    // For interleaved ordering, Mkl tensor follows immediately after
517149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower    // Tensorflow tensor.
527149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower    return n + 1;
537149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower  } else {
547149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower    CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
557149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower    // For contiguous ordering, Mkl tensor is n+total_tensors / 2 away.
567149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower    return n + total_tensors / 2;
575eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen  }
587149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower}
595eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
607149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlowerint inline GetTensorDataIndex(int n, int total_tensors) {
617149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower  if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
627149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower    return 2 * n;  // index corresponding to nth input/output tensor
637149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower  } else {
647149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower    CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
657149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower    return n;
667149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower  }
677149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower}
685eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
697149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlowerint inline GetTensorMetaDataIndex(int n, int total_tensors) {
707149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower  // Get index for TensorData first and then use mapping function
717149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower  // to get TensorMetaData index from TensorData index.
727149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower  int tidx = GetTensorDataIndex(n, total_tensors);
737149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower  return DataIndexToMetaDataIndex(tidx, total_tensors);
747149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower}
755eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
765eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chennamespace mkl_op_registry {
777149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlowerstatic const char* kMklOpLabel = "MklOp";
787149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlowerstatic const char* kMklOpLabelPattern = "label='MklOp'";
797149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// Prefix that we add to Tensorflow op name to construct Mkl op name.
807149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlowerstatic const char* const kMklOpPrefix = "_Mkl";
815eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
827149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// Get the name of Mkl op from original TensorFlow op
837149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// We prefix 'Mkl' to the original op to get Mkl op.
847149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlowerinline string GetMklOpName(const string& name) {
857149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower  return string(kMklOpPrefix) + name;
867149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower}
875eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
887149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// Check whether opname with type T is registered as MKL-compliant.
897149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower//
907149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// @input: name of the op
917149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// @input: T datatype to be used for checking op
927149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// @return: true if opname is registered as Mkl op; false otherwise
937149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlowerstatic inline bool IsMklOp(const std::string& op_name, DataType T) {
947149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower  string kernel = KernelsRegisteredForOp(op_name);
957149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower  bool result =
967149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower      kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT);
977149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower  return result;
987149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower}
995eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
1007149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// Check whether opname with type T is registered as MKL-compliant and
1017149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// is element-wise.
1027149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower//
1037149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// @input: name of the op
1047149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// @input: T datatype to be used for checking op
1057149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// @return: true if opname is registered as element-wise Mkl op;
1067149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower// false otherwise
1077149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlowerstatic inline bool IsMklElementWiseOp(const std::string& op_name, DataType T) {
1087149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower  if (!IsMklOp(op_name, T)) {
1097149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower    return false;
110fe8406149feec453250905965a14285465cd2063Shanqing Cai  }
1117149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower  bool result = (0 == op_name.compare(GetMklOpName("Add")) ||
1127149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower                 0 == op_name.compare(GetMklOpName("Sub")) ||
1137149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower                 0 == op_name.compare(GetMklOpName("Mul")) ||
1147149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower                 0 == op_name.compare(GetMklOpName("Maximum")) ||
1157149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower                 0 == op_name.compare(GetMklOpName("SquaredDifference")));
1167149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower
1177149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower  return result;
1187149a2e2e2f549035f23e21224ee41afe8df3876A. Unique TensorFlower}
1195eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen}  // namespace mkl_op_registry
1205eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen}  // namespace tensorflow
1215eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen#endif  // INTEL_MKL
1225eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen#endif  // TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_
123