1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
19#include "tensorflow/core/framework/numeric_op.h"
20#include "tensorflow/core/framework/op_kernel.h"
21#include "tensorflow/core/framework/register_types.h"
22#include "tensorflow/core/framework/tensor.h"
23#include "tensorflow/core/kernels/quantization_utils.h"
24
25namespace tensorflow {
26
27namespace {
28
29// A slow but straightforward implementation of batch normalization.
30template <typename T1, typename T2>
31void ReferenceBatchNorm(const Tensor& input, const float input_min,
32                        const float input_max, const Tensor& mean,
33                        float mean_min, float mean_max, const Tensor& var,
34                        float var_min, float var_max, const Tensor& beta,
35                        float beta_min, float beta_max, const Tensor& gamma,
36                        float gamma_min, float gamma_max,
37                        float variance_epsilon, bool scale_after_normalization,
38                        Tensor* output, float* output_min, float* output_max) {
39  auto input_flat = input.flat<T1>();
40  auto mean_flat = mean.flat<T1>();
41  auto var_flat = var.flat<T1>();
42  auto beta_flat = beta.flat<T1>();
43  auto gamma_flat = gamma.flat<T1>();
44  auto output_flat = output->flat<T2>();
45
46  const int depth = mean.dim_size(0);
47  const int row_count = input_flat.size() / depth;
48
49  *output_min = std::numeric_limits<float>::max();
50  *output_max = std::numeric_limits<float>::lowest();
51  for (int pass = 0; pass < 2; ++pass) {
52    const bool is_range_pass = (pass == 0);
53    for (int row_index = 0; row_index < row_count; ++row_index) {
54      for (int channel = 0; channel < depth; ++channel) {
55        const int input_index = (row_index * depth) + channel;
56        const float input_value =
57            QuantizedToFloat(input_flat(input_index), input_min, input_max);
58        const float mean_value =
59            QuantizedToFloat(mean_flat(channel), mean_min, mean_max);
60        const float var_value =
61            QuantizedToFloat(var_flat(channel), var_min, var_max);
62        const float beta_value =
63            QuantizedToFloat(beta_flat(channel), beta_min, beta_max);
64        const float gamma_value =
65            QuantizedToFloat(gamma_flat(channel), gamma_min, gamma_max);
66        float output_value;
67        if (scale_after_normalization) {
68          output_value = (((input_value - mean_value) /
69                           sqrtf(var_value + variance_epsilon)) *
70                          gamma_value) +
71                         beta_value;
72        } else {
73          output_value = ((input_value - mean_value) /
74                          sqrtf(var_value + variance_epsilon)) +
75                         beta_value;
76        }
77        if (is_range_pass) {
78          *output_min = std::min(output_value, *output_min);
79          *output_max = std::max(output_value, *output_max);
80        } else {
81          output_flat(input_index) =
82              FloatToQuantized<T2>(output_value, *output_min, *output_max);
83        }
84      }
85    }
86  }
87}
88
89// An implementation of batch normalization that does the main calculations
90// using only fixed-point arithmetic. There's a prologue with some floating
91// calculations, but assuming the weights are constant these could be hoisted to
92// an offline process, or baked into the weights.
93template <typename T1, typename T2>
94void FixedPointBatchNorm(const Tensor& input, const float input_min,
95                         const float input_max, const Tensor& mean,
96                         float mean_min, float mean_max, const Tensor& var,
97                         float var_min, float var_max, const Tensor& beta,
98                         float beta_min, float beta_max, const Tensor& gamma,
99                         float gamma_min, float gamma_max,
100                         float variance_epsilon, bool scale_after_normalization,
101                         Tensor* output, float* output_min, float* output_max) {
102  auto input_flat = input.flat<T1>();
103  auto mean_flat = mean.flat<T1>();
104  auto var_flat = var.flat<T1>();
105  auto beta_flat = beta.flat<T1>();
106  auto gamma_flat = gamma.flat<T1>();
107  auto output_flat = output->flat<T2>();
108
109  const int depth = mean.dim_size(0);
110  const int row_count = input_flat.size() / depth;
111
112  // The range here is chosen so that typical input values fit in without any
113  // overflow or loss of precision, going from +1m to -1m with 10 bits of fixed
114  // point precision.
115  *output_min = -(1 << 20);
116  *output_max = (1 << 20);
117
118  Tensor scale_tensor(DataTypeToEnum<T2>::v(), {depth});
119  auto scale_flat = scale_tensor.flat<T2>();
120  Tensor offset_tensor(DataTypeToEnum<T2>::v(), {depth});
121  auto offset_flat = offset_tensor.flat<T2>();
122  for (int channel = 0; channel < depth; ++channel) {
123    const float mean_value =
124        QuantizedToFloat(mean_flat(channel), mean_min, mean_max);
125    const float var_value =
126        QuantizedToFloat(var_flat(channel), var_min, var_max);
127    const float beta_value =
128        QuantizedToFloat(beta_flat(channel), beta_min, beta_max);
129    const float gamma_value =
130        QuantizedToFloat(gamma_flat(channel), gamma_min, gamma_max);
131    float scale_value;
132    if (scale_after_normalization) {
133      scale_value = (1.0f / sqrtf(var_value + variance_epsilon)) * gamma_value;
134    } else {
135      scale_value = (1.0f / sqrtf(var_value + variance_epsilon));
136    }
137    const float offset_value = (-mean_value * scale_value) + beta_value;
138    scale_flat(channel) =
139        FloatToQuantized<T2>(scale_value, *output_min, *output_max);
140    offset_flat(channel) =
141        FloatToQuantized<T2>(offset_value, *output_min, *output_max);
142  }
143
144  const T2 one_in_output_space =
145      FloatToQuantized<T2>(1.0f, *output_min, *output_max);
146  for (int row_index = 0; row_index < row_count; ++row_index) {
147    for (int channel = 0; channel < depth; ++channel) {
148      const int input_index = (row_index * depth) + channel;
149      const T2 input_value =
150          RequantizeInNewRange<T1, T2>(input_flat(input_index), input_min,
151                                       input_max, *output_min, *output_max);
152      const T2 scale_value = scale_flat(channel);
153      const T2 offset_value = offset_flat(channel);
154      const T2 output_value =
155          ((input_value * scale_value) / one_in_output_space) + offset_value;
156      output_flat(input_index) = output_value;
157    }
158  }
159}
160
161}  // namespace
162
163template <typename T1, typename T2>
164class QuantizedBatchNormOp : public OpKernel {
165 public:
166  explicit QuantizedBatchNormOp(OpKernelConstruction* context)
167      : OpKernel(context) {
168    OP_REQUIRES_OK(context,
169                   context->GetAttr("variance_epsilon", &variance_epsilon_));
170    OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",
171                                             &scale_after_normalization_));
172  }
173
174  void Compute(OpKernelContext* context) override {
175    const Tensor& input = context->input(0);
176    const float input_min = context->input(1).flat<float>()(0);
177    const float input_max = context->input(2).flat<float>()(0);
178    const Tensor& mean = context->input(3);
179    const float mean_min = context->input(4).flat<float>()(0);
180    const float mean_max = context->input(5).flat<float>()(0);
181    const Tensor& var = context->input(6);
182    const float var_min = context->input(7).flat<float>()(0);
183    const float var_max = context->input(8).flat<float>()(0);
184    const Tensor& beta = context->input(9);
185    const float beta_min = context->input(10).flat<float>()(0);
186    const float beta_max = context->input(11).flat<float>()(0);
187    const Tensor& gamma = context->input(12);
188    const float gamma_min = context->input(13).flat<float>()(0);
189    const float gamma_max = context->input(14).flat<float>()(0);
190
191    OP_REQUIRES(context, input.dims() == 4,
192                errors::InvalidArgument("input must be 4-dimensional",
193                                        input.shape().DebugString()));
194    OP_REQUIRES(context, mean.dims() == 1,
195                errors::InvalidArgument("mean must be 1-dimensional",
196                                        mean.shape().DebugString()));
197    OP_REQUIRES(context, var.dims() == 1,
198                errors::InvalidArgument("var must be 1-dimensional",
199                                        var.shape().DebugString()));
200    OP_REQUIRES(context, beta.dims() == 1,
201                errors::InvalidArgument("beta must be 1-dimensional",
202                                        beta.shape().DebugString()));
203    OP_REQUIRES(context, gamma.dims() == 1,
204                errors::InvalidArgument("gamma must be 1-dimensional",
205                                        gamma.shape().DebugString()));
206
207    Tensor* output = nullptr;
208    OP_REQUIRES_OK(context,
209                   context->allocate_output(0, input.shape(), &output));
210    float output_min;
211    float output_max;
212    FixedPointBatchNorm<T1, T2>(input, input_min, input_max, mean, mean_min,
213                                mean_max, var, var_min, var_max, beta, beta_min,
214                                beta_max, gamma, gamma_min, gamma_max,
215                                variance_epsilon_, scale_after_normalization_,
216                                output, &output_min, &output_max);
217
218    Tensor* output_min_tensor = nullptr;
219    OP_REQUIRES_OK(context,
220                   context->allocate_output(1, {}, &output_min_tensor));
221    output_min_tensor->flat<float>()(0) = output_min;
222
223    Tensor* output_max_tensor = nullptr;
224    OP_REQUIRES_OK(context,
225                   context->allocate_output(2, {}, &output_max_tensor));
226    output_max_tensor->flat<float>()(0) = output_max;
227  }
228
229 private:
230  float variance_epsilon_;
231  bool scale_after_normalization_;
232};
233
234REGISTER_KERNEL_BUILDER(Name("QuantizedBatchNormWithGlobalNormalization")
235                            .Device(DEVICE_CPU)
236                            .TypeConstraint<quint8>("Tinput")
237                            .TypeConstraint<qint32>("out_type"),
238                        QuantizedBatchNormOp<quint8, qint32>);
239
240}  // namespace tensorflow
241