quantized_matmul_op.cc revision 9f2fa2ec4a68bb9e88ee20146927f84e4f9fe199
1122cdce33e3e0a01a7f82645617317530aa571fbA. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
3ca4e053aa52ab9a42467d4df814ca9272487dbdfPete WardenLicensed under the Apache License, Version 2.0 (the "License");
4ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardenyou may not use this file except in compliance with the License.
5ca4e053aa52ab9a42467d4df814ca9272487dbdfPete WardenYou may obtain a copy of the License at
6ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
7ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    http://www.apache.org/licenses/LICENSE-2.0
8ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
9ca4e053aa52ab9a42467d4df814ca9272487dbdfPete WardenUnless required by applicable law or agreed to in writing, software
10ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardendistributed under the License is distributed on an "AS IS" BASIS,
11ca4e053aa52ab9a42467d4df814ca9272487dbdfPete WardenWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12ca4e053aa52ab9a42467d4df814ca9272487dbdfPete WardenSee the License for the specific language governing permissions and
13ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardenlimitations under the License.
14ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden==============================================================================*/
15ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
16ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden// Implements a quantized eight-bit version of the matmul operation.
17ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
18ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden#include "external/gemmlowp/public/gemmlowp.h"
19ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden#include "tensorflow/contrib/quantization/kernels/quantization_utils.h"
20ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden#include "tensorflow/contrib/quantization/kernels/reference_gemm.h"
21ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden#include "tensorflow/core/framework/op_kernel.h"
22ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden#include "tensorflow/core/framework/tensor.h"
23ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden#include "tensorflow/core/lib/core/errors.h"
24ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
25ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardennamespace tensorflow {
26ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
27ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden// We have to break this out as a separate function because there are multiple
28ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden// combinations of transpose attributes we need to support, and they have to be
29ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden// compile-time constants to work with the templates used internally.
30ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardentemplate <bool TransposeA, bool TransposeB, bool TransposeC>
319f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlowervoid GemmlowpMultiply(OpKernelContext* op_context, const quint8* a_data,
329f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower                      const quint8* b_data, qint32* c_data, int m, int n, int k,
339f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower                      int offset_a, int offset_b, int lda, int ldb, int ldc) {
34ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  const uint8* a_data_as_uint8 = &(a_data->value);
35ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  const uint8* b_data_as_uint8 = &(b_data->value);
36ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  int32* c_data_as_int32 = &(c_data->value);
37ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  static const gemmlowp::MapOrder ResultOrder =
38ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      !TransposeC ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor;
39ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  static const gemmlowp::MapOrder LhsOrder =
40ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      !TransposeA ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor;
41ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  static const gemmlowp::MapOrder RhsOrder =
42ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      !TransposeB ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor;
43ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  gemmlowp::MatrixMap<const std::uint8_t, LhsOrder> lhs(a_data_as_uint8, m, k,
44ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                                                        lda);
45ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  gemmlowp::MatrixMap<const std::uint8_t, RhsOrder> rhs(b_data_as_uint8, k, n,
46ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                                                        ldb);
47ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  gemmlowp::MatrixMap<std::int32_t, ResultOrder> result(c_data_as_int32, m, n,
48ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                                                        ldc);
49ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  const std::tuple<> empty_pipeline = {};
509f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower  auto& worker_threads =
519f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower      *(op_context->device()->tensorflow_cpu_worker_threads());
529f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower  TensorflowGemmContext context(worker_threads.num_threads,
539f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower                                worker_threads.workers);
54ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t,
55ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                                   gemmlowp::DefaultL8R8BitDepthParams>(
56ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      &context, lhs, rhs, &result, -offset_a, -offset_b, empty_pipeline);
57ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden}
58ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
59ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardentemplate <class T1, class T2, class Toutput>
60ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardenclass QuantizedMatMulOp : public OpKernel {
61ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden public:
62ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  explicit QuantizedMatMulOp(OpKernelConstruction* context)
63ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      : OpKernel(context) {
64ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &transpose_a_));
65ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &transpose_b_));
66ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  }
67ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
68ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  void Compute(OpKernelContext* context) override {
69ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const Tensor& a = context->input(0);
70ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const Tensor& b = context->input(1);
71ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const float min_a = context->input(2).flat<float>()(0);
72ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const float max_a = context->input(3).flat<float>()(0);
73ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const float min_b = context->input(4).flat<float>()(0);
74ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const float max_b = context->input(5).flat<float>()(0);
75ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
76ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    // Make sure that we have valid quantization ranges for the input buffers.
77ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    // If the difference between the min and max is negative or zero, it makes
78ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    // it hard to do meaningful intermediate operations on the values.
79ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    OP_REQUIRES(context, (max_a > min_a),
80ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                errors::InvalidArgument("max_a must be larger than min_a."));
81ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    OP_REQUIRES(context, (max_b > min_b),
82ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                errors::InvalidArgument("max_b must be larger than min_b."));
83ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const int32 offset_a = FloatToQuantizedUnclamped<T1>(0.0f, min_a, max_a);
84ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const int32 offset_b = FloatToQuantizedUnclamped<T2>(0.0f, min_b, max_b);
85ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const int32 offset_c = 0;
86ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const int32 mult_c = 1;
87ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const int32 shift_c = 0;
88ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
89ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    // Check that the dimensions of the two matrices are valid.
90ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    OP_REQUIRES(context, TensorShapeUtils::IsMatrix(a.shape()),
91ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                errors::InvalidArgument("In[0] is not a matrix"));
92ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    OP_REQUIRES(context, TensorShapeUtils::IsMatrix(b.shape()),
93ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                errors::InvalidArgument("In[1] is not a matrix"));
94ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
95ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    dim_pair[0].first = transpose_a_ ? 0 : 1;
96ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    dim_pair[0].second = transpose_b_ ? 1 : 0;
97ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
98ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    OP_REQUIRES(context,
99ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
100ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                errors::InvalidArgument("Matrix size-compatible: In[0]: ",
101ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                                        a.shape().DebugString(), ", In[1]: ",
102ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                                        b.shape().DebugString()));
103ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
104ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    OP_REQUIRES(context, ((shift_c >= 0) && (shift_c <= 31)),
105ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                errors::InvalidArgument("shift_c must be between 0 and 31, "
106ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                                        "inclusive."));
107ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
108ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    int a_dim_remaining = 1 - dim_pair[0].first;
109ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    int b_dim_remaining = 1 - dim_pair[0].second;
110ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    TensorShape out_shape(
111ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden        {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
112ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    Tensor* c = nullptr;
113ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &c));
114ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    CHECK(c);
115ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
116ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const T1* a_data = a.flat<T1>().data();
117ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const T2* b_data = b.flat<T2>().data();
118ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    Toutput* c_data = c->flat<Toutput>().data();
119ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
120ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const bool transpose_c = false;
121ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const size_t m = a.dim_size(a_dim_remaining);
122ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const size_t n = b.dim_size(b_dim_remaining);
123ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const size_t k = a.dim_size(dim_pair[0].first);
124ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const size_t lda = a.dim_size(1);
125ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const size_t ldb = b.dim_size(1);
126ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const size_t ldc = n;
127ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
128ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    // The gemmlowp optimized library only works for a particular set of data
129ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    // types, so check if we meet those requirements and
130ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    // fall back to a slower reference implementation if not.
13119376f7010507f5f690bba2176a429ee3436afebA. Unique TensorFlower    if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() &&
132ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden        std::is_same<Toutput, qint32>() && (offset_c == 0) && (mult_c == 1) &&
133ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden        (shift_c == 0) && (transpose_c == false)) {
134ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      if (transpose_a_) {
135ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden        if (transpose_b_) {
1369f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower          GemmlowpMultiply<true, true, false>(context, a_data, b_data, c_data,
1379f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower                                              m, n, k, offset_a, offset_b, lda,
1389f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower                                              ldb, ldc);
139ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden        } else {
1409f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower          GemmlowpMultiply<true, false, false>(context, a_data, b_data, c_data,
1419f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower                                               m, n, k, offset_a, offset_b, lda,
1429f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower                                               ldb, ldc);
143ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden        }
144ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      } else {
145ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden        if (transpose_b_) {
1469f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower          GemmlowpMultiply<false, true, false>(context, a_data, b_data, c_data,
1479f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower                                               m, n, k, offset_a, offset_b, lda,
1489f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower                                               ldb, ldc);
149ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden        } else {
1509f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower          GemmlowpMultiply<false, false, false>(context, a_data, b_data, c_data,
1519f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower                                                m, n, k, offset_a, offset_b,
1529f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower                                                lda, ldb, ldc);
153ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden        }
154ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      }
155ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    } else {
156ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      ReferenceGemm<T1, T2, Toutput>(
157ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden          transpose_a_, transpose_b_, transpose_c, m, n, k, a_data, offset_a,
158ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden          lda, b_data, offset_b, ldb, c_data, shift_c, offset_c, mult_c, ldc);
159ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    }
160ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
161ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float min_c_value;
162ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float max_c_value;
163ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    QuantizationRangeForMultiplication<T1, T2, Toutput>(
164ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden        min_a, max_a, min_b, max_b, &min_c_value, &max_c_value);
165ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    Tensor* c_min = nullptr;
166ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    OP_REQUIRES_OK(context, context->allocate_output(1, {}, &c_min));
167ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    c_min->flat<float>()(0) = min_c_value;
168ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
169ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    Tensor* c_max = nullptr;
170ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    OP_REQUIRES_OK(context, context->allocate_output(2, {}, &c_max));
171ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    c_max->flat<float>()(0) = max_c_value;
172ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  }
173ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
174ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden private:
175ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  bool transpose_a_;
176ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  bool transpose_b_;
177ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden};
178ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
179ca4e053aa52ab9a42467d4df814ca9272487dbdfPete WardenREGISTER_KERNEL_BUILDER(Name("QuantizedMatMul")
180ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                            .Device(DEVICE_CPU)
181ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                            .TypeConstraint<quint8>("T1")
182ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                            .TypeConstraint<quint8>("T2")
183ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                            .TypeConstraint<qint32>("Toutput"),
184ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                        QuantizedMatMulOp<quint8, quint8, qint32>);
185ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
186ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden}  // namespace tensorflow
187