1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Implements a quantized eight-bit version of the matmul operation.
17
18#define EIGEN_USE_THREADS
19
20#if defined(__ARM_NEON__) || defined(__ARM_NEON)
21#define USE_NEON
22#define QUANTIZED_ADD_USE_NEON
23#include <arm_neon.h>
24#endif
25
26#include "tensorflow/core/framework/op_kernel.h"
27#include "tensorflow/core/framework/tensor.h"
28#include "tensorflow/core/kernels/meta_support.h"
29#include "tensorflow/core/kernels/quantization_utils.h"
30#include "tensorflow/core/lib/core/casts.h"
31#include "tensorflow/core/lib/core/errors.h"
32#include "tensorflow/core/util/bcast.h"
33
34// There are implementations for three broadcast patterns for add:
35//  - Scalar * Array
36//  - Array * Array
37//  - Array * Shorter Array (repeated to match first)
38//
39// These handle a lot of common broadcast patterns, and we have NEON SIMD
40// versions to accelerate performance on ARM platforms.
41
42namespace tensorflow {
43namespace {
44
45template <class T, class Toutput>
46void ScalarAddition(OpKernelContext* context, const T* full_input,
47                    float full_input_min, float full_input_max,
48                    int64 num_elements, T scalar_input, float scalar_input_min,
49                    float scalar_input_max, float output_min, float output_max,
50                    Toutput* output) {
51  const Toutput scalar_in_output_range = RequantizeInNewRange<T, Toutput>(
52      scalar_input, scalar_input_min, scalar_input_max, output_min, output_max);
53  for (int i = 0; i < num_elements; ++i) {
54    const Toutput full_input_in_output_range = RequantizeInNewRange<T, Toutput>(
55        full_input[i], full_input_min, full_input_max, output_min, output_max);
56    output[i] = full_input_in_output_range + scalar_in_output_range;
57  }
58}
59
60#ifdef QUANTIZED_ADD_USE_NEON
61
62template <>
63void ScalarAddition(OpKernelContext* context, const quint8* full_input,
64                    float full_input_min, float full_input_max,
65                    int64 num_elements, quint8 scalar_input,
66                    float scalar_input_min, float scalar_input_max,
67                    float output_min, float output_max, qint32* output) {
68  const int32 scalar_in_output_range = RequantizeInNewRange<quint8, qint32>(
69      scalar_input, scalar_input_min, scalar_input_max, output_min, output_max);
70
71  const float input_0_float =
72      QuantizedToFloat<quint8>(0, full_input_min, full_input_max);
73  const float input_1_float =
74      QuantizedToFloat<quint8>(1, full_input_min, full_input_max);
75  const int64 input_0_int64 =
76      FloatToQuantizedUnclamped<qint32>(input_0_float, output_min, output_max);
77  const int64 input_1_int64 =
78      FloatToQuantizedUnclamped<qint32>(input_1_float, output_min, output_max);
79  const int32 input_mult_int32 = input_1_int64 - input_0_int64;
80
81  const int64 lowest_quantized =
82      static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
83  const int64 highest_quantized =
84      static_cast<int64>(Eigen::NumTraits<qint32>::highest());
85
86  const int64x2_t input_0_64x2 = vmovq_n_s64(input_0_int64);
87  const int32x2_t input_mult_32x2 = vmov_n_s32(input_mult_int32);
88  const int32x4_t scalar_in_output_range_32x4 =
89      vmovq_n_s32(scalar_in_output_range);
90  int64 i = 0;
91  for (; i < (num_elements - 7); i += 8) {
92    const uint8* full_input_ptr = &(full_input->value) + i;
93    const std::array<int32x4_t, 2> output_value =
94        Requantize8x8To32Neon(full_input_ptr, input_0_64x2, input_mult_32x2);
95    const int32x4_t result_low_32x4 =
96        vaddq_s32(output_value[0], scalar_in_output_range_32x4);
97    const int32x4_t result_high_32x4 =
98        vaddq_s32(output_value[1], scalar_in_output_range_32x4);
99    int32* output_ptr = &(output->value) + i;
100    vst1q_s32(output_ptr + 0, result_low_32x4);
101    vst1q_s32(output_ptr + 4, result_high_32x4);
102  }
103  for (; i < num_elements; ++i) {
104    const int64 full_input_value = static_cast<int64>(full_input[i]);
105    int64 full_input_in_output_range_64 =
106        input_0_int64 + (full_input_value * input_mult_int32);
107    full_input_in_output_range_64 =
108        std::max(full_input_in_output_range_64, lowest_quantized);
109    full_input_in_output_range_64 =
110        std::min(full_input_in_output_range_64, highest_quantized);
111    const int32 full_input_in_output_range =
112        static_cast<int32>(full_input_in_output_range_64);
113    output[i] = full_input_in_output_range + scalar_in_output_range;
114  }
115}
116
117#else  // QUANTIZED_ADD_USE_NEON
118
119template <>
120void ScalarAddition(OpKernelContext* context, const quint8* full_input,
121                    float full_input_min, float full_input_max,
122                    int64 num_elements, quint8 scalar_input,
123                    float scalar_input_min, float scalar_input_max,
124                    float output_min, float output_max, qint32* output) {
125  const int32 scalar_in_output_range = RequantizeInNewRange<quint8, qint32>(
126      scalar_input, scalar_input_min, scalar_input_max, output_min, output_max);
127
128  const float input_0_float =
129      QuantizedToFloat<quint8>(0, full_input_min, full_input_max);
130  const float input_1_float =
131      QuantizedToFloat<quint8>(1, full_input_min, full_input_max);
132  const int64 input_0_int64 =
133      FloatToQuantizedUnclamped<qint32>(input_0_float, output_min, output_max);
134  const int64 input_1_int64 =
135      FloatToQuantizedUnclamped<qint32>(input_1_float, output_min, output_max);
136  const int32 input_mult_int32 = input_1_int64 - input_0_int64;
137
138  const int64 lowest_quantized =
139      static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
140  const int64 highest_quantized =
141      static_cast<int64>(Eigen::NumTraits<qint32>::highest());
142
143  for (int i = 0; i < num_elements; ++i) {
144    const int64 full_input_value = static_cast<int64>(full_input[i]);
145    int64 full_input_in_output_range_64 =
146        input_0_int64 + (full_input_value * input_mult_int32);
147    full_input_in_output_range_64 =
148        std::max(full_input_in_output_range_64, lowest_quantized);
149    full_input_in_output_range_64 =
150        std::min(full_input_in_output_range_64, highest_quantized);
151    const int32 full_input_in_output_range =
152        static_cast<int32>(full_input_in_output_range_64);
153    output[i] = full_input_in_output_range + scalar_in_output_range;
154  }
155}
156
157#endif  // QUANTIZED_ADD_USE_NEON
158
159template <class T, class Toutput>
160void VectorAddition(OpKernelContext* context, const T* x_data, float min_x,
161                    float max_x, const T* y_data, float min_y, float max_y,
162                    int64 num_elements, float output_min, float output_max,
163                    Toutput* output) {
164  for (int i = 0; i < num_elements; ++i) {
165    const Toutput x_in_output_range = RequantizeInNewRange<T, Toutput>(
166        x_data[i], min_x, max_x, output_min, output_max);
167    const Toutput y_in_output_range = RequantizeInNewRange<T, Toutput>(
168        y_data[i], min_y, max_y, output_min, output_max);
169    output[i] = x_in_output_range + y_in_output_range;
170  }
171}
172
173#ifdef QUANTIZED_ADD_USE_NEON
174
175template <>
176void VectorAddition(OpKernelContext* context, const quint8* x_data, float min_x,
177                    float max_x, const quint8* y_data, float min_y, float max_y,
178                    int64 num_elements, float output_min, float output_max,
179                    qint32* output) {
180  const float x_0_float = QuantizedToFloat<quint8>(0, min_x, max_x);
181  const float x_1_float = QuantizedToFloat<quint8>(1, min_x, max_x);
182  const int64 x_0_int64 =
183      FloatToQuantizedUnclamped<qint32>(x_0_float, output_min, output_max);
184  const int64 x_1_int64 =
185      FloatToQuantizedUnclamped<qint32>(x_1_float, output_min, output_max);
186  const int32 x_mult_int32 = x_1_int64 - x_0_int64;
187
188  const float y_0_float = QuantizedToFloat<quint8>(0, min_y, max_y);
189  const float y_1_float = QuantizedToFloat<quint8>(1, min_y, max_y);
190  const int64 y_0_int64 =
191      FloatToQuantizedUnclamped<qint32>(y_0_float, output_min, output_max);
192  const int64 y_1_int64 =
193      FloatToQuantizedUnclamped<qint32>(y_1_float, output_min, output_max);
194  const int32 y_mult_int32 = y_1_int64 - y_0_int64;
195
196  const int64 lowest_quantized =
197      static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
198  const int64 highest_quantized =
199      static_cast<int64>(Eigen::NumTraits<qint32>::highest());
200
201  const int64x2_t x_0_64x2 = vmovq_n_s64(x_0_int64);
202  const int32x2_t x_mult_32x2 = vmov_n_s32(x_mult_int32);
203
204  const int64x2_t y_0_64x2 = vmovq_n_s64(y_0_int64);
205  const int32x2_t y_mult_32x2 = vmov_n_s32(y_mult_int32);
206
207  int64 i = 0;
208  for (; i < (num_elements - 7); i += 8) {
209    const uint8* x_ptr = &(x_data->value) + i;
210    const std::array<int32x4_t, 2> x_output_value =
211        Requantize8x8To32Neon(x_ptr, x_0_64x2, x_mult_32x2);
212    const uint8* y_ptr = &(y_data->value) + i;
213    const std::array<int32x4_t, 2> y_output_value =
214        Requantize8x8To32Neon(y_ptr, y_0_64x2, y_mult_32x2);
215
216    const int32x4_t result_low_32x4 =
217        vaddq_s32(x_output_value[0], y_output_value[0]);
218    const int32x4_t result_high_32x4 =
219        vaddq_s32(x_output_value[1], y_output_value[1]);
220    int32* output_ptr = &(output->value) + i;
221    vst1q_s32(output_ptr + 0, result_low_32x4);
222    vst1q_s32(output_ptr + 4, result_high_32x4);
223  }
224
225  for (; i < num_elements; ++i) {
226    const int64 x_value = static_cast<int64>(x_data[i]);
227    int64 x_in_output_range_64 = x_0_int64 + (x_value * x_mult_int32);
228    x_in_output_range_64 = std::max(x_in_output_range_64, lowest_quantized);
229    x_in_output_range_64 = std::min(x_in_output_range_64, highest_quantized);
230    const int32 x_in_output_range = static_cast<int32>(x_in_output_range_64);
231
232    const int64 y_value = static_cast<int64>(y_data[i]);
233    int64 y_in_output_range_64 = y_0_int64 + (y_value * y_mult_int32);
234    y_in_output_range_64 = std::max(y_in_output_range_64, lowest_quantized);
235    y_in_output_range_64 = std::min(y_in_output_range_64, highest_quantized);
236    const int32 y_in_output_range = static_cast<int32>(y_in_output_range_64);
237
238    output[i] = x_in_output_range + y_in_output_range;
239  }
240}
241
242#else  // QUANTIZED_ADD_USE_NEON
243
244template <>
245void VectorAddition(OpKernelContext* context, const quint8* x_data, float min_x,
246                    float max_x, const quint8* y_data, float min_y, float max_y,
247                    int64 num_elements, float output_min, float output_max,
248                    qint32* output) {
249  const float x_0_float = QuantizedToFloat<quint8>(0, min_x, max_x);
250  const float x_1_float = QuantizedToFloat<quint8>(1, min_x, max_x);
251  const int64 x_0_int64 =
252      FloatToQuantizedUnclamped<qint32>(x_0_float, output_min, output_max);
253  const int64 x_1_int64 =
254      FloatToQuantizedUnclamped<qint32>(x_1_float, output_min, output_max);
255  const int32 x_mult_int32 = x_1_int64 - x_0_int64;
256
257  const float y_0_float = QuantizedToFloat<quint8>(0, min_y, max_y);
258  const float y_1_float = QuantizedToFloat<quint8>(1, min_y, max_y);
259  const int64 y_0_int64 =
260      FloatToQuantizedUnclamped<qint32>(y_0_float, output_min, output_max);
261  const int64 y_1_int64 =
262      FloatToQuantizedUnclamped<qint32>(y_1_float, output_min, output_max);
263  const int32 y_mult_int32 = y_1_int64 - y_0_int64;
264
265  const int64 lowest_quantized =
266      static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
267  const int64 highest_quantized =
268      static_cast<int64>(Eigen::NumTraits<qint32>::highest());
269
270  for (int i = 0; i < num_elements; ++i) {
271    const int64 x_value = static_cast<int64>(x_data[i]);
272    int64 x_in_output_range_64 = x_0_int64 + (x_value * x_mult_int32);
273    x_in_output_range_64 = std::max(x_in_output_range_64, lowest_quantized);
274    x_in_output_range_64 = std::min(x_in_output_range_64, highest_quantized);
275    const int32 x_in_output_range = static_cast<int32>(x_in_output_range_64);
276
277    const int64 y_value = static_cast<int64>(y_data[i]);
278    int64 y_in_output_range_64 = y_0_int64 + (y_value * y_mult_int32);
279    y_in_output_range_64 = std::max(y_in_output_range_64, lowest_quantized);
280    y_in_output_range_64 = std::min(y_in_output_range_64, highest_quantized);
281    const int32 y_in_output_range = static_cast<int32>(y_in_output_range_64);
282
283    output[i] = x_in_output_range + y_in_output_range;
284  }
285}
286
287#endif  // QUANTIZED_ADD_USE_NEON
288
289template <class T, class Toutput>
290void VectorTensorAddition(const T* vector_data, float min_vector,
291                          float max_vector, int64 vector_num_elements,
292                          const T* tensor_data, float min_tensor,
293                          float max_tensor, int64 tensor_num_elements,
294                          float output_min, float output_max, Toutput* output) {
295  for (int i = 0; i < tensor_num_elements; ++i) {
296    const int64 vector_i = i % vector_num_elements;
297    const Toutput vector_in_output_range = RequantizeInNewRange<T, Toutput>(
298        vector_data[vector_i], min_vector, max_vector, output_min, output_max);
299    const Toutput tensor_in_output_range = RequantizeInNewRange<T, Toutput>(
300        tensor_data[i], min_tensor, max_tensor, output_min, output_max);
301    output[i] = vector_in_output_range + tensor_in_output_range;
302  }
303}
304
305#ifdef QUANTIZED_ADD_USE_NEON
306
307template <>
308void VectorTensorAddition(const quint8* vector_data, float min_vector,
309                          float max_vector, int64 vector_num_elements,
310                          const quint8* tensor_data, float min_tensor,
311                          float max_tensor, int64 tensor_num_elements,
312                          float output_min, float output_max, qint32* output) {
313  const float vector_0_float =
314      QuantizedToFloat<quint8>(0, min_vector, max_vector);
315  const float vector_1_float =
316      QuantizedToFloat<quint8>(1, min_vector, max_vector);
317  const int64 vector_0_int64 =
318      FloatToQuantizedUnclamped<qint32>(vector_0_float, output_min, output_max);
319  const int64 vector_1_int64 =
320      FloatToQuantizedUnclamped<qint32>(vector_1_float, output_min, output_max);
321  const int32 vector_mult_int32 = vector_1_int64 - vector_0_int64;
322
323  const float tensor_0_float =
324      QuantizedToFloat<quint8>(0, min_tensor, max_tensor);
325  const float tensor_1_float =
326      QuantizedToFloat<quint8>(1, min_tensor, max_tensor);
327  const int64 tensor_0_int64 =
328      FloatToQuantizedUnclamped<qint32>(tensor_0_float, output_min, output_max);
329  const int64 tensor_1_int64 =
330      FloatToQuantizedUnclamped<qint32>(tensor_1_float, output_min, output_max);
331  const int32 tensor_mult_int32 = tensor_1_int64 - tensor_0_int64;
332
333  const int64 lowest_quantized =
334      static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
335  const int64 highest_quantized =
336      static_cast<int64>(Eigen::NumTraits<qint32>::highest());
337
338  const int64x2_t vector_0_64x2 = vmovq_n_s64(vector_0_int64);
339  const int32x2_t vector_mult_32x2 = vmov_n_s32(vector_mult_int32);
340
341  const int64x2_t tensor_0_64x2 = vmovq_n_s64(tensor_0_int64);
342  const int32x2_t tensor_mult_32x2 = vmov_n_s32(tensor_mult_int32);
343
344  for (int64 base_i = 0; base_i < tensor_num_elements;
345       base_i += vector_num_elements) {
346    int64 i = base_i;
347    int64 vector_i = 0;
348    for (; vector_i < (vector_num_elements - 7); vector_i += 8, i += 8) {
349      const uint8* vector_ptr = &(vector_data->value) + vector_i;
350      const std::array<int32x4_t, 2> vector_output_value =
351          Requantize8x8To32Neon(vector_ptr, vector_0_64x2, vector_mult_32x2);
352      const uint8* tensor_ptr = &(tensor_data->value) + i;
353      const std::array<int32x4_t, 2> tensor_output_value =
354          Requantize8x8To32Neon(tensor_ptr, tensor_0_64x2, tensor_mult_32x2);
355
356      const int32x4_t result_low_32x4 =
357          vaddq_s32(vector_output_value[0], tensor_output_value[0]);
358      const int32x4_t result_high_32x4 =
359          vaddq_s32(vector_output_value[1], tensor_output_value[1]);
360      int32* output_ptr = &(output->value) + i;
361      vst1q_s32(output_ptr + 0, result_low_32x4);
362      vst1q_s32(output_ptr + 4, result_high_32x4);
363    }
364    for (; vector_i < vector_num_elements; ++vector_i, ++i) {
365      const int64 vector_value = static_cast<int64>(vector_data[vector_i]);
366      int64 vector_in_output_range_64 =
367          vector_0_int64 + (vector_value * vector_mult_int32);
368      vector_in_output_range_64 =
369          std::max(vector_in_output_range_64, lowest_quantized);
370      vector_in_output_range_64 =
371          std::min(vector_in_output_range_64, highest_quantized);
372      const int32 vector_in_output_range =
373          static_cast<int32>(vector_in_output_range_64);
374
375      const int64 tensor_value = static_cast<int64>(tensor_data[i]);
376      int64 tensor_in_output_range_64 =
377          tensor_0_int64 + (tensor_value * tensor_mult_int32);
378      tensor_in_output_range_64 =
379          std::max(tensor_in_output_range_64, lowest_quantized);
380      tensor_in_output_range_64 =
381          std::min(tensor_in_output_range_64, highest_quantized);
382      const int32 tensor_in_output_range =
383          static_cast<int32>(tensor_in_output_range_64);
384
385      output[i] = vector_in_output_range + tensor_in_output_range;
386    }
387  }
388}
389
390#else  // QUANTIZED_ADD_USE_NEON
391
392template <>
393void VectorTensorAddition(const quint8* vector_data, float min_vector,
394                          float max_vector, int64 vector_num_elements,
395                          const quint8* tensor_data, float min_tensor,
396                          float max_tensor, int64 tensor_num_elements,
397                          float output_min, float output_max, qint32* output) {
398  const float vector_0_float =
399      QuantizedToFloat<quint8>(0, min_vector, max_vector);
400  const float vector_1_float =
401      QuantizedToFloat<quint8>(1, min_vector, max_vector);
402  const int64 vector_0_int64 =
403      FloatToQuantizedUnclamped<qint32>(vector_0_float, output_min, output_max);
404  const int64 vector_1_int64 =
405      FloatToQuantizedUnclamped<qint32>(vector_1_float, output_min, output_max);
406  const int32 vector_mult_int32 = vector_1_int64 - vector_0_int64;
407
408  const float tensor_0_float =
409      QuantizedToFloat<quint8>(0, min_tensor, max_tensor);
410  const float tensor_1_float =
411      QuantizedToFloat<quint8>(1, min_tensor, max_tensor);
412  const int64 tensor_0_int64 =
413      FloatToQuantizedUnclamped<qint32>(tensor_0_float, output_min, output_max);
414  const int64 tensor_1_int64 =
415      FloatToQuantizedUnclamped<qint32>(tensor_1_float, output_min, output_max);
416  const int32 tensor_mult_int32 = tensor_1_int64 - tensor_0_int64;
417
418  const int64 lowest_quantized =
419      static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
420  const int64 highest_quantized =
421      static_cast<int64>(Eigen::NumTraits<qint32>::highest());
422
423  for (int i = 0; i < tensor_num_elements; ++i) {
424    const int64 vector_i = i % vector_num_elements;
425    const int64 vector_value = static_cast<int64>(vector_data[vector_i]);
426    int64 vector_in_output_range_64 =
427        vector_0_int64 + (vector_value * vector_mult_int32);
428    vector_in_output_range_64 =
429        std::max(vector_in_output_range_64, lowest_quantized);
430    vector_in_output_range_64 =
431        std::min(vector_in_output_range_64, highest_quantized);
432    const int32 vector_in_output_range =
433        static_cast<int32>(vector_in_output_range_64);
434
435    const int64 tensor_value = static_cast<int64>(tensor_data[i]);
436    int64 tensor_in_output_range_64 =
437        tensor_0_int64 + (tensor_value * tensor_mult_int32);
438    tensor_in_output_range_64 =
439        std::max(tensor_in_output_range_64, lowest_quantized);
440    tensor_in_output_range_64 =
441        std::min(tensor_in_output_range_64, highest_quantized);
442    const int32 tensor_in_output_range =
443        static_cast<int32>(tensor_in_output_range_64);
444
445    output[i] = vector_in_output_range + tensor_in_output_range;
446  }
447}
448
449#endif  // QUANTIZED_ADD_USE_NEON
450
451}  // namespace
452
453template <class T, class Toutput>
454class QuantizedAddOp : public OpKernel {
455 public:
456  explicit QuantizedAddOp(OpKernelConstruction* context) : OpKernel(context) {}
457
458  void Compute(OpKernelContext* context) override {
459    const Tensor& x = context->input(0);
460    const Tensor& y = context->input(1);
461    const float min_x = context->input(2).flat<float>()(0);
462    const float max_x = context->input(3).flat<float>()(0);
463    const float min_y = context->input(4).flat<float>()(0);
464    const float max_y = context->input(5).flat<float>()(0);
465
466    BCast bcast(BCast::FromShape(x.shape()), BCast::FromShape(y.shape()));
467    if (!bcast.IsValid()) {
468      context->SetStatus(errors::InvalidArgument(
469          "Incompatible shapes: ", x.shape().DebugString(), " vs. ",
470          y.shape().DebugString()));
471      return;
472    }
473    Tensor* z;
474    OP_REQUIRES_OK(context, context->allocate_output(
475                                0, BCast::ToShape(bcast.output_shape()), &z));
476
477    // Make sure that we have valid quantization ranges for the input buffers.
478    // If the difference between the min and max is negative or zero, it makes
479    // it hard to do meaningful intermediate operations on the values.
480    OP_REQUIRES(context, (max_x > min_x),
481                errors::InvalidArgument("max_x must be larger than min_x."));
482    OP_REQUIRES(context, (max_y > min_y),
483                errors::InvalidArgument("max_y must be larger than min_y."));
484    const T* x_data = x.flat<T>().data();
485    const T* y_data = y.flat<T>().data();
486    Toutput* z_data = z->flat<Toutput>().data();
487
488    // We want the range of the output to be symmetrical around zero so that
489    // adding zero leaves the result unchanged, and to contain the largest of
490    // the two input values with some room to spare.
491    const float smallest_min = std::min(min_x, min_y);
492    const float largest_max = std::max(max_x, max_y);
493    const float biggest_range =
494        std::max(std::abs(smallest_min), std::abs(largest_max));
495    const float output_range = (biggest_range * (1 << 14));
496    const float min_z_value = -output_range;
497    const float max_z_value = output_range;
498
499    const int ndims = bcast.x_reshape().size();
500    if (ndims <= 1) {
501      if (x.NumElements() == 1) {
502        ScalarAddition<T, Toutput>(context, y_data, min_y, max_y,
503                                   y.NumElements(), x_data[0], min_x, max_x,
504                                   min_z_value, max_z_value, z_data);
505      } else if (y.NumElements() == 1) {
506        ScalarAddition<T, Toutput>(context, x_data, min_x, max_x,
507                                   x.NumElements(), y_data[0], min_y, max_y,
508                                   min_z_value, max_z_value, z_data);
509      } else {
510        VectorAddition<T, Toutput>(context, x_data, min_x, max_x, y_data, min_y,
511                                   max_y, x.NumElements(), min_z_value,
512                                   max_z_value, z_data);
513      }
514    } else if (ndims == 2) {
515      const T* vector_data;
516      int64 vector_num_elements;
517      float vector_min;
518      float vector_max;
519      const T* tensor_data;
520      int64 tensor_num_elements;
521      float tensor_min;
522      float tensor_max;
523      if (x.NumElements() < y.NumElements()) {
524        vector_data = x_data;
525        vector_num_elements = x.NumElements();
526        vector_min = min_x;
527        vector_max = max_x;
528        tensor_data = y_data;
529        tensor_num_elements = y.NumElements();
530        tensor_min = min_y;
531        tensor_max = max_y;
532      } else {
533        vector_data = y_data;
534        vector_num_elements = y.NumElements();
535        vector_min = min_y;
536        vector_max = max_y;
537        tensor_data = x_data;
538        tensor_num_elements = x.NumElements();
539        tensor_min = min_x;
540        tensor_max = max_x;
541      }
542      VectorTensorAddition<T, Toutput>(
543          vector_data, vector_min, vector_max, vector_num_elements, tensor_data,
544          tensor_min, tensor_max, tensor_num_elements, min_z_value, max_z_value,
545          z_data);
546    } else {
547      LOG(INFO) << "ndims=" << ndims;
548      LOG(INFO) << "bcast.x_reshape()="
549                << TensorShape(bcast.x_reshape()).DebugString();
550      LOG(INFO) << "bcast.y_reshape()="
551                << TensorShape(bcast.y_reshape()).DebugString();
552      LOG(INFO) << "bcast.x_bcast()="
553                << TensorShape(bcast.x_bcast()).DebugString();
554      LOG(INFO) << "bcast.y_bcast()="
555                << TensorShape(bcast.y_bcast()).DebugString();
556
557      context->SetStatus(errors::Unimplemented(
558          "Broadcast between ", context->input(0).shape().DebugString(),
559          " and ", context->input(1).shape().DebugString(),
560          " is not supported yet."));
561      return;
562    }
563
564    Tensor* z_min = nullptr;
565    OP_REQUIRES_OK(context, context->allocate_output(1, {}, &z_min));
566    z_min->flat<float>()(0) = min_z_value;
567
568    Tensor* z_max = nullptr;
569    OP_REQUIRES_OK(context, context->allocate_output(2, {}, &z_max));
570    z_max->flat<float>()(0) = max_z_value;
571  }
572};
573
574REGISTER_KERNEL_BUILDER(Name("QuantizedAdd")
575                            .Device(DEVICE_CPU)
576                            .TypeConstraint<quint8>("T1")
577                            .TypeConstraint<quint8>("T2")
578                            .TypeConstraint<qint32>("Toutput"),
579                        QuantizedAddOp<quint8, qint32>);
580
581}  // namespace tensorflow
582