1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/tf2xla/xla_helpers.h" 17#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 18#include "tensorflow/compiler/tf2xla/xla_op_registry.h" 19#include "tensorflow/compiler/xla/client/lib/arithmetic.h" 20#include "tensorflow/core/platform/macros.h" 21 22namespace tensorflow { 23namespace { 24 25class QuantizeAndDequantizeOp : public XlaOpKernel { 26 public: 27 explicit QuantizeAndDequantizeOp(OpKernelConstruction* ctx) 28 : XlaOpKernel(ctx) { 29 OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_)); 30 OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_)); 31 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_)); 32 OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63), 33 errors::InvalidArgument("num_bits is out of range: ", num_bits_, 34 " with signed_input_ ", signed_input_)); 35 } 36 37 void Compile(XlaOpKernelContext* ctx) override { 38 xla::ComputationDataHandle input = ctx->Input(0); 39 const DataType data_type = ctx->input_type(0); 40 41 // Comments taken from semantics description at 42 // https://www.tensorflow.org/versions/r1.0/api_docs/cc/class/tensorflow/ops/quantize-and-dequantize 43 // 44 // ... we find m such that 45 // 46 // m = max(abs(input_min), abs(input_max)) if range_given is true, 47 // m = max(abs(min_elem(input)), 48 // abs(max_elem(input))) otherwise. 49 xla::ComputationBuilder* b = ctx->builder(); 50 xla::ComputationDataHandle input_min, input_max; 51 if (range_given_) { 52 double input_min_value, input_max_value; 53 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(1, &input_min_value)); 54 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(2, &input_max_value)); 55 input_min = XlaHelpers::FloatLiteral(b, data_type, input_min_value); 56 input_max = XlaHelpers::FloatLiteral(b, data_type, input_max_value); 57 } else { 58 const xla::Computation* fmax = ctx->GetOrCreateMax(data_type); 59 const xla::Computation* fmin = ctx->GetOrCreateMin(data_type); 60 input_min = 61 b->ReduceAll(input, XlaHelpers::MaxValue(b, data_type), *fmin); 62 input_max = 63 b->ReduceAll(input, XlaHelpers::MinValue(b, data_type), *fmax); 64 } 65 xla::ComputationDataHandle m = b->Max(b->Abs(input_min), b->Abs(input_max)); 66 67 // Next, we choose our fixed-point quantization buckets, [min_fixed, 68 // max_fixed]. If signed_input is true, this is 69 // 70 // [min_fixed, max_fixed ] = [-((1 << (num_bits - 1)) - 1), 71 // (1 << (num_bits - 1)) - 1]. 72 // 73 // Otherwise, if signed_input is false, the fixed-point range is 74 // 75 // [min_fixed, max_fixed] = [0, (1 << num_bits) - 1]. 76 int64 min_fixed, max_fixed; 77 if (signed_input_) { 78 min_fixed = -((1LL << (num_bits_ - 1)) - 1); 79 max_fixed = (1LL << (num_bits_ - 1)) - 1; 80 } else { 81 min_fixed = 0; 82 max_fixed = (1LL << num_bits_) - 1; 83 } 84 85 // From this we compute our scaling factor, s: 86 // 87 // s = (max_fixed - min_fixed) / (2 * m). 88 xla::ComputationDataHandle s = 89 b->Div(XlaHelpers::FloatLiteral(b, data_type, max_fixed - min_fixed), 90 b->Mul(XlaHelpers::FloatLiteral(b, data_type, 2.0), m)); 91 92 // Now we can quantize and dequantize the elements of our tensor. An element 93 // e is transformed into e': 94 // 95 // e' = (e * s).round_to_nearest() / s. 96 xla::ComputationDataHandle result = b->Div(b->Round(b->Mul(input, s)), s); 97 98 ctx->SetOutput(0, result); 99 } 100 101 int64 num_bits_; 102 bool signed_input_; 103 bool range_given_; 104}; 105 106REGISTER_XLA_OP(Name("QuantizeAndDequantizeV2"), QuantizeAndDequantizeOp); 107 108} // namespace 109} // namespace tensorflow 110