1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#define EIGEN_USE_THREADS 17 18#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 19#include "tensorflow/core/framework/numeric_op.h" 20#include "tensorflow/core/framework/op_kernel.h" 21#include "tensorflow/core/framework/register_types.h" 22#include "tensorflow/core/framework/tensor.h" 23#include "tensorflow/core/kernels/quantization_utils.h" 24 25namespace tensorflow { 26 27namespace { 28 29// A slow but straightforward implementation of batch normalization. 30template <typename T1, typename T2> 31void ReferenceBatchNorm(const Tensor& input, const float input_min, 32 const float input_max, const Tensor& mean, 33 float mean_min, float mean_max, const Tensor& var, 34 float var_min, float var_max, const Tensor& beta, 35 float beta_min, float beta_max, const Tensor& gamma, 36 float gamma_min, float gamma_max, 37 float variance_epsilon, bool scale_after_normalization, 38 Tensor* output, float* output_min, float* output_max) { 39 auto input_flat = input.flat<T1>(); 40 auto mean_flat = mean.flat<T1>(); 41 auto var_flat = var.flat<T1>(); 42 auto beta_flat = beta.flat<T1>(); 43 auto gamma_flat = gamma.flat<T1>(); 44 auto output_flat = output->flat<T2>(); 45 46 const int depth = mean.dim_size(0); 47 const int row_count = input_flat.size() / depth; 48 49 *output_min = std::numeric_limits<float>::max(); 50 *output_max = std::numeric_limits<float>::lowest(); 51 for (int pass = 0; pass < 2; ++pass) { 52 const bool is_range_pass = (pass == 0); 53 for (int row_index = 0; row_index < row_count; ++row_index) { 54 for (int channel = 0; channel < depth; ++channel) { 55 const int input_index = (row_index * depth) + channel; 56 const float input_value = 57 QuantizedToFloat(input_flat(input_index), input_min, input_max); 58 const float mean_value = 59 QuantizedToFloat(mean_flat(channel), mean_min, mean_max); 60 const float var_value = 61 QuantizedToFloat(var_flat(channel), var_min, var_max); 62 const float beta_value = 63 QuantizedToFloat(beta_flat(channel), beta_min, beta_max); 64 const float gamma_value = 65 QuantizedToFloat(gamma_flat(channel), gamma_min, gamma_max); 66 float output_value; 67 if (scale_after_normalization) { 68 output_value = (((input_value - mean_value) / 69 sqrtf(var_value + variance_epsilon)) * 70 gamma_value) + 71 beta_value; 72 } else { 73 output_value = ((input_value - mean_value) / 74 sqrtf(var_value + variance_epsilon)) + 75 beta_value; 76 } 77 if (is_range_pass) { 78 *output_min = std::min(output_value, *output_min); 79 *output_max = std::max(output_value, *output_max); 80 } else { 81 output_flat(input_index) = 82 FloatToQuantized<T2>(output_value, *output_min, *output_max); 83 } 84 } 85 } 86 } 87} 88 89// An implementation of batch normalization that does the main calculations 90// using only fixed-point arithmetic. There's a prologue with some floating 91// calculations, but assuming the weights are constant these could be hoisted to 92// an offline process, or baked into the weights. 93template <typename T1, typename T2> 94void FixedPointBatchNorm(const Tensor& input, const float input_min, 95 const float input_max, const Tensor& mean, 96 float mean_min, float mean_max, const Tensor& var, 97 float var_min, float var_max, const Tensor& beta, 98 float beta_min, float beta_max, const Tensor& gamma, 99 float gamma_min, float gamma_max, 100 float variance_epsilon, bool scale_after_normalization, 101 Tensor* output, float* output_min, float* output_max) { 102 auto input_flat = input.flat<T1>(); 103 auto mean_flat = mean.flat<T1>(); 104 auto var_flat = var.flat<T1>(); 105 auto beta_flat = beta.flat<T1>(); 106 auto gamma_flat = gamma.flat<T1>(); 107 auto output_flat = output->flat<T2>(); 108 109 const int depth = mean.dim_size(0); 110 const int row_count = input_flat.size() / depth; 111 112 // The range here is chosen so that typical input values fit in without any 113 // overflow or loss of precision, going from +1m to -1m with 10 bits of fixed 114 // point precision. 115 *output_min = -(1 << 20); 116 *output_max = (1 << 20); 117 118 Tensor scale_tensor(DataTypeToEnum<T2>::v(), {depth}); 119 auto scale_flat = scale_tensor.flat<T2>(); 120 Tensor offset_tensor(DataTypeToEnum<T2>::v(), {depth}); 121 auto offset_flat = offset_tensor.flat<T2>(); 122 for (int channel = 0; channel < depth; ++channel) { 123 const float mean_value = 124 QuantizedToFloat(mean_flat(channel), mean_min, mean_max); 125 const float var_value = 126 QuantizedToFloat(var_flat(channel), var_min, var_max); 127 const float beta_value = 128 QuantizedToFloat(beta_flat(channel), beta_min, beta_max); 129 const float gamma_value = 130 QuantizedToFloat(gamma_flat(channel), gamma_min, gamma_max); 131 float scale_value; 132 if (scale_after_normalization) { 133 scale_value = (1.0f / sqrtf(var_value + variance_epsilon)) * gamma_value; 134 } else { 135 scale_value = (1.0f / sqrtf(var_value + variance_epsilon)); 136 } 137 const float offset_value = (-mean_value * scale_value) + beta_value; 138 scale_flat(channel) = 139 FloatToQuantized<T2>(scale_value, *output_min, *output_max); 140 offset_flat(channel) = 141 FloatToQuantized<T2>(offset_value, *output_min, *output_max); 142 } 143 144 const T2 one_in_output_space = 145 FloatToQuantized<T2>(1.0f, *output_min, *output_max); 146 for (int row_index = 0; row_index < row_count; ++row_index) { 147 for (int channel = 0; channel < depth; ++channel) { 148 const int input_index = (row_index * depth) + channel; 149 const T2 input_value = 150 RequantizeInNewRange<T1, T2>(input_flat(input_index), input_min, 151 input_max, *output_min, *output_max); 152 const T2 scale_value = scale_flat(channel); 153 const T2 offset_value = offset_flat(channel); 154 const T2 output_value = 155 ((input_value * scale_value) / one_in_output_space) + offset_value; 156 output_flat(input_index) = output_value; 157 } 158 } 159} 160 161} // namespace 162 163template <typename T1, typename T2> 164class QuantizedBatchNormOp : public OpKernel { 165 public: 166 explicit QuantizedBatchNormOp(OpKernelConstruction* context) 167 : OpKernel(context) { 168 OP_REQUIRES_OK(context, 169 context->GetAttr("variance_epsilon", &variance_epsilon_)); 170 OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization", 171 &scale_after_normalization_)); 172 } 173 174 void Compute(OpKernelContext* context) override { 175 const Tensor& input = context->input(0); 176 const float input_min = context->input(1).flat<float>()(0); 177 const float input_max = context->input(2).flat<float>()(0); 178 const Tensor& mean = context->input(3); 179 const float mean_min = context->input(4).flat<float>()(0); 180 const float mean_max = context->input(5).flat<float>()(0); 181 const Tensor& var = context->input(6); 182 const float var_min = context->input(7).flat<float>()(0); 183 const float var_max = context->input(8).flat<float>()(0); 184 const Tensor& beta = context->input(9); 185 const float beta_min = context->input(10).flat<float>()(0); 186 const float beta_max = context->input(11).flat<float>()(0); 187 const Tensor& gamma = context->input(12); 188 const float gamma_min = context->input(13).flat<float>()(0); 189 const float gamma_max = context->input(14).flat<float>()(0); 190 191 OP_REQUIRES(context, input.dims() == 4, 192 errors::InvalidArgument("input must be 4-dimensional", 193 input.shape().DebugString())); 194 OP_REQUIRES(context, mean.dims() == 1, 195 errors::InvalidArgument("mean must be 1-dimensional", 196 mean.shape().DebugString())); 197 OP_REQUIRES(context, var.dims() == 1, 198 errors::InvalidArgument("var must be 1-dimensional", 199 var.shape().DebugString())); 200 OP_REQUIRES(context, beta.dims() == 1, 201 errors::InvalidArgument("beta must be 1-dimensional", 202 beta.shape().DebugString())); 203 OP_REQUIRES(context, gamma.dims() == 1, 204 errors::InvalidArgument("gamma must be 1-dimensional", 205 gamma.shape().DebugString())); 206 207 Tensor* output = nullptr; 208 OP_REQUIRES_OK(context, 209 context->allocate_output(0, input.shape(), &output)); 210 float output_min; 211 float output_max; 212 FixedPointBatchNorm<T1, T2>(input, input_min, input_max, mean, mean_min, 213 mean_max, var, var_min, var_max, beta, beta_min, 214 beta_max, gamma, gamma_min, gamma_max, 215 variance_epsilon_, scale_after_normalization_, 216 output, &output_min, &output_max); 217 218 Tensor* output_min_tensor = nullptr; 219 OP_REQUIRES_OK(context, 220 context->allocate_output(1, {}, &output_min_tensor)); 221 output_min_tensor->flat<float>()(0) = output_min; 222 223 Tensor* output_max_tensor = nullptr; 224 OP_REQUIRES_OK(context, 225 context->allocate_output(2, {}, &output_max_tensor)); 226 output_max_tensor->flat<float>()(0) = output_max; 227 } 228 229 private: 230 float variance_epsilon_; 231 bool scale_after_normalization_; 232}; 233 234REGISTER_KERNEL_BUILDER(Name("QuantizedBatchNormWithGlobalNormalization") 235 .Device(DEVICE_CPU) 236 .TypeConstraint<quint8>("Tinput") 237 .TypeConstraint<qint32>("out_type"), 238 QuantizedBatchNormOp<quint8, qint32>); 239 240} // namespace tensorflow 241