quantized_matmul_op.cc revision 9f2fa2ec4a68bb9e88ee20146927f84e4f9fe199
1122cdce33e3e0a01a7f82645617317530aa571fbA. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 3ca4e053aa52ab9a42467d4df814ca9272487dbdfPete WardenLicensed under the Apache License, Version 2.0 (the "License"); 4ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardenyou may not use this file except in compliance with the License. 5ca4e053aa52ab9a42467d4df814ca9272487dbdfPete WardenYou may obtain a copy of the License at 6ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 7ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden http://www.apache.org/licenses/LICENSE-2.0 8ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 9ca4e053aa52ab9a42467d4df814ca9272487dbdfPete WardenUnless required by applicable law or agreed to in writing, software 10ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardendistributed under the License is distributed on an "AS IS" BASIS, 11ca4e053aa52ab9a42467d4df814ca9272487dbdfPete WardenWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12ca4e053aa52ab9a42467d4df814ca9272487dbdfPete WardenSee the License for the specific language governing permissions and 13ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardenlimitations under the License. 14ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden==============================================================================*/ 15ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 16ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden// Implements a quantized eight-bit version of the matmul operation. 17ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 18ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden#include "external/gemmlowp/public/gemmlowp.h" 19ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden#include "tensorflow/contrib/quantization/kernels/quantization_utils.h" 20ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden#include "tensorflow/contrib/quantization/kernels/reference_gemm.h" 21ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden#include "tensorflow/core/framework/op_kernel.h" 22ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden#include "tensorflow/core/framework/tensor.h" 23ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden#include "tensorflow/core/lib/core/errors.h" 24ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 25ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardennamespace tensorflow { 26ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 27ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden// We have to break this out as a separate function because there are multiple 28ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden// combinations of transpose attributes we need to support, and they have to be 29ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden// compile-time constants to work with the templates used internally. 30ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardentemplate <bool TransposeA, bool TransposeB, bool TransposeC> 319f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlowervoid GemmlowpMultiply(OpKernelContext* op_context, const quint8* a_data, 329f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower const quint8* b_data, qint32* c_data, int m, int n, int k, 339f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower int offset_a, int offset_b, int lda, int ldb, int ldc) { 34ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden const uint8* a_data_as_uint8 = &(a_data->value); 35ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden const uint8* b_data_as_uint8 = &(b_data->value); 36ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden int32* c_data_as_int32 = &(c_data->value); 37ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden static const gemmlowp::MapOrder ResultOrder = 38ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden !TransposeC ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor; 39ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden static const gemmlowp::MapOrder LhsOrder = 40ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden !TransposeA ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor; 41ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden static const gemmlowp::MapOrder RhsOrder = 42ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden !TransposeB ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor; 43ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden gemmlowp::MatrixMap<const std::uint8_t, LhsOrder> lhs(a_data_as_uint8, m, k, 44ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden lda); 45ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden gemmlowp::MatrixMap<const std::uint8_t, RhsOrder> rhs(b_data_as_uint8, k, n, 46ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden ldb); 47ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden gemmlowp::MatrixMap<std::int32_t, ResultOrder> result(c_data_as_int32, m, n, 48ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden ldc); 49ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden const std::tuple<> empty_pipeline = {}; 509f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower auto& worker_threads = 519f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower *(op_context->device()->tensorflow_cpu_worker_threads()); 529f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower TensorflowGemmContext context(worker_threads.num_threads, 539f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower worker_threads.workers); 54ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t, 55ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden gemmlowp::DefaultL8R8BitDepthParams>( 56ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden &context, lhs, rhs, &result, -offset_a, -offset_b, empty_pipeline); 57ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden} 58ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 59ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardentemplate <class T1, class T2, class Toutput> 60ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardenclass QuantizedMatMulOp : public OpKernel { 61ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden public: 62ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden explicit QuantizedMatMulOp(OpKernelConstruction* context) 63ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden : OpKernel(context) { 64ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &transpose_a_)); 65ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &transpose_b_)); 66ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden } 67ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 68ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden void Compute(OpKernelContext* context) override { 69ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden const Tensor& a = context->input(0); 70ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden const Tensor& b = context->input(1); 71ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden const float min_a = context->input(2).flat<float>()(0); 72ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden const float max_a = context->input(3).flat<float>()(0); 73ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden const float min_b = context->input(4).flat<float>()(0); 74ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden const float max_b = context->input(5).flat<float>()(0); 75ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 76ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden // Make sure that we have valid quantization ranges for the input buffers. 77ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden // If the difference between the min and max is negative or zero, it makes 78ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden // it hard to do meaningful intermediate operations on the values. 79ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden OP_REQUIRES(context, (max_a > min_a), 80ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden errors::InvalidArgument("max_a must be larger than min_a.")); 81ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden OP_REQUIRES(context, (max_b > min_b), 82ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden errors::InvalidArgument("max_b must be larger than min_b.")); 83ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden const int32 offset_a = FloatToQuantizedUnclamped<T1>(0.0f, min_a, max_a); 84ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden const int32 offset_b = FloatToQuantizedUnclamped<T2>(0.0f, min_b, max_b); 85ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden const int32 offset_c = 0; 86ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden const int32 mult_c = 1; 87ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden const int32 shift_c = 0; 88ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 89ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden // Check that the dimensions of the two matrices are valid. 90ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden OP_REQUIRES(context, TensorShapeUtils::IsMatrix(a.shape()), 91ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden errors::InvalidArgument("In[0] is not a matrix")); 92ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden OP_REQUIRES(context, TensorShapeUtils::IsMatrix(b.shape()), 93ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden errors::InvalidArgument("In[1] is not a matrix")); 94ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair; 95ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden dim_pair[0].first = transpose_a_ ? 0 : 1; 96ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden dim_pair[0].second = transpose_b_ ? 1 : 0; 97ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 98ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden OP_REQUIRES(context, 99ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second), 100ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden errors::InvalidArgument("Matrix size-compatible: In[0]: ", 101ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden a.shape().DebugString(), ", In[1]: ", 102ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden b.shape().DebugString())); 103ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 104ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden OP_REQUIRES(context, ((shift_c >= 0) && (shift_c <= 31)), 105ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden errors::InvalidArgument("shift_c must be between 0 and 31, " 106ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden "inclusive.")); 107ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 108ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden int a_dim_remaining = 1 - dim_pair[0].first; 109ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden int b_dim_remaining = 1 - dim_pair[0].second; 110ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden TensorShape out_shape( 111ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)}); 112ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden Tensor* c = nullptr; 113ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &c)); 114ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden CHECK(c); 115ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 116ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden const T1* a_data = a.flat<T1>().data(); 117ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden const T2* b_data = b.flat<T2>().data(); 118ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden Toutput* c_data = c->flat<Toutput>().data(); 119ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 120ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden const bool transpose_c = false; 121ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden const size_t m = a.dim_size(a_dim_remaining); 122ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden const size_t n = b.dim_size(b_dim_remaining); 123ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden const size_t k = a.dim_size(dim_pair[0].first); 124ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden const size_t lda = a.dim_size(1); 125ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden const size_t ldb = b.dim_size(1); 126ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden const size_t ldc = n; 127ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 128ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden // The gemmlowp optimized library only works for a particular set of data 129ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden // types, so check if we meet those requirements and 130ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden // fall back to a slower reference implementation if not. 13119376f7010507f5f690bba2176a429ee3436afebA. Unique TensorFlower if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() && 132ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden std::is_same<Toutput, qint32>() && (offset_c == 0) && (mult_c == 1) && 133ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden (shift_c == 0) && (transpose_c == false)) { 134ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden if (transpose_a_) { 135ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden if (transpose_b_) { 1369f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower GemmlowpMultiply<true, true, false>(context, a_data, b_data, c_data, 1379f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower m, n, k, offset_a, offset_b, lda, 1389f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower ldb, ldc); 139ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden } else { 1409f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower GemmlowpMultiply<true, false, false>(context, a_data, b_data, c_data, 1419f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower m, n, k, offset_a, offset_b, lda, 1429f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower ldb, ldc); 143ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden } 144ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden } else { 145ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden if (transpose_b_) { 1469f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower GemmlowpMultiply<false, true, false>(context, a_data, b_data, c_data, 1479f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower m, n, k, offset_a, offset_b, lda, 1489f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower ldb, ldc); 149ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden } else { 1509f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower GemmlowpMultiply<false, false, false>(context, a_data, b_data, c_data, 1519f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower m, n, k, offset_a, offset_b, 1529f2fa2ec4a68bb9e88ee20146927f84e4f9fe199A. Unique TensorFlower lda, ldb, ldc); 153ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden } 154ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden } 155ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden } else { 156ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden ReferenceGemm<T1, T2, Toutput>( 157ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden transpose_a_, transpose_b_, transpose_c, m, n, k, a_data, offset_a, 158ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden lda, b_data, offset_b, ldb, c_data, shift_c, offset_c, mult_c, ldc); 159ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden } 160ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 161ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float min_c_value; 162ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float max_c_value; 163ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden QuantizationRangeForMultiplication<T1, T2, Toutput>( 164ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden min_a, max_a, min_b, max_b, &min_c_value, &max_c_value); 165ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden Tensor* c_min = nullptr; 166ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden OP_REQUIRES_OK(context, context->allocate_output(1, {}, &c_min)); 167ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden c_min->flat<float>()(0) = min_c_value; 168ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 169ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden Tensor* c_max = nullptr; 170ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden OP_REQUIRES_OK(context, context->allocate_output(2, {}, &c_max)); 171ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden c_max->flat<float>()(0) = max_c_value; 172ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden } 173ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 174ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden private: 175ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden bool transpose_a_; 176ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden bool transpose_b_; 177ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden}; 178ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 179ca4e053aa52ab9a42467d4df814ca9272487dbdfPete WardenREGISTER_KERNEL_BUILDER(Name("QuantizedMatMul") 180ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden .Device(DEVICE_CPU) 181ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden .TypeConstraint<quint8>("T1") 182ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden .TypeConstraint<quint8>("T2") 183ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden .TypeConstraint<qint32>("Toutput"), 184ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden QuantizedMatMulOp<quint8, quint8, qint32>); 185ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 186ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden} // namespace tensorflow 187