1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#if GOOGLE_CUDA
19#define EIGEN_USE_GPU
20#endif  // GOOGLE_CUDA
21
22#include "tensorflow/core/kernels/quantize_and_dequantize_op.h"
23
24#include "tensorflow/core/framework/op.h"
25#include "tensorflow/core/framework/op_kernel.h"
26#include "tensorflow/core/framework/register_types.h"
27#include "tensorflow/core/framework/type_traits.h"
28#include "tensorflow/core/framework/types.h"
29#include "tensorflow/core/lib/core/errors.h"
30
31namespace tensorflow {
32
33typedef Eigen::ThreadPoolDevice CPUDevice;
34typedef Eigen::GpuDevice GPUDevice;
35
36// Simulate quantization precision loss in a float tensor by:
37// 1. Quantize the tensor to fixed point numbers, which should match the target
38//    quantization method when it is used in inference.
39// 2. Dequantize it back to floating point numbers for the following ops, most
40//    likely matmul.
41template <typename Device, typename T>
42class QuantizeAndDequantizeV2Op : public OpKernel {
43 public:
44  explicit QuantizeAndDequantizeV2Op(OpKernelConstruction* ctx)
45      : OpKernel(ctx) {
46    OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_));
47    OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
48    OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63),
49                errors::InvalidArgument("num_bits is out of range: ", num_bits_,
50                                        " with signed_input_ ", signed_input_));
51    OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
52  }
53
54  void Compute(OpKernelContext* ctx) override {
55    const Tensor& input = ctx->input(0);
56
57    Tensor* output = nullptr;
58    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
59
60    Tensor input_min_tensor;
61    Tensor input_max_tensor;
62    if (range_given_) {
63      input_min_tensor = ctx->input(1);
64      input_max_tensor = ctx->input(2);
65      auto min_val = input_min_tensor.scalar<T>()();
66      auto max_val = input_max_tensor.scalar<T>()();
67      OP_REQUIRES(ctx, min_val <= max_val,
68                  errors::InvalidArgument("Invalid range: input_min ", min_val,
69                                          " > input_max ", max_val));
70    } else {
71      OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
72                                             TensorShape(), &input_min_tensor));
73      OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
74                                             TensorShape(), &input_max_tensor));
75    }
76
77    functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> f;
78    f(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_, num_bits_,
79      range_given_, &input_min_tensor, &input_max_tensor, output->flat<T>());
80  }
81
82 private:
83  bool signed_input_;
84  int num_bits_;
85  bool range_given_;
86};
87
88// Simulate quantization precision loss in a float tensor by:
89// 1. Quantize the tensor to fixed point numbers, which should match the target
90//    quantization method when it is used in inference.
91// 2. Dequantize it back to floating point numbers for the following ops, most
92//    likely matmul.
93// Almost identical to QuantizeAndDequantizeV2Op, except that num_bits is a
94// tensor.
95template <typename Device, typename T>
96class QuantizeAndDequantizeV3Op : public OpKernel {
97 public:
98  explicit QuantizeAndDequantizeV3Op(OpKernelConstruction* ctx)
99      : OpKernel(ctx) {
100    OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_));
101    OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
102  }
103
104  void Compute(OpKernelContext* ctx) override {
105    const Tensor& input = ctx->input(0);
106
107    Tensor* output = nullptr;
108    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
109
110    Tensor num_bits_tensor;
111    num_bits_tensor = ctx->input(3);
112    int num_bits_val = num_bits_tensor.scalar<int32>()();
113
114    OP_REQUIRES(
115        ctx, num_bits_val > 0 && num_bits_val < (signed_input_ ? 62 : 63),
116        errors::InvalidArgument("num_bits is out of range: ", num_bits_val,
117                                " with signed_input_ ", signed_input_));
118
119    Tensor input_min_tensor;
120    Tensor input_max_tensor;
121    if (range_given_) {
122      input_min_tensor = ctx->input(1);
123      input_max_tensor = ctx->input(2);
124      auto min_val = input_min_tensor.scalar<T>()();
125      auto max_val = input_max_tensor.scalar<T>()();
126      OP_REQUIRES(ctx, min_val <= max_val,
127                  errors::InvalidArgument("Invalid range: input_min ", min_val,
128                                          " > input_max ", max_val));
129    } else {
130      OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
131                                             TensorShape(), &input_min_tensor));
132      OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
133                                             TensorShape(), &input_max_tensor));
134    }
135
136    functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> f;
137    f(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_, num_bits_val,
138      range_given_, &input_min_tensor, &input_max_tensor, output->flat<T>());
139  }
140
141 private:
142  bool signed_input_;
143  bool range_given_;
144};
145
146// DEPRECATED: Use QuantizeAndDequantizeV2Op.
147template <typename Device, typename T>
148class QuantizeAndDequantizeOp : public OpKernel {
149 public:
150  explicit QuantizeAndDequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
151    OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_));
152    OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
153    OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63),
154                errors::InvalidArgument("num_bits is out of range: ", num_bits_,
155                                        " with signed_input_ ", signed_input_));
156    OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
157    OP_REQUIRES_OK(ctx, ctx->GetAttr("input_min", &input_min_));
158    OP_REQUIRES_OK(ctx, ctx->GetAttr("input_max", &input_max_));
159    if (range_given_) {
160      OP_REQUIRES(
161          ctx, input_min_ <= input_max_,
162          errors::InvalidArgument("Invalid range: input_min ", input_min_,
163                                  " > input_max ", input_max_));
164    }
165  }
166
167  void Compute(OpKernelContext* ctx) override {
168    const Tensor& input = ctx->input(0);
169
170    Tensor* output = nullptr;
171    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
172
173    // One global scale.
174    Tensor input_min_tensor(DataTypeToEnum<T>::value, TensorShape());
175    Tensor input_max_tensor(DataTypeToEnum<T>::value, TensorShape());
176    // Initialize the tensors with the values in the Attrs.
177    input_min_tensor.template scalar<T>()() = static_cast<T>(input_min_);
178    input_max_tensor.template scalar<T>()() = static_cast<T>(input_max_);
179
180    functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> functor;
181    functor(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_,
182            num_bits_, range_given_, &input_min_tensor, &input_max_tensor,
183            output->flat<T>());
184  }
185
186 private:
187  bool signed_input_;
188  int num_bits_;
189  bool range_given_;
190  float input_min_;
191  float input_max_;
192};
193
194// Specialization for CPUDevice.
195namespace functor {
196template <typename T>
197struct QuantizeAndDequantizeOneScaleFunctor<CPUDevice, T> {
198  void operator()(const CPUDevice& d, typename TTypes<T>::ConstVec input,
199                  const bool signed_input, const int num_bits,
200                  const bool range_given, Tensor* input_min_tensor,
201                  Tensor* input_max_tensor, typename TTypes<T>::Vec out) {
202    QuantizeAndDequantizeOneScaleImpl<CPUDevice, T>::Compute(
203        d, input, signed_input, num_bits, range_given, input_min_tensor,
204        input_max_tensor, out);
205  }
206};
207}  // namespace functor
208
209#define REGISTER_CPU_KERNEL(T)                                                 \
210  REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV2")                      \
211                              .Device(DEVICE_CPU)                              \
212                              .TypeConstraint<T>("T"),                         \
213                          QuantizeAndDequantizeV2Op<CPUDevice, T>);            \
214  REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV3")                      \
215                              .Device(DEVICE_CPU)                              \
216                              .TypeConstraint<T>("T"),                         \
217                          QuantizeAndDequantizeV3Op<CPUDevice, T>);            \
218  REGISTER_KERNEL_BUILDER(                                                     \
219      Name("QuantizeAndDequantize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
220      QuantizeAndDequantizeOp<CPUDevice, T>);
221TF_CALL_float(REGISTER_CPU_KERNEL);
222TF_CALL_double(REGISTER_CPU_KERNEL);
223#undef REGISTER_CPU_KERNEL
224
225#if GOOGLE_CUDA
226#define REGISTER_GPU_KERNEL(T)                                                 \
227  REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV2")                      \
228                              .Device(DEVICE_GPU)                              \
229                              .HostMemory("input_max")                         \
230                              .HostMemory("input_min")                         \
231                              .TypeConstraint<T>("T"),                         \
232                          QuantizeAndDequantizeV2Op<GPUDevice, T>);            \
233  REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV3")                      \
234                              .Device(DEVICE_GPU)                              \
235                              .HostMemory("input_max")                         \
236                              .HostMemory("input_min")                         \
237                              .HostMemory("num_bits")                          \
238                              .TypeConstraint<T>("T"),                         \
239                          QuantizeAndDequantizeV3Op<GPUDevice, T>);            \
240  REGISTER_KERNEL_BUILDER(                                                     \
241      Name("QuantizeAndDequantize").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
242      QuantizeAndDequantizeOp<GPUDevice, T>);
243TF_CALL_float(REGISTER_GPU_KERNEL);
244TF_CALL_double(REGISTER_GPU_KERNEL);
245#undef REGISTER_GPU_KERNEL
246#endif
247}  // namespace tensorflow
248