1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/contrib/lite/builtin_op_data.h"
16#include "tensorflow/contrib/lite/context.h"
17#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
18#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
19#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
20#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
21#include "tensorflow/contrib/lite/kernels/kernel_util.h"
22#include "tensorflow/contrib/lite/kernels/op_macros.h"
23
24namespace tflite {
25namespace ops {
26namespace builtin {
27namespace mul {
28
29// This file has three implementation of Mul.
30enum KernelType {
31  kReference,
32  kGenericOptimized,  // Neon-free
33  kNeonOptimized,
34};
35
36constexpr int kInputTensor1 = 0;
37constexpr int kInputTensor2 = 1;
38constexpr int kOutputTensor = 0;
39
40struct OpData {
41  bool requires_broadcast;
42};
43
44void* Init(TfLiteContext* context, const char* buffer, size_t length) {
45  auto* data = new OpData;
46  data->requires_broadcast = false;
47  return data;
48}
49
50void Free(TfLiteContext* context, void* buffer) {
51  delete reinterpret_cast<OpData*>(buffer);
52}
53
54TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
55  OpData* data = reinterpret_cast<OpData*>(node->user_data);
56
57  TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
58  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
59
60  TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
61  TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
62  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
63
64  TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
65  output->type = input2->type;
66
67  data->requires_broadcast = !HaveSameShapes(input1, input2);
68
69  TfLiteIntArray* output_size = nullptr;
70  if (data->requires_broadcast) {
71    TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
72                                   context, input1, input2, &output_size));
73  } else {
74    output_size = TfLiteIntArrayCopy(input1->dims);
75  }
76
77  return context->ResizeTensor(context, output, output_size);
78}
79
80template <KernelType kernel_type>
81void EvalFloat(TfLiteContext* context, TfLiteNode* node,
82               TfLiteMulParams* params, const OpData* data,
83               TfLiteTensor* input1, TfLiteTensor* input2,
84               TfLiteTensor* output) {
85  float output_activation_min, output_activation_max;
86  CalculateActivationRangeFloat(params->activation, &output_activation_min,
87                                &output_activation_max);
88#define TF_LITE_MUL(type, opname)                                   \
89  type::opname(GetTensorData<float>(input1), GetTensorDims(input1), \
90               GetTensorData<float>(input2), GetTensorDims(input2), \
91               output_activation_min, output_activation_max,        \
92               GetTensorData<float>(output), GetTensorDims(output))
93  if (kernel_type == kReference) {
94    if (data->requires_broadcast) {
95      TF_LITE_MUL(reference_ops, BroadcastMul);
96    } else {
97      TF_LITE_MUL(reference_ops, Mul);
98    }
99  } else {
100    if (data->requires_broadcast) {
101      TF_LITE_MUL(optimized_ops, BroadcastMul);
102    } else {
103      TF_LITE_MUL(optimized_ops, Mul);
104    }
105  }
106#undef TF_LITE_MUL
107}
108
109template <KernelType kernel_type>
110void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
111                   TfLiteMulParams* params, const OpData* data,
112                   TfLiteTensor* input1, TfLiteTensor* input2,
113                   TfLiteTensor* output) {
114  auto input1_offset = -input1->params.zero_point;
115  auto input2_offset = -input2->params.zero_point;
116  auto output_offset = output->params.zero_point;
117
118  int32_t output_multiplier;
119  int output_shift;
120
121  double real_multiplier =
122      input1->params.scale * input2->params.scale / output->params.scale;
123  QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier,
124                                   &output_shift);
125
126  int32 output_activation_min, output_activation_max;
127  CalculateActivationRangeUint8(params->activation, output,
128                                &output_activation_min, &output_activation_max);
129
130#define TF_LITE_MUL(type, opname)                                      \
131  type::opname(GetTensorData<uint8_t>(input1), GetTensorDims(input1),  \
132               input1_offset, GetTensorData<uint8_t>(input2),          \
133               GetTensorDims(input2), input2_offset, output_offset,    \
134               output_multiplier, output_shift, output_activation_min, \
135               output_activation_max, GetTensorData<uint8_t>(output),  \
136               GetTensorDims(output));
137  // The quantized version of Mul doesn't support activations, so we
138  // always use BroadcastMul.
139  if (kernel_type == kReference) {
140    TF_LITE_MUL(reference_ops, BroadcastMul);
141  } else {
142    TF_LITE_MUL(optimized_ops, BroadcastMul);
143  }
144#undef TF_LITE_MUL
145}
146
147template <KernelType kernel_type>
148TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
149  auto* params = reinterpret_cast<TfLiteMulParams*>(node->builtin_data);
150  OpData* data = reinterpret_cast<OpData*>(node->user_data);
151
152  TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
153  TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
154  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
155
156  if (output->type == kTfLiteFloat32) {
157    EvalFloat<kernel_type>(context, node, params, data, input1, input2, output);
158  } else if (output->type == kTfLiteUInt8) {
159    EvalQuantized<kernel_type>(context, node, params, data, input1, input2,
160                               output);
161  } else {
162    context->ReportError(context,
163                         "Mul only supports FLOAT32 and quantized UINT8 now.");
164    return kTfLiteError;
165  }
166
167  return kTfLiteOk;
168}
169
170}  // namespace mul
171
172TfLiteRegistration* Register_MUL_REF() {
173  static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare,
174                                 mul::Eval<mul::kReference>};
175  return &r;
176}
177
178TfLiteRegistration* Register_MUL_GENERIC_OPT() {
179  static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare,
180                                 mul::Eval<mul::kGenericOptimized>};
181  return &r;
182}
183
184TfLiteRegistration* Register_MUL_NEON_OPT() {
185  static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare,
186                                 mul::Eval<mul::kNeonOptimized>};
187  return &r;
188}
189
190TfLiteRegistration* Register_MUL() {
191#ifdef USE_NEON
192  return Register_MUL_NEON_OPT();
193#else
194  return Register_MUL_GENERIC_OPT();
195#endif
196}
197
198}  // namespace builtin
199}  // namespace ops
200}  // namespace tflite
201