1122cdce33e3e0a01a7f82645617317530aa571fbA. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
3ca4e053aa52ab9a42467d4df814ca9272487dbdfPete WardenLicensed under the Apache License, Version 2.0 (the "License");
4ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardenyou may not use this file except in compliance with the License.
5ca4e053aa52ab9a42467d4df814ca9272487dbdfPete WardenYou may obtain a copy of the License at
6ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
7ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    http://www.apache.org/licenses/LICENSE-2.0
8ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
9ca4e053aa52ab9a42467d4df814ca9272487dbdfPete WardenUnless required by applicable law or agreed to in writing, software
10ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardendistributed under the License is distributed on an "AS IS" BASIS,
11ca4e053aa52ab9a42467d4df814ca9272487dbdfPete WardenWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12ca4e053aa52ab9a42467d4df814ca9272487dbdfPete WardenSee the License for the specific language governing permissions and
13ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardenlimitations under the License.
14ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden==============================================================================*/
15ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
16ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden// Implements a quantized eight-bit version of the matmul operation.
17ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
1816cda320d92cfbfc6870140691ae2c5e6286688cA. Unique TensorFlower#define EIGEN_USE_THREADS
1916cda320d92cfbfc6870140691ae2c5e6286688cA. Unique TensorFlower
20c4b09b5df79625a70853fd66b5caa7dd92fb4d1fA. Unique TensorFlower#define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
210811b0b6e0bae76489866d7649bdbb7ffdb4e3efKiril Gorovoy#include "public/gemmlowp.h"
223e3633c8b5e2817d502de6dd892c5495cb5e85a3A. Unique TensorFlower#include "tensorflow/core/framework/op_kernel.h"
233e3633c8b5e2817d502de6dd892c5495cb5e85a3A. Unique TensorFlower#include "tensorflow/core/framework/tensor.h"
2416cda320d92cfbfc6870140691ae2c5e6286688cA. Unique TensorFlower#include "tensorflow/core/kernels/meta_support.h"
2516cda320d92cfbfc6870140691ae2c5e6286688cA. Unique TensorFlower#include "tensorflow/core/kernels/quantization_utils.h"
2616cda320d92cfbfc6870140691ae2c5e6286688cA. Unique TensorFlower#include "tensorflow/core/kernels/reference_gemm.h"
27ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden#include "tensorflow/core/lib/core/errors.h"
28ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
29ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardennamespace tensorflow {
30ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
31ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden// We have to break this out as a separate function because there are multiple
32ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden// combinations of transpose attributes we need to support, and they have to be
33ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden// compile-time constants to work with the templates used internally.
34ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardentemplate <bool TransposeA, bool TransposeB, bool TransposeC>
359f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlowervoid GemmlowpMultiply(OpKernelContext* op_context, const quint8* a_data,
369f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower                      const quint8* b_data, qint32* c_data, int m, int n, int k,
379f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower                      int offset_a, int offset_b, int lda, int ldb, int ldc) {
38ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  const uint8* a_data_as_uint8 = &(a_data->value);
39ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  const uint8* b_data_as_uint8 = &(b_data->value);
40ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  int32* c_data_as_int32 = &(c_data->value);
41ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  static const gemmlowp::MapOrder ResultOrder =
42ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      !TransposeC ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor;
43ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  static const gemmlowp::MapOrder LhsOrder =
44ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      !TransposeA ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor;
45ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  static const gemmlowp::MapOrder RhsOrder =
46ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      !TransposeB ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor;
47ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  gemmlowp::MatrixMap<const std::uint8_t, LhsOrder> lhs(a_data_as_uint8, m, k,
48ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                                                        lda);
49ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  gemmlowp::MatrixMap<const std::uint8_t, RhsOrder> rhs(b_data_as_uint8, k, n,
50ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                                                        ldb);
51ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  gemmlowp::MatrixMap<std::int32_t, ResultOrder> result(c_data_as_int32, m, n,
52ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                                                        ldc);
53ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  const std::tuple<> empty_pipeline = {};
549f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower  auto& worker_threads =
559f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower      *(op_context->device()->tensorflow_cpu_worker_threads());
569f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower  TensorflowGemmContext context(worker_threads.num_threads,
579f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower                                worker_threads.workers);
58ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t,
59ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                                   gemmlowp::DefaultL8R8BitDepthParams>(
60ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      &context, lhs, rhs, &result, -offset_a, -offset_b, empty_pipeline);
611118de02db298159d7df7008df59ffd92801b59fPatrick Nguyen  // Since gemmlowp uses assembly to write to the output, msan won't detect
621118de02db298159d7df7008df59ffd92801b59fPatrick Nguyen  // the output buffer as written to, so we mark it manually.
631118de02db298159d7df7008df59ffd92801b59fPatrick Nguyen  TF_ANNOTATE_MEMORY_IS_INITIALIZED(c_data_as_int32, m * n * sizeof(int32));
64ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden}
65ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
66ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardentemplate <class T1, class T2, class Toutput>
67ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardenclass QuantizedMatMulOp : public OpKernel {
68ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden public:
69ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  explicit QuantizedMatMulOp(OpKernelConstruction* context)
70ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      : OpKernel(context) {
71ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &transpose_a_));
72ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &transpose_b_));
73ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  }
74ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
75ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  void Compute(OpKernelContext* context) override {
76ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const Tensor& a = context->input(0);
77ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const Tensor& b = context->input(1);
78ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const float min_a = context->input(2).flat<float>()(0);
79ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const float max_a = context->input(3).flat<float>()(0);
80ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const float min_b = context->input(4).flat<float>()(0);
81ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const float max_b = context->input(5).flat<float>()(0);
82ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
83ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    // Make sure that we have valid quantization ranges for the input buffers.
84ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    // If the difference between the min and max is negative or zero, it makes
85ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    // it hard to do meaningful intermediate operations on the values.
86ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    OP_REQUIRES(context, (max_a > min_a),
87ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                errors::InvalidArgument("max_a must be larger than min_a."));
88ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    OP_REQUIRES(context, (max_b > min_b),
89ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                errors::InvalidArgument("max_b must be larger than min_b."));
90ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const int32 offset_a = FloatToQuantizedUnclamped<T1>(0.0f, min_a, max_a);
91ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const int32 offset_b = FloatToQuantizedUnclamped<T2>(0.0f, min_b, max_b);
92ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const int32 offset_c = 0;
93ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const int32 mult_c = 1;
94ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const int32 shift_c = 0;
95ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
96ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    // Check that the dimensions of the two matrices are valid.
97ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    OP_REQUIRES(context, TensorShapeUtils::IsMatrix(a.shape()),
98ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                errors::InvalidArgument("In[0] is not a matrix"));
99ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    OP_REQUIRES(context, TensorShapeUtils::IsMatrix(b.shape()),
100ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                errors::InvalidArgument("In[1] is not a matrix"));
101ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
102ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    dim_pair[0].first = transpose_a_ ? 0 : 1;
103ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    dim_pair[0].second = transpose_b_ ? 1 : 0;
104ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
105ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    OP_REQUIRES(context,
106ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
107982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                errors::InvalidArgument(
108982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                    "Matrix size-compatible: In[0]: ", a.shape().DebugString(),
109982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                    ", In[1]: ", b.shape().DebugString()));
110ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
111ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    OP_REQUIRES(context, ((shift_c >= 0) && (shift_c <= 31)),
112ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                errors::InvalidArgument("shift_c must be between 0 and 31, "
113ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                                        "inclusive."));
114ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
115ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    int a_dim_remaining = 1 - dim_pair[0].first;
116ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    int b_dim_remaining = 1 - dim_pair[0].second;
117ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    TensorShape out_shape(
118ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden        {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
119ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    Tensor* c = nullptr;
120ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &c));
121ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    CHECK(c);
122ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
123ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const T1* a_data = a.flat<T1>().data();
124ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const T2* b_data = b.flat<T2>().data();
125ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    Toutput* c_data = c->flat<Toutput>().data();
126ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
127ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const bool transpose_c = false;
128ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const size_t m = a.dim_size(a_dim_remaining);
129ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const size_t n = b.dim_size(b_dim_remaining);
130ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const size_t k = a.dim_size(dim_pair[0].first);
131ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const size_t lda = a.dim_size(1);
132ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const size_t ldb = b.dim_size(1);
133ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    const size_t ldc = n;
134ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
13516cda320d92cfbfc6870140691ae2c5e6286688cA. Unique TensorFlower    if (meta::IsSupportedAndEnabled() && std::is_same<T1, quint8>() &&
13616cda320d92cfbfc6870140691ae2c5e6286688cA. Unique TensorFlower        std::is_same<T2, quint8>() && std::is_same<Toutput, qint32>() &&
13716cda320d92cfbfc6870140691ae2c5e6286688cA. Unique TensorFlower        (offset_c == 0) && (mult_c == 1) && (shift_c == 0) &&
138bdb2967a298236e24011405907cd19737386934eA. Unique TensorFlower        (transpose_c == false) && (k <= 2048)) {
13916cda320d92cfbfc6870140691ae2c5e6286688cA. Unique TensorFlower      // Gemmlowp/meta code path works on 32 & 64 bit Arm with NEON Simd and
14016cda320d92cfbfc6870140691ae2c5e6286688cA. Unique TensorFlower      // allows optimized quantized 8bit to 32bit gemm.
14116cda320d92cfbfc6870140691ae2c5e6286688cA. Unique TensorFlower      meta::QuantizedGemm(context, transpose_a_, transpose_b_, a_data, b_data,
142f9694e876e56c8e4f46e355e8686d7174fdc3b69A. Unique TensorFlower                          c_data, m, n, k, -offset_a, -offset_b, lda, ldb, ldc);
14316cda320d92cfbfc6870140691ae2c5e6286688cA. Unique TensorFlower    } else if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() &&
14416cda320d92cfbfc6870140691ae2c5e6286688cA. Unique TensorFlower               std::is_same<Toutput, qint32>() && (offset_c == 0) &&
14516cda320d92cfbfc6870140691ae2c5e6286688cA. Unique TensorFlower               (mult_c == 1) && (shift_c == 0) && (transpose_c == false)) {
14616cda320d92cfbfc6870140691ae2c5e6286688cA. Unique TensorFlower      // The gemmlowp optimized library only works for a particular set of data
14716cda320d92cfbfc6870140691ae2c5e6286688cA. Unique TensorFlower      // types, so check if we meet those requirements and fall back to a slower
14816cda320d92cfbfc6870140691ae2c5e6286688cA. Unique TensorFlower      // reference implementation if not.
149ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      if (transpose_a_) {
150ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden        if (transpose_b_) {
1519f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower          GemmlowpMultiply<true, true, false>(context, a_data, b_data, c_data,
1529f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower                                              m, n, k, offset_a, offset_b, lda,
1539f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower                                              ldb, ldc);
154ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden        } else {
1559f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower          GemmlowpMultiply<true, false, false>(context, a_data, b_data, c_data,
1569f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower                                               m, n, k, offset_a, offset_b, lda,
1579f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower                                               ldb, ldc);
158ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden        }
159ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      } else {
160ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden        if (transpose_b_) {
1619f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower          GemmlowpMultiply<false, true, false>(context, a_data, b_data, c_data,
1629f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower                                               m, n, k, offset_a, offset_b, lda,
1639f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower                                               ldb, ldc);
164ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden        } else {
1659f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower          GemmlowpMultiply<false, false, false>(context, a_data, b_data, c_data,
1669f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower                                                m, n, k, offset_a, offset_b,
1679f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower                                                lda, ldb, ldc);
168ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden        }
169ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      }
170ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    } else {
171ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      ReferenceGemm<T1, T2, Toutput>(
172ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden          transpose_a_, transpose_b_, transpose_c, m, n, k, a_data, offset_a,
173ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden          lda, b_data, offset_b, ldb, c_data, shift_c, offset_c, mult_c, ldc);
174ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    }
175ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
176ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float min_c_value;
177ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float max_c_value;
178ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    QuantizationRangeForMultiplication<T1, T2, Toutput>(
179ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden        min_a, max_a, min_b, max_b, &min_c_value, &max_c_value);
180ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    Tensor* c_min = nullptr;
181ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    OP_REQUIRES_OK(context, context->allocate_output(1, {}, &c_min));
182ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    c_min->flat<float>()(0) = min_c_value;
183ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
184ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    Tensor* c_max = nullptr;
185ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    OP_REQUIRES_OK(context, context->allocate_output(2, {}, &c_max));
186ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    c_max->flat<float>()(0) = max_c_value;
187ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  }
188ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
189ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden private:
190ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  bool transpose_a_;
191ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  bool transpose_b_;
192ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden};
193ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
194ca4e053aa52ab9a42467d4df814ca9272487dbdfPete WardenREGISTER_KERNEL_BUILDER(Name("QuantizedMatMul")
195ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                            .Device(DEVICE_CPU)
196ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                            .TypeConstraint<quint8>("T1")
197ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                            .TypeConstraint<quint8>("T2")
198ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                            .TypeConstraint<qint32>("Toutput"),
199ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                        QuantizedMatMulOp<quint8, quint8, qint32>);
200ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
201ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden}  // namespace tensorflow
202