quantized_instance_norm.cc revision 982549ea3423df4270ff154e5c764beb43d472da
1a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 3a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License"); 4a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudluryou may not use this file except in compliance with the License. 5a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath KudlurYou may obtain a copy of the License at 6a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 7a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur http://www.apache.org/licenses/LICENSE-2.0 8a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 9a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath KudlurUnless required by applicable law or agreed to in writing, software 10a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS, 11a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath KudlurSee the License for the specific language governing permissions and 13a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlurlimitations under the License. 14a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur==============================================================================*/ 15a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 16a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#define EIGEN_USE_THREADS 17a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 18a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#if defined(__ARM_NEON__) || defined(__ARM_NEON) 19a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#define USE_NEON 20a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#include <arm_neon.h> 21a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#endif 22a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 23a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 24a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#include "tensorflow/core/framework/numeric_op.h" 25a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#include "tensorflow/core/framework/op_kernel.h" 26a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#include "tensorflow/core/framework/register_types.h" 27a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#include "tensorflow/core/framework/tensor.h" 28a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 29a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#include "tensorflow/core/kernels/quantization_utils.h" 30a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 31a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#ifdef USE_NEON 32a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlurnamespace { 33a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 34a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// Single pass mean and variance. 35a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// Shape of `input` is [rows x cols], shape of both `mean` and `variance` 36a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// is [cols]. 37a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// Note, `mean` and `variance` are of 'i' (not scaled). 38a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// The following is a straightforward implementation of the parallel algorithm 39a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// described in 40a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 41a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlurvoid ColMeanAndVariance(const uint8_t* input, const uint32_t rows, 42a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const uint32_t cols, float* mean, float* variance) { 43a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur // The implementation operates on for 16 columns at a time. 44a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur // Assumes cols % 16 == 0 45a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) { 46a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur // Vector registers to track the running sum across the rows. Since there 47a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur // are 16 columns, we have 4 32x4 registers. 48a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur uint32x4_t sum[4] = {0}; 49a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 50a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur float nA = 0.0f; 51a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur // Running average and the second moment. 52a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur float32x4_t xA[4] = {0.0f}; 53a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur float32x4_t M2A[4] = {0.0f}; 54a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 55a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const uint8_t* inp_ptr = input + col_offset; 56a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur // Go over the rows in chunks of 256. This is so that we can use 16 bit adds 57a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur // to do the accumulation. 58a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur for (uint32_t row = 0; row < rows; row += 256) { 59a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur // Running sum and sum of squares for the 256 rows. 60a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur uint32x4_t sub_sum[4] = {0}; 61a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur uint32x4_t sub_sq_sum[4] = {0}; 62a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const uint32_t limit = std::min(rows, row + 256); 63a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float nB = limit - row; 64a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur for (uint32_t subrow = row; subrow < limit; ++subrow) { 65a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const uint8x16_t v = vld1q_u8(inp_ptr); 66a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur inp_ptr += cols; 67a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 68a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const uint8x8_t v_high = vget_high_u8(v); 69a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const uint8x8_t v_low = vget_low_u8(v); 70a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 71a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const uint16x8_t v_high_u16 = vmovl_u8(v_high); 72a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const uint16x8_t v_low_u16 = vmovl_u8(v_low); 73a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 74a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const uint16x4_t v_high_high = vget_high_u16(v_high_u16); 75a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const uint16x4_t v_high_low = vget_low_u16(v_high_u16); 76a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const uint16x4_t v_low_high = vget_high_u16(v_low_u16); 77a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const uint16x4_t v_low_low = vget_low_u16(v_low_u16); 78a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 79a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur sub_sum[0] = vaddw_u16(sub_sum[0], v_high_high); 80a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur sub_sum[1] = vaddw_u16(sub_sum[1], v_high_low); 81a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur sub_sum[2] = vaddw_u16(sub_sum[2], v_low_high); 82a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur sub_sum[3] = vaddw_u16(sub_sum[3], v_low_low); 83a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 84a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur sub_sq_sum[0] = vmlal_u16(sub_sq_sum[0], v_high_high, v_high_high); 85a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur sub_sq_sum[1] = vmlal_u16(sub_sq_sum[1], v_high_low, v_high_low); 86a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur sub_sq_sum[2] = vmlal_u16(sub_sq_sum[2], v_low_high, v_low_high); 87a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur sub_sq_sum[3] = vmlal_u16(sub_sq_sum[3], v_low_low, v_low_low); 88a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur } 89a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 90a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur // Update the full running sum and moment from the ones for 256 rows. 91a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur for (int i = 0; i < 4; ++i) { 92a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur sum[i] = vaddq_u32(sum[i], sub_sum[i]); 93a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float nX = nA + nB; 94a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur // xB is the average of up to 256 elements. 95a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float32x4_t xB = 96a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vmulq_n_f32(vcvtq_f32_u32(sub_sum[i]), 1.0f / nB); 97a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 98a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur // delta = xB - xA 99a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float32x4_t delta = vsubq_f32(xB, xA[i]); 100a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur // xA = (nA * xA + nB * xB) / (nA + nB) 101a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur xA[i] = vmulq_n_f32( 102a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vaddq_f32(vmulq_n_f32(xA[i], nA), vmulq_n_f32(xB, nB)), 1.0f / nX); 103a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 104a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float32x4_t sub_sum_f32 = vcvtq_f32_u32(sub_sum[i]); 105a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float32x4_t sub_sum_sq = vmulq_f32(sub_sum_f32, sub_sum_f32); 106a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 107a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur // M2B = sum(xB^2) - sum(xB)^2/nB 108a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float32x4_t M2B = vsubq_f32(vcvtq_f32_u32(sub_sq_sum[i]), 109a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vmulq_n_f32(sub_sum_sq, 1.0f / nB)); 110a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float32x4_t last_term = 111a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vmulq_n_f32(vmulq_f32(delta, delta), nA * nB / nX); 112a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur // M2A = oldM2A + M2B + delta^2 * nA*nB/nX 113a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur M2A[i] = vaddq_f32(vaddq_f32(M2A[i], M2B), last_term); 114a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur } 115a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur nA += limit; 116a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur } 117a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 118a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur // Write the final mean and variance for the 16 columns. 119a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float inv_rows = 1.0f / static_cast<float>(rows); 120a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vst1q_f32(mean + col_offset, vmulq_n_f32(vcvtq_f32_u32(sum[3]), inv_rows)); 121a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vst1q_f32(mean + col_offset + 4, 122a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vmulq_n_f32(vcvtq_f32_u32(sum[2]), inv_rows)); 123a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vst1q_f32(mean + col_offset + 8, 124a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vmulq_n_f32(vcvtq_f32_u32(sum[1]), inv_rows)); 125a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vst1q_f32(mean + col_offset + 12, 126a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vmulq_n_f32(vcvtq_f32_u32(sum[0]), inv_rows)); 127a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 128a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vst1q_f32(variance + col_offset, vmulq_n_f32(M2A[3], inv_rows)); 129a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vst1q_f32(variance + col_offset + 4, vmulq_n_f32(M2A[2], inv_rows)); 130a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vst1q_f32(variance + col_offset + 8, vmulq_n_f32(M2A[1], inv_rows)); 131a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vst1q_f32(variance + col_offset + 12, vmulq_n_f32(M2A[0], inv_rows)); 132a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur } 133a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur} 134a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 135a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// Compute min and max of (input - mean) / sqrt(variance + epsilon). 136a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// This is done in a separate pass so that the normalized value can be 137a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// temporarily computed in floating point precision and not stored anywhere. 138a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlurvoid MinAndMax(const uint8_t* input, const uint32_t rows, const uint32_t cols, 139a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float* mean_ptr, const float* variance_ptr, 140a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur float variance_epsilon, float* minimum, float* maximum) { 141a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur float v_maximum = std::numeric_limits<float>::min(); 142a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur float v_minimum = std::numeric_limits<float>::max(); 143a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float32x4_t eps = vdupq_n_f32(variance_epsilon); 144a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 145a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) { 146a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float32x4_t mean[4] = {vld1q_f32(mean_ptr + col_offset), 147a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vld1q_f32(mean_ptr + col_offset + 4), 148a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vld1q_f32(mean_ptr + col_offset + 8), 149a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vld1q_f32(mean_ptr + col_offset + 12)}; 150a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float32x4_t variance[4] = {vld1q_f32(variance_ptr + col_offset), 151a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vld1q_f32(variance_ptr + col_offset + 4), 152a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vld1q_f32(variance_ptr + col_offset + 8), 153a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vld1q_f32(variance_ptr + col_offset + 12)}; 154a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float32x4_t inv_stddev[4] = { 155a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vrsqrteq_f32(vaddq_f32(variance[0], eps)), 156a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vrsqrteq_f32(vaddq_f32(variance[1], eps)), 157a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vrsqrteq_f32(vaddq_f32(variance[2], eps)), 158a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vrsqrteq_f32(vaddq_f32(variance[3], eps))}; 159a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 160a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const uint8_t* inp_ptr = input + col_offset; 161a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur for (uint32_t row = 0; row < rows; ++row) { 162a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const uint8x16_t v = vld1q_u8(inp_ptr); 163a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur inp_ptr += cols; 164a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 165a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const uint16x8_t v_high = vmovl_u8(vget_high_u8(v)); 166a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const uint16x8_t v_low = vmovl_u8(vget_low_u8(v)); 167a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 168a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float32x4_t v_float[4] = { 169a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_high))), 170a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_high))), 171a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_low))), 172a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_low)))}; 173a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 174a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur for (int i = 0; i < 4; ++i) { 175a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float32x4_t normed = 176a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vmulq_f32(vsubq_f32(v_float[i], mean[i]), inv_stddev[i]); 177a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float32x2_t high = vget_high_f32(normed); 178a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float32x2_t low = vget_low_f32(normed); 179a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur float32x2_t tmp_max = vpmax_f32(low, high); 180a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur tmp_max = vpmax_f32(tmp_max, tmp_max); 181a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur v_maximum = std::max(v_maximum, vget_lane_f32(tmp_max, 0)); 182a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur float32x2_t tmp_min = vpmin_f32(low, high); 183a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur tmp_min = vpmin_f32(tmp_min, tmp_min); 184a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur v_minimum = std::min(v_minimum, vget_lane_f32(tmp_min, 0)); 185a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur } 186a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur } 187a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur } 188a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur *minimum = v_minimum; 189a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur *maximum = v_maximum; 190a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur} 191a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 192a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// Compute (input - mean) / sqrt(variance + epsilon) in floating point, quantize 193a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur// it in the range (minimum, maximum) and store the result as quint8. 194a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlurvoid InstanceNorm(const uint8_t* input, const uint32_t rows, 195a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const uint32_t cols, const float* mean_ptr, 196a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float* variance_ptr, float variance_epsilon, 197a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur float minimum, float maximum, uint8_t* output) { 198a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float32x4_t eps = vdupq_n_f32(variance_epsilon); 199a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float32x4_t out_min = vdupq_n_f32(minimum); 200a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float out_scale = 255.0f / (maximum - minimum); 201a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 202a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) { 203a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float32x4_t mean[4] = {vld1q_f32(mean_ptr + col_offset + 12), 204a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vld1q_f32(mean_ptr + col_offset + 8), 205a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vld1q_f32(mean_ptr + col_offset + 4), 206a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vld1q_f32(mean_ptr + col_offset)}; 207a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float32x4_t variance[4] = {vld1q_f32(variance_ptr + col_offset + 12), 208a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vld1q_f32(variance_ptr + col_offset + 8), 209a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vld1q_f32(variance_ptr + col_offset + 4), 210a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vld1q_f32(variance_ptr + col_offset)}; 211a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float32x4_t inv_stddev[4] = { 212a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vrsqrteq_f32(vaddq_f32(variance[0], eps)), 213a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vrsqrteq_f32(vaddq_f32(variance[1], eps)), 214a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vrsqrteq_f32(vaddq_f32(variance[2], eps)), 215a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vrsqrteq_f32(vaddq_f32(variance[3], eps))}; 216a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const uint8_t* inp_ptr = input + col_offset; 217a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur uint8_t* out_ptr = output + col_offset; 218a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur for (uint32_t row = 0; row < rows; ++row) { 219a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const uint8x16_t v = vld1q_u8(inp_ptr); 220a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur inp_ptr += cols; 221a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const uint16x8_t v_high = vmovl_u8(vget_high_u8(v)); 222a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const uint16x8_t v_low = vmovl_u8(vget_low_u8(v)); 223a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 224a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float32x4_t v_float[4] = { 225a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_high))), 226a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_high))), 227a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_low))), 228a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_low)))}; 229a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 230a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur uint16x4_t normed_uint16[4]; 231a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur for (int i = 0; i < 4; ++i) { 232a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const float32x4_t normed = 233a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vmulq_f32(vsubq_f32(v_float[i], mean[i]), inv_stddev[i]); 234a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const int32x4_t normed_int32 = 235a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vcvtq_s32_f32(vmulq_n_f32(vsubq_f32(normed, out_min), out_scale)); 236a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur normed_uint16[i] = vqmovun_s32(normed_int32); 237a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur } 238a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vst1_u8(out_ptr, 239a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vqmovn_u16(vcombine_u16(normed_uint16[3], normed_uint16[2]))); 240a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vst1_u8(out_ptr + 8, 241a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur vqmovn_u16(vcombine_u16(normed_uint16[1], normed_uint16[0]))); 242a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur out_ptr += cols; 243a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur } 244a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur } 245a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur} 246a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 247a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur} // end namespace 248a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#endif // USE_NEON 249a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 250a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlurnamespace tensorflow { 251a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 252a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlurtypedef Eigen::ThreadPoolDevice CPUDevice; 253a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 254a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlurclass QuantizedInstanceNorm : public OpKernel { 255a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur public: 256a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur explicit QuantizedInstanceNorm(OpKernelConstruction* context) 257a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur : OpKernel(context) { 258a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur OP_REQUIRES_OK(context, 259a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur context->GetAttr("variance_epsilon", &variance_epsilon_)); 260a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur OP_REQUIRES_OK(context, 261a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur context->GetAttr("min_separation", &min_separation_)); 262a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur OP_REQUIRES_OK( 263a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur context, context->GetAttr("output_range_given", &output_range_given_)); 264a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur if (output_range_given_) { 265a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur OP_REQUIRES_OK(context, context->GetAttr("given_y_min", &given_y_min_)); 266a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur OP_REQUIRES_OK(context, context->GetAttr("given_y_max", &given_y_max_)); 267a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur OP_REQUIRES(context, given_y_min_ < given_y_max_, 268a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur errors::InvalidArgument( 269a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur "given_y_min must be less than given_y_max : ", 270a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur given_y_min_, " >= ", given_y_max_)); 271a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur } 272a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur } 273a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 274a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur void Compute(OpKernelContext* context) override { 275a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const Tensor& input = context->input(0); 276a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 277a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur float input_min = context->input(1).flat<float>()(0); 278a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur float input_max = context->input(2).flat<float>()(0); 279a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur float input_scale = (input_max - input_min) / 255.0f; 280a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 281982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen OP_REQUIRES(context, input_min < input_max, 282982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen errors::InvalidArgument( 283982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen "input_min must be less than input_max : ", input_min, 284982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen " >= ", input_max)); 285a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 286a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur auto input_tensor = input.tensor<quint8, 4>(); 287a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur auto N = input_tensor.dimension(0); 288a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur auto H = input_tensor.dimension(1); 289a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur auto W = input_tensor.dimension(2); 290a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur auto C = input_tensor.dimension(3); 291a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 292a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur Tensor* output = nullptr; 293a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur OP_REQUIRES_OK(context, 294a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur context->allocate_output(0, input.shape(), &output)); 295a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 296a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur Tensor* output_min = nullptr; 297a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur OP_REQUIRES_OK(context, context->allocate_output(1, {}, &output_min)); 298a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur Tensor* output_max = nullptr; 299a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_max)); 300a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 301a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur typedef TTypes<float>::Tensor::Index Index; 302a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 303a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#if defined(EIGEN_HAS_INDEX_LIST) 304a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<2>> 305a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur reduction_indices; 306a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur Eigen::IndexList<Eigen::type2index<1>, Index, Index, Eigen::type2index<1>> 307a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur broadcast_spec; 308a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur broadcast_spec.set(1, H); 309a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur broadcast_spec.set(2, W); 310a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur Eigen::IndexList<Index, Eigen::type2index<1>, Eigen::type2index<1>, Index> 311a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur expand_spec; 312a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur expand_spec.set(0, N); 313a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur expand_spec.set(3, C); 314a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#else 315a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const Eigen::array<Index, 2> reduction_indices{1, 2}; 316a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const Eigen::array<Index, 4> broadcast_spec{1, H, W, 1}; 317a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur const Eigen::array<Index, 4> expand_spec{N, 1, 1, C}; 318a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#endif 319a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 320a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur Eigen::Tensor<float, 2, Eigen::RowMajor> float_mean(N, C); 321a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur Eigen::Tensor<float, 2, Eigen::RowMajor> float_variance(N, C); 322a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 323a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#ifdef USE_NEON 324a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur if (N == 1 && (C % 16 == 0)) { 325a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur VLOG(2) << "Calling optimized"; 326a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur ColMeanAndVariance(reinterpret_cast<const uint8_t*>(input_tensor.data()), 327a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur H * W, C, float_mean.data(), float_variance.data()); 328a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 329a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur float minimum = given_y_min_, maximum = given_y_max_; 330a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur if (!output_range_given_) { 331a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur MinAndMax(reinterpret_cast<const uint8_t*>(input_tensor.data()), H * W, 332a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur C, float_mean.data(), float_variance.data(), 333a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur variance_epsilon_, &minimum, &maximum); 334a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur } 335a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 336a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur if (maximum - minimum < min_separation_) { 337a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur maximum = minimum + min_separation_; 338a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur } 339a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 340a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur InstanceNorm(reinterpret_cast<const uint8_t*>(input_tensor.data()), H * W, 341a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur C, float_mean.data(), float_variance.data(), 342a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur variance_epsilon_, minimum, maximum, 343a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur reinterpret_cast<uint8_t*>(output->flat<quint8>().data())); 344a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur output_min->scalar<float>()() = minimum; 345a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur output_max->scalar<float>()() = maximum; 346a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur } else // NOLINT(readability/braces) 347a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur#endif 348a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur { 349a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur VLOG(2) << "Calling unoptimized"; 350a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur float_mean = input_tensor.cast<float>().reduce( 351a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur reduction_indices, Eigen::internal::MeanReducer<float>()); 352a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 353a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur float_variance = 354a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur (input_scale * 355a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur ((input_tensor.cast<float>() - 356a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur float_mean.reshape(expand_spec).broadcast(broadcast_spec)))) 357a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur .square() 358a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur .reduce(reduction_indices, Eigen::internal::MeanReducer<float>()); 359a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 360a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur Eigen::Tensor<float, 4, Eigen::RowMajor> instance_normed = 361a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur input_scale * 362a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur (input_tensor.cast<float>() - 363a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur float_mean.reshape(expand_spec).broadcast(broadcast_spec)) * 364a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur (float_variance + variance_epsilon_) 365a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur .rsqrt() 366a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur .reshape(expand_spec) 367a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur .broadcast(broadcast_spec); 368a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 369a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur Eigen::Tensor<float, 0, Eigen::RowMajor> normed_min; 370a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur Eigen::Tensor<float, 0, Eigen::RowMajor> normed_max; 371a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 372a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur if (!output_range_given_) { 373a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur normed_min = instance_normed.minimum(); 374a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur normed_max = instance_normed.maximum(); 375a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur } else { 376a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur normed_min() = given_y_min_; 377a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur normed_max() = given_y_max_; 378a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur } 379a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 380a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur if (normed_max() - normed_min() < min_separation_) { 381a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur normed_max() = normed_min() + min_separation_; 382a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur } 383a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 384a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur FloatToQuantizedStruct<quint8> output_f2q(normed_min(), normed_max()); 385a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur auto instance_normed_quantized = 386a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur QUANTIZE_WITH_EIGEN(instance_normed, output_f2q, quint8); 387a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 388a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur output->tensor<quint8, 4>().device( 389a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur context->template eigen_device<CPUDevice>()) = 390a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur instance_normed_quantized; 391a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur output_min->flat<float>()(0) = normed_min(); 392a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur output_max->flat<float>()(0) = normed_max(); 393a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur } 394a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur } 395a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 396a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur private: 397a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur float variance_epsilon_; 398a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur float min_separation_; 399a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur bool output_range_given_; 400a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur float given_y_min_; 401a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur float given_y_max_; 402a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur}; 403a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 404a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath KudlurREGISTER_KERNEL_BUILDER(Name("QuantizedInstanceNorm") 405a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur .Device(DEVICE_CPU) 406a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur .TypeConstraint<quint8>("T"), 407a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur QuantizedInstanceNorm); 408a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur 409a79a7a21358c12d1b9096b357ea5267fc7800775Manjunath Kudlur} // namespace tensorflow 410