1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include "tensorflow/contrib/lite/builtin_op_data.h" 16#include "tensorflow/contrib/lite/context.h" 17#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" 18#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" 19#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" 20#include "tensorflow/contrib/lite/kernels/internal/tensor.h" 21#include "tensorflow/contrib/lite/kernels/kernel_util.h" 22#include "tensorflow/contrib/lite/kernels/op_macros.h" 23 24namespace tflite { 25namespace ops { 26namespace builtin { 27namespace mul { 28 29// This file has three implementation of Mul. 30enum KernelType { 31 kReference, 32 kGenericOptimized, // Neon-free 33 kNeonOptimized, 34}; 35 36constexpr int kInputTensor1 = 0; 37constexpr int kInputTensor2 = 1; 38constexpr int kOutputTensor = 0; 39 40struct OpData { 41 bool requires_broadcast; 42}; 43 44void* Init(TfLiteContext* context, const char* buffer, size_t length) { 45 auto* data = new OpData; 46 data->requires_broadcast = false; 47 return data; 48} 49 50void Free(TfLiteContext* context, void* buffer) { 51 delete reinterpret_cast<OpData*>(buffer); 52} 53 54TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 55 OpData* data = reinterpret_cast<OpData*>(node->user_data); 56 57 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); 58 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); 59 60 TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); 61 TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); 62 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 63 64 TF_LITE_ENSURE_EQ(context, input1->type, input2->type); 65 output->type = input2->type; 66 67 data->requires_broadcast = !HaveSameShapes(input1, input2); 68 69 TfLiteIntArray* output_size = nullptr; 70 if (data->requires_broadcast) { 71 TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast( 72 context, input1, input2, &output_size)); 73 } else { 74 output_size = TfLiteIntArrayCopy(input1->dims); 75 } 76 77 return context->ResizeTensor(context, output, output_size); 78} 79 80template <KernelType kernel_type> 81void EvalFloat(TfLiteContext* context, TfLiteNode* node, 82 TfLiteMulParams* params, const OpData* data, 83 TfLiteTensor* input1, TfLiteTensor* input2, 84 TfLiteTensor* output) { 85 float output_activation_min, output_activation_max; 86 CalculateActivationRangeFloat(params->activation, &output_activation_min, 87 &output_activation_max); 88#define TF_LITE_MUL(type, opname) \ 89 type::opname(GetTensorData<float>(input1), GetTensorDims(input1), \ 90 GetTensorData<float>(input2), GetTensorDims(input2), \ 91 output_activation_min, output_activation_max, \ 92 GetTensorData<float>(output), GetTensorDims(output)) 93 if (kernel_type == kReference) { 94 if (data->requires_broadcast) { 95 TF_LITE_MUL(reference_ops, BroadcastMul); 96 } else { 97 TF_LITE_MUL(reference_ops, Mul); 98 } 99 } else { 100 if (data->requires_broadcast) { 101 TF_LITE_MUL(optimized_ops, BroadcastMul); 102 } else { 103 TF_LITE_MUL(optimized_ops, Mul); 104 } 105 } 106#undef TF_LITE_MUL 107} 108 109template <KernelType kernel_type> 110void EvalQuantized(TfLiteContext* context, TfLiteNode* node, 111 TfLiteMulParams* params, const OpData* data, 112 TfLiteTensor* input1, TfLiteTensor* input2, 113 TfLiteTensor* output) { 114 auto input1_offset = -input1->params.zero_point; 115 auto input2_offset = -input2->params.zero_point; 116 auto output_offset = output->params.zero_point; 117 118 int32_t output_multiplier; 119 int output_shift; 120 121 double real_multiplier = 122 input1->params.scale * input2->params.scale / output->params.scale; 123 QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier, 124 &output_shift); 125 126 int32 output_activation_min, output_activation_max; 127 CalculateActivationRangeUint8(params->activation, output, 128 &output_activation_min, &output_activation_max); 129 130#define TF_LITE_MUL(type, opname) \ 131 type::opname(GetTensorData<uint8_t>(input1), GetTensorDims(input1), \ 132 input1_offset, GetTensorData<uint8_t>(input2), \ 133 GetTensorDims(input2), input2_offset, output_offset, \ 134 output_multiplier, output_shift, output_activation_min, \ 135 output_activation_max, GetTensorData<uint8_t>(output), \ 136 GetTensorDims(output)); 137 // The quantized version of Mul doesn't support activations, so we 138 // always use BroadcastMul. 139 if (kernel_type == kReference) { 140 TF_LITE_MUL(reference_ops, BroadcastMul); 141 } else { 142 TF_LITE_MUL(optimized_ops, BroadcastMul); 143 } 144#undef TF_LITE_MUL 145} 146 147template <KernelType kernel_type> 148TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 149 auto* params = reinterpret_cast<TfLiteMulParams*>(node->builtin_data); 150 OpData* data = reinterpret_cast<OpData*>(node->user_data); 151 152 TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); 153 TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); 154 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 155 156 if (output->type == kTfLiteFloat32) { 157 EvalFloat<kernel_type>(context, node, params, data, input1, input2, output); 158 } else if (output->type == kTfLiteUInt8) { 159 EvalQuantized<kernel_type>(context, node, params, data, input1, input2, 160 output); 161 } else { 162 context->ReportError(context, 163 "Mul only supports FLOAT32 and quantized UINT8 now."); 164 return kTfLiteError; 165 } 166 167 return kTfLiteOk; 168} 169 170} // namespace mul 171 172TfLiteRegistration* Register_MUL_REF() { 173 static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare, 174 mul::Eval<mul::kReference>}; 175 return &r; 176} 177 178TfLiteRegistration* Register_MUL_GENERIC_OPT() { 179 static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare, 180 mul::Eval<mul::kGenericOptimized>}; 181 return &r; 182} 183 184TfLiteRegistration* Register_MUL_NEON_OPT() { 185 static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare, 186 mul::Eval<mul::kNeonOptimized>}; 187 return &r; 188} 189 190TfLiteRegistration* Register_MUL() { 191#ifdef USE_NEON 192 return Register_MUL_NEON_OPT(); 193#else 194 return Register_MUL_GENERIC_OPT(); 195#endif 196} 197 198} // namespace builtin 199} // namespace ops 200} // namespace tflite 201