1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/tf2xla/xla_helpers.h"
17#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
18#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
20#include "tensorflow/core/platform/macros.h"
21
22namespace tensorflow {
23namespace {
24
25class QuantizeAndDequantizeOp : public XlaOpKernel {
26 public:
27  explicit QuantizeAndDequantizeOp(OpKernelConstruction* ctx)
28      : XlaOpKernel(ctx) {
29    OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_));
30    OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
31    OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
32    OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63),
33                errors::InvalidArgument("num_bits is out of range: ", num_bits_,
34                                        " with signed_input_ ", signed_input_));
35  }
36
37  void Compile(XlaOpKernelContext* ctx) override {
38    xla::ComputationDataHandle input = ctx->Input(0);
39    const DataType data_type = ctx->input_type(0);
40
41    // Comments taken from semantics description at
42    // https://www.tensorflow.org/versions/r1.0/api_docs/cc/class/tensorflow/ops/quantize-and-dequantize
43    //
44    // ... we find m such that
45    //
46    // m = max(abs(input_min), abs(input_max)) if range_given is true,
47    // m = max(abs(min_elem(input)),
48    //         abs(max_elem(input))) otherwise.
49    xla::ComputationBuilder* b = ctx->builder();
50    xla::ComputationDataHandle input_min, input_max;
51    if (range_given_) {
52      double input_min_value, input_max_value;
53      OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(1, &input_min_value));
54      OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(2, &input_max_value));
55      input_min = XlaHelpers::FloatLiteral(b, data_type, input_min_value);
56      input_max = XlaHelpers::FloatLiteral(b, data_type, input_max_value);
57    } else {
58      const xla::Computation* fmax = ctx->GetOrCreateMax(data_type);
59      const xla::Computation* fmin = ctx->GetOrCreateMin(data_type);
60      input_min =
61          b->ReduceAll(input, XlaHelpers::MaxValue(b, data_type), *fmin);
62      input_max =
63          b->ReduceAll(input, XlaHelpers::MinValue(b, data_type), *fmax);
64    }
65    xla::ComputationDataHandle m = b->Max(b->Abs(input_min), b->Abs(input_max));
66
67    // Next, we choose our fixed-point quantization buckets, [min_fixed,
68    // max_fixed]. If signed_input is true, this is
69    //
70    // [min_fixed, max_fixed ] = [-((1 << (num_bits - 1)) - 1),
71    //                             (1 << (num_bits - 1)) - 1].
72    //
73    // Otherwise, if signed_input is false, the fixed-point range is
74    //
75    // [min_fixed, max_fixed] = [0, (1 << num_bits) - 1].
76    int64 min_fixed, max_fixed;
77    if (signed_input_) {
78      min_fixed = -((1LL << (num_bits_ - 1)) - 1);
79      max_fixed = (1LL << (num_bits_ - 1)) - 1;
80    } else {
81      min_fixed = 0;
82      max_fixed = (1LL << num_bits_) - 1;
83    }
84
85    // From this we compute our scaling factor, s:
86    //
87    // s = (max_fixed - min_fixed) / (2 * m).
88    xla::ComputationDataHandle s =
89        b->Div(XlaHelpers::FloatLiteral(b, data_type, max_fixed - min_fixed),
90               b->Mul(XlaHelpers::FloatLiteral(b, data_type, 2.0), m));
91
92    // Now we can quantize and dequantize the elements of our tensor. An element
93    // e is transformed into e':
94    //
95    // e' = (e * s).round_to_nearest() / s.
96    xla::ComputationDataHandle result = b->Div(b->Round(b->Mul(input, s)), s);
97
98    ctx->SetOutput(0, result);
99  }
100
101  int64 num_bits_;
102  bool signed_input_;
103  bool range_given_;
104};
105
106REGISTER_XLA_OP(Name("QuantizeAndDequantizeV2"), QuantizeAndDequantizeOp);
107
108}  // namespace
109}  // namespace tensorflow
110