1a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
3a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License");
4a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudluryou may not use this file except in compliance with the License.
5a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath KudlurYou may obtain a copy of the License at
6a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
7a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    http://www.apache.org/licenses/LICENSE-2.0
8a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
9a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath KudlurUnless required by applicable law or agreed to in writing, software
10a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS,
11a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath KudlurSee the License for the specific language governing permissions and
13a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlurlimitations under the License.
14a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur==============================================================================*/
15a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
16a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#define EIGEN_USE_THREADS
17a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
18a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#if defined(__ARM_NEON__) || defined(__ARM_NEON)
19a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#define USE_NEON
20a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#include <arm_neon.h>
21a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#endif
22a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
23a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#include "tensorflow/core/framework/numeric_op.h"
25a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#include "tensorflow/core/framework/op_kernel.h"
26a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#include "tensorflow/core/framework/register_types.h"
27a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#include "tensorflow/core/framework/tensor.h"
28a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
29a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#include "tensorflow/core/kernels/quantization_utils.h"
30a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
31a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#ifdef USE_NEON
32a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlurnamespace {
33a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
34a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// Single pass mean and variance.
35a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// Shape of `input` is [rows x cols], shape of both `mean` and `variance`
36a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// is [cols].
37a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// Note, `mean` and `variance` are of 'i' (not scaled).
38a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// The following is a straightforward implementation of the parallel algorithm
39a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// described in
40a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
41a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlurvoid ColMeanAndVariance(const uint8_t* input, const uint32_t rows,
42a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                        const uint32_t cols, float* mean, float* variance) {
43a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  // The implementation operates on for 16 columns at a time.
44a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  // Assumes cols % 16 == 0
45a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) {
46a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    // Vector registers to track the running sum across the rows. Since there
47a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    // are 16 columns, we have 4 32x4 registers.
48a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    uint32x4_t sum[4] = {0};
49a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
50a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    float nA = 0.0f;
51a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    // Running average and the second moment.
52a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    float32x4_t xA[4] = {0.0f};
53a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    float32x4_t M2A[4] = {0.0f};
54a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
55a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    const uint8_t* inp_ptr = input + col_offset;
56a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    // Go over the rows in chunks of 256. This is so that we can use 16 bit adds
57a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    // to do the accumulation.
58a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    for (uint32_t row = 0; row < rows; row += 256) {
59a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      // Running sum and sum of squares for the 256 rows.
60a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      uint32x4_t sub_sum[4] = {0};
61a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      uint32x4_t sub_sq_sum[4] = {0};
62a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      const uint32_t limit = std::min(rows, row + 256);
63a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      const float nB = limit - row;
64a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      for (uint32_t subrow = row; subrow < limit; ++subrow) {
65a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        const uint8x16_t v = vld1q_u8(inp_ptr);
66a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        inp_ptr += cols;
67a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
68a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        const uint8x8_t v_high = vget_high_u8(v);
69a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        const uint8x8_t v_low = vget_low_u8(v);
70a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
71a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        const uint16x8_t v_high_u16 = vmovl_u8(v_high);
72a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        const uint16x8_t v_low_u16 = vmovl_u8(v_low);
73a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
74a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        const uint16x4_t v_high_high = vget_high_u16(v_high_u16);
75a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        const uint16x4_t v_high_low = vget_low_u16(v_high_u16);
76a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        const uint16x4_t v_low_high = vget_high_u16(v_low_u16);
77a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        const uint16x4_t v_low_low = vget_low_u16(v_low_u16);
78a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
79a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        sub_sum[0] = vaddw_u16(sub_sum[0], v_high_high);
80a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        sub_sum[1] = vaddw_u16(sub_sum[1], v_high_low);
81a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        sub_sum[2] = vaddw_u16(sub_sum[2], v_low_high);
82a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        sub_sum[3] = vaddw_u16(sub_sum[3], v_low_low);
83a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
84a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        sub_sq_sum[0] = vmlal_u16(sub_sq_sum[0], v_high_high, v_high_high);
85a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        sub_sq_sum[1] = vmlal_u16(sub_sq_sum[1], v_high_low, v_high_low);
86a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        sub_sq_sum[2] = vmlal_u16(sub_sq_sum[2], v_low_high, v_low_high);
87a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        sub_sq_sum[3] = vmlal_u16(sub_sq_sum[3], v_low_low, v_low_low);
88a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      }
89a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
90a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      // Update the full running sum and moment from the ones for 256 rows.
91a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      for (int i = 0; i < 4; ++i) {
92a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        sum[i] = vaddq_u32(sum[i], sub_sum[i]);
93a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        const float nX = nA + nB;
94a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        // xB is the average of up to 256 elements.
95a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        const float32x4_t xB =
96a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur            vmulq_n_f32(vcvtq_f32_u32(sub_sum[i]), 1.0f / nB);
97a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
98a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        // delta = xB - xA
99a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        const float32x4_t delta = vsubq_f32(xB, xA[i]);
100a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        // xA = (nA * xA + nB * xB) / (nA + nB)
101a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        xA[i] = vmulq_n_f32(
102a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur            vaddq_f32(vmulq_n_f32(xA[i], nA), vmulq_n_f32(xB, nB)), 1.0f / nX);
103a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
104a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        const float32x4_t sub_sum_f32 = vcvtq_f32_u32(sub_sum[i]);
105a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        const float32x4_t sub_sum_sq = vmulq_f32(sub_sum_f32, sub_sum_f32);
106a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
107a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        // M2B = sum(xB^2) - sum(xB)^2/nB
108a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        const float32x4_t M2B = vsubq_f32(vcvtq_f32_u32(sub_sq_sum[i]),
109a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                                          vmulq_n_f32(sub_sum_sq, 1.0f / nB));
110a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        const float32x4_t last_term =
111a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur            vmulq_n_f32(vmulq_f32(delta, delta), nA * nB / nX);
112a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        // M2A = oldM2A + M2B + delta^2 * nA*nB/nX
113a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        M2A[i] = vaddq_f32(vaddq_f32(M2A[i], M2B), last_term);
114a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      }
115a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      nA += limit;
116a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    }
117a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
118a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    // Write the final mean and variance for the 16 columns.
119a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    const float inv_rows = 1.0f / static_cast<float>(rows);
120a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    vst1q_f32(mean + col_offset, vmulq_n_f32(vcvtq_f32_u32(sum[3]), inv_rows));
121a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    vst1q_f32(mean + col_offset + 4,
122a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur              vmulq_n_f32(vcvtq_f32_u32(sum[2]), inv_rows));
123a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    vst1q_f32(mean + col_offset + 8,
124a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur              vmulq_n_f32(vcvtq_f32_u32(sum[1]), inv_rows));
125a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    vst1q_f32(mean + col_offset + 12,
126a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur              vmulq_n_f32(vcvtq_f32_u32(sum[0]), inv_rows));
127a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
128a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    vst1q_f32(variance + col_offset, vmulq_n_f32(M2A[3], inv_rows));
129a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    vst1q_f32(variance + col_offset + 4, vmulq_n_f32(M2A[2], inv_rows));
130a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    vst1q_f32(variance + col_offset + 8, vmulq_n_f32(M2A[1], inv_rows));
131a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    vst1q_f32(variance + col_offset + 12, vmulq_n_f32(M2A[0], inv_rows));
132a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  }
133a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur}
134a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
135a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// Compute min and max of (input - mean) / sqrt(variance + epsilon).
136a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// This is done in a separate pass so that the normalized value can be
137a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// temporarily computed in floating point precision and not stored anywhere.
138a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlurvoid MinAndMax(const uint8_t* input, const uint32_t rows, const uint32_t cols,
139a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur               const float* mean_ptr, const float* variance_ptr,
140a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur               float variance_epsilon, float* minimum, float* maximum) {
141a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  float v_maximum = std::numeric_limits<float>::min();
142a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  float v_minimum = std::numeric_limits<float>::max();
143a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  const float32x4_t eps = vdupq_n_f32(variance_epsilon);
144a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
145a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) {
146a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    const float32x4_t mean[4] = {vld1q_f32(mean_ptr + col_offset),
147a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                                 vld1q_f32(mean_ptr + col_offset + 4),
148a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                                 vld1q_f32(mean_ptr + col_offset + 8),
149a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                                 vld1q_f32(mean_ptr + col_offset + 12)};
150a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    const float32x4_t variance[4] = {vld1q_f32(variance_ptr + col_offset),
151a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                                     vld1q_f32(variance_ptr + col_offset + 4),
152a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                                     vld1q_f32(variance_ptr + col_offset + 8),
153a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                                     vld1q_f32(variance_ptr + col_offset + 12)};
154a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    const float32x4_t inv_stddev[4] = {
155a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        vrsqrteq_f32(vaddq_f32(variance[0], eps)),
156a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        vrsqrteq_f32(vaddq_f32(variance[1], eps)),
157a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        vrsqrteq_f32(vaddq_f32(variance[2], eps)),
158a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        vrsqrteq_f32(vaddq_f32(variance[3], eps))};
159a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
160a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    const uint8_t* inp_ptr = input + col_offset;
161a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    for (uint32_t row = 0; row < rows; ++row) {
162a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      const uint8x16_t v = vld1q_u8(inp_ptr);
163a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      inp_ptr += cols;
164a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
165a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      const uint16x8_t v_high = vmovl_u8(vget_high_u8(v));
166a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      const uint16x8_t v_low = vmovl_u8(vget_low_u8(v));
167a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
168a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      const float32x4_t v_float[4] = {
169a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur          vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_high))),
170a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur          vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_high))),
171a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur          vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_low))),
172a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur          vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_low)))};
173a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
174a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      for (int i = 0; i < 4; ++i) {
175a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        const float32x4_t normed =
176a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur            vmulq_f32(vsubq_f32(v_float[i], mean[i]), inv_stddev[i]);
177a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        const float32x2_t high = vget_high_f32(normed);
178a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        const float32x2_t low = vget_low_f32(normed);
179a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        float32x2_t tmp_max = vpmax_f32(low, high);
180a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        tmp_max = vpmax_f32(tmp_max, tmp_max);
181a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        v_maximum = std::max(v_maximum, vget_lane_f32(tmp_max, 0));
182a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        float32x2_t tmp_min = vpmin_f32(low, high);
183a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        tmp_min = vpmin_f32(tmp_min, tmp_min);
184a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        v_minimum = std::min(v_minimum, vget_lane_f32(tmp_min, 0));
185a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      }
186a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    }
187a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  }
188a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  *minimum = v_minimum;
189a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  *maximum = v_maximum;
190a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur}
191a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
192a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// Compute (input - mean) / sqrt(variance + epsilon) in floating point, quantize
193a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// it in the range (minimum, maximum) and store the result as quint8.
194a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlurvoid InstanceNorm(const uint8_t* input, const uint32_t rows,
195a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                  const uint32_t cols, const float* mean_ptr,
196a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                  const float* variance_ptr, float variance_epsilon,
197a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                  float minimum, float maximum, uint8_t* output) {
198a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  const float32x4_t eps = vdupq_n_f32(variance_epsilon);
199a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  const float32x4_t out_min = vdupq_n_f32(minimum);
200a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  const float out_scale = 255.0f / (maximum - minimum);
201a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
202a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) {
203a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    const float32x4_t mean[4] = {vld1q_f32(mean_ptr + col_offset + 12),
204a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                                 vld1q_f32(mean_ptr + col_offset + 8),
205a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                                 vld1q_f32(mean_ptr + col_offset + 4),
206a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                                 vld1q_f32(mean_ptr + col_offset)};
207a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    const float32x4_t variance[4] = {vld1q_f32(variance_ptr + col_offset + 12),
208a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                                     vld1q_f32(variance_ptr + col_offset + 8),
209a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                                     vld1q_f32(variance_ptr + col_offset + 4),
210a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                                     vld1q_f32(variance_ptr + col_offset)};
211a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    const float32x4_t inv_stddev[4] = {
212a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        vrsqrteq_f32(vaddq_f32(variance[0], eps)),
213a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        vrsqrteq_f32(vaddq_f32(variance[1], eps)),
214a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        vrsqrteq_f32(vaddq_f32(variance[2], eps)),
215a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        vrsqrteq_f32(vaddq_f32(variance[3], eps))};
216a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    const uint8_t* inp_ptr = input + col_offset;
217a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    uint8_t* out_ptr = output + col_offset;
218a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    for (uint32_t row = 0; row < rows; ++row) {
219a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      const uint8x16_t v = vld1q_u8(inp_ptr);
220a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      inp_ptr += cols;
221a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      const uint16x8_t v_high = vmovl_u8(vget_high_u8(v));
222a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      const uint16x8_t v_low = vmovl_u8(vget_low_u8(v));
223a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
224a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      const float32x4_t v_float[4] = {
225a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur          vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_high))),
226a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur          vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_high))),
227a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur          vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_low))),
228a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur          vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_low)))};
229a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
230a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      uint16x4_t normed_uint16[4];
231a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      for (int i = 0; i < 4; ++i) {
232a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        const float32x4_t normed =
233a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur            vmulq_f32(vsubq_f32(v_float[i], mean[i]), inv_stddev[i]);
234a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        const int32x4_t normed_int32 =
235a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur            vcvtq_s32_f32(vmulq_n_f32(vsubq_f32(normed, out_min), out_scale));
236a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        normed_uint16[i] = vqmovun_s32(normed_int32);
237a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      }
238a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      vst1_u8(out_ptr,
239a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur              vqmovn_u16(vcombine_u16(normed_uint16[3], normed_uint16[2])));
240a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      vst1_u8(out_ptr + 8,
241a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur              vqmovn_u16(vcombine_u16(normed_uint16[1], normed_uint16[0])));
242a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      out_ptr += cols;
243a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    }
244a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  }
245a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur}
246a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
247a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur}  // end namespace
248a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#endif  // USE_NEON
249a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
250a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlurnamespace tensorflow {
251a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
252a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlurtypedef Eigen::ThreadPoolDevice CPUDevice;
253a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
254a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlurclass QuantizedInstanceNorm : public OpKernel {
255a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur public:
256a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  explicit QuantizedInstanceNorm(OpKernelConstruction* context)
257a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      : OpKernel(context) {
258a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    OP_REQUIRES_OK(context,
259a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                   context->GetAttr("variance_epsilon", &variance_epsilon_));
260a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    OP_REQUIRES_OK(context,
261a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                   context->GetAttr("min_separation", &min_separation_));
262a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    OP_REQUIRES_OK(
263a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        context, context->GetAttr("output_range_given", &output_range_given_));
264a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    if (output_range_given_) {
265a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      OP_REQUIRES_OK(context, context->GetAttr("given_y_min", &given_y_min_));
266a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      OP_REQUIRES_OK(context, context->GetAttr("given_y_max", &given_y_max_));
267a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      OP_REQUIRES(context, given_y_min_ < given_y_max_,
268a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                  errors::InvalidArgument(
269a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                      "given_y_min must be less than given_y_max : ",
270a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                      given_y_min_, " >= ", given_y_max_));
271a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    }
272a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  }
273a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
274a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  void Compute(OpKernelContext* context) override {
275a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    const Tensor& input = context->input(0);
276a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
277a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    float input_min = context->input(1).flat<float>()(0);
278a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    float input_max = context->input(2).flat<float>()(0);
279a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    float input_scale = (input_max - input_min) / 255.0f;
280a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
281982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    OP_REQUIRES(context, input_min < input_max,
282982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                errors::InvalidArgument(
283982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                    "input_min must be less than input_max : ", input_min,
284982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                    " >= ", input_max));
285a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
286a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    auto input_tensor = input.tensor<quint8, 4>();
287a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    auto N = input_tensor.dimension(0);
288a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    auto H = input_tensor.dimension(1);
289a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    auto W = input_tensor.dimension(2);
290a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    auto C = input_tensor.dimension(3);
291a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
292a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    Tensor* output = nullptr;
293a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    OP_REQUIRES_OK(context,
294a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                   context->allocate_output(0, input.shape(), &output));
295a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
296a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    Tensor* output_min = nullptr;
297a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    OP_REQUIRES_OK(context, context->allocate_output(1, {}, &output_min));
298a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    Tensor* output_max = nullptr;
299a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_max));
300a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
301a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    typedef TTypes<float>::Tensor::Index Index;
302a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
303a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#if defined(EIGEN_HAS_INDEX_LIST)
304a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    const Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<2>>
305a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        reduction_indices;
306a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    Eigen::IndexList<Eigen::type2index<1>, Index, Index, Eigen::type2index<1>>
307a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        broadcast_spec;
308a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    broadcast_spec.set(1, H);
309a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    broadcast_spec.set(2, W);
310a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    Eigen::IndexList<Index, Eigen::type2index<1>, Eigen::type2index<1>, Index>
311a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        expand_spec;
312a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    expand_spec.set(0, N);
313a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    expand_spec.set(3, C);
314a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#else
315a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    const Eigen::array<Index, 2> reduction_indices{1, 2};
316a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    const Eigen::array<Index, 4> broadcast_spec{1, H, W, 1};
317a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    const Eigen::array<Index, 4> expand_spec{N, 1, 1, C};
318a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#endif
319a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
320a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    Eigen::Tensor<float, 2, Eigen::RowMajor> float_mean(N, C);
321a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    Eigen::Tensor<float, 2, Eigen::RowMajor> float_variance(N, C);
322a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
323a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#ifdef USE_NEON
324a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    if (N == 1 && (C % 16 == 0)) {
325a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      VLOG(2) << "Calling optimized";
326a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      ColMeanAndVariance(reinterpret_cast<const uint8_t*>(input_tensor.data()),
327a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                         H * W, C, float_mean.data(), float_variance.data());
328a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
329a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      float minimum = given_y_min_, maximum = given_y_max_;
330a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      if (!output_range_given_) {
331a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        MinAndMax(reinterpret_cast<const uint8_t*>(input_tensor.data()), H * W,
332a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                  C, float_mean.data(), float_variance.data(),
333a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                  variance_epsilon_, &minimum, &maximum);
334a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      }
335a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
336a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      if (maximum - minimum < min_separation_) {
337a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        maximum = minimum + min_separation_;
338a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      }
339a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
340a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      InstanceNorm(reinterpret_cast<const uint8_t*>(input_tensor.data()), H * W,
341a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                   C, float_mean.data(), float_variance.data(),
342a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                   variance_epsilon_, minimum, maximum,
343a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                   reinterpret_cast<uint8_t*>(output->flat<quint8>().data()));
344a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      output_min->scalar<float>()() = minimum;
345a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      output_max->scalar<float>()() = maximum;
346a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    } else  // NOLINT(readability/braces)
347a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#endif
348a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    {
349a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      VLOG(2) << "Calling unoptimized";
350a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      float_mean = input_tensor.cast<float>().reduce(
351a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur          reduction_indices, Eigen::internal::MeanReducer<float>());
352a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
353a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      float_variance =
354a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur          (input_scale *
355a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur           ((input_tensor.cast<float>() -
356a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur             float_mean.reshape(expand_spec).broadcast(broadcast_spec))))
357a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur              .square()
358a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur              .reduce(reduction_indices, Eigen::internal::MeanReducer<float>());
359a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
360a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      Eigen::Tensor<float, 4, Eigen::RowMajor> instance_normed =
361a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur          input_scale *
362a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur          (input_tensor.cast<float>() -
363a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur           float_mean.reshape(expand_spec).broadcast(broadcast_spec)) *
364a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur          (float_variance + variance_epsilon_)
365a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur              .rsqrt()
366a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur              .reshape(expand_spec)
367a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur              .broadcast(broadcast_spec);
368a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
369a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      Eigen::Tensor<float, 0, Eigen::RowMajor> normed_min;
370a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      Eigen::Tensor<float, 0, Eigen::RowMajor> normed_max;
371a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
372a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      if (!output_range_given_) {
373a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        normed_min = instance_normed.minimum();
374a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        normed_max = instance_normed.maximum();
375a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      } else {
376a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        normed_min() = given_y_min_;
377a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        normed_max() = given_y_max_;
378a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      }
379a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
380a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      if (normed_max() - normed_min() < min_separation_) {
381a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur        normed_max() = normed_min() + min_separation_;
382a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      }
383a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
384a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      FloatToQuantizedStruct<quint8> output_f2q(normed_min(), normed_max());
385a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      auto instance_normed_quantized =
386a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur          QUANTIZE_WITH_EIGEN(instance_normed, output_f2q, quint8);
387a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
388a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      output->tensor<quint8, 4>().device(
389a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur          context->template eigen_device<CPUDevice>()) =
390a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur          instance_normed_quantized;
391a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      output_min->flat<float>()(0) = normed_min();
392a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur      output_max->flat<float>()(0) = normed_max();
393a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur    }
394a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  }
395a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
396a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur private:
397a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  float variance_epsilon_;
398a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  float min_separation_;
399a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  bool output_range_given_;
400a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  float given_y_min_;
401a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur  float given_y_max_;
402a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur};
403a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
404a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath KudlurREGISTER_KERNEL_BUILDER(Name("QuantizedInstanceNorm")
405a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                            .Device(DEVICE_CPU)
406a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                            .TypeConstraint<quint8>("T"),
407a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur                        QuantizedInstanceNorm);
408a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur
409a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur}  // namespace tensorflow
410