1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/core/kernels/sparse_matmul_op.h" 17#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" 18#include "tensorflow/core/framework/bfloat16.h" 19#include "tensorflow/core/framework/tensor.h" 20#include "tensorflow/core/framework/types.pb.h" 21#include "tensorflow/core/graph/node_builder.h" 22#include "tensorflow/core/lib/random/simple_philox.h" 23#include "tensorflow/core/lib/strings/stringprintf.h" 24#include "tensorflow/core/platform/test.h" 25#include "tensorflow/core/platform/test_benchmark.h" 26 27namespace tensorflow { 28random::PhiloxRandom philox(1, 1); 29random::SimplePhilox rnd(&philox); 30using Eigen::operator==; 31 32template <typename T> 33void Sparsify(Tensor* t, float sparsity) { 34 const int64 N = t->NumElements(); 35 CHECK_LE(sparsity, 1); 36 auto flat = t->flat<T>(); 37 if (sparsity == 1) { 38 flat.setZero(); 39 return; 40 } 41 static const uint32 K = 10000; 42 for (int64 i = 0; i < N; ++i) { 43 if (rnd.Uniform(K) < sparsity * K) { 44 flat(i) = T(0); 45 } else if (flat(i) == T(0)) { 46 flat(i) = T(1); 47 } 48 } 49} 50 51Node* SparseMatMulNode(Graph* g, Node* in0, Node* in1, bool transpose_a, 52 bool transpose_b, bool a_sparse, bool b_sparse) { 53 Node* ret; 54 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "SparseMatMul") 55 .Input(in0) 56 .Input(in1) 57 .Attr("transpose_a", transpose_a) 58 .Attr("transpose_b", transpose_b) 59 .Attr("a_is_sparse", a_sparse) 60 .Attr("b_is_sparse", b_sparse) 61 .Finalize(g, &ret)); 62 return ret; 63} 64 65template <typename TA, typename TB> 66static Graph* SparseMatMulHelper(Graph* g, int m, int n, int d, 67 float sparsity_a, float sparsity_b, 68 bool transpose_a, bool transpose_b) { 69 bool a_sparse = (sparsity_a > 0); 70 bool b_sparse = (sparsity_b > 0); 71 72 auto left_shape = transpose_a ? TensorShape({d, m}) : TensorShape({m, d}); 73 Tensor left(DataTypeToEnum<TA>::value, left_shape); 74 left.flat<TA>().setRandom(); 75 Sparsify<TA>(&left, sparsity_a); 76 77 auto right_shape = transpose_b ? TensorShape({n, d}) : TensorShape({d, n}); 78 Tensor right(DataTypeToEnum<TB>::value, right_shape); 79 right.flat<TB>().setRandom(); 80 Sparsify<TB>(&right, sparsity_b); 81 82 SparseMatMulNode(g, test::graph::Constant(g, left), 83 test::graph::Constant(g, right), transpose_a, transpose_b, 84 a_sparse, b_sparse); 85 return g; 86} 87 88template <typename TA, typename TB> 89static Graph* SparseMatMul(int m, int n, int d, float sparsity_a, 90 float sparsity_b, bool transpose_a, 91 bool transpose_b) { 92 Graph* g = new Graph(OpRegistry::Global()); 93 return SparseMatMulHelper<TA, TB>(g, m, n, d, sparsity_a, sparsity_b, 94 transpose_a, transpose_b); 95} 96 97static Graph* ReplicatedSparseMatMul(int m, int n, int d, float sparsity_1, 98 float sparsity_2, int copies) { 99 Graph* g = new Graph(OpRegistry::Global()); 100 for (int i = 0; i < copies; ++i) { 101 SparseMatMulHelper<float, float>(g, m, n, d, sparsity_1, sparsity_2, false, 102 false); 103 } 104 return g; 105} 106 107#define BM_SPARSE(M, K, N, S1, S2, TRA, TRB, TA, TB) \ 108 static void \ 109 BM_Sparse##_##M##_##K##_##N##_##S1##_##S2##_##TRA##_##TRB##_##TA##_##TB( \ 110 int iters) { \ 111 testing::StopTiming(); \ 112 testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2); \ 113 auto label = strings::Printf("tr_a: %d tr_b: %d sp_a: %0.2f sp_b: %0.2f", \ 114 TRA, TRB, S1 / 100.0, S2 / 100.0); \ 115 testing::SetLabel(label); \ 116 testing::UseRealTime(); \ 117 auto g = SparseMatMul<TA, TB>(M, N, K, S1 / 100.0, S2 / 100.0, TRA, TRB); \ 118 testing::StartTiming(); \ 119 test::Benchmark("cpu", g).Run(iters); \ 120 } \ 121 BENCHMARK( \ 122 BM_Sparse##_##M##_##K##_##N##_##S1##_##S2##_##TRA##_##TRB##_##TA##_##TB); 123 124#define BM_SPARSE_REPLICATED(M, K, N, S1, S2, Copies) \ 125 static void BM_Sparse_replicated##_##M##_##K##_##N##_##S1##_##S2##_##Copies( \ 126 int iters) { \ 127 testing::StopTiming(); \ 128 testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * Copies * \ 129 2); \ 130 auto label = strings::Printf("copies: %d sp_a: %0.2f sp_b: %0.2f", \ 131 (Copies), S1 / 100.0, S2 / 100.0); \ 132 testing::SetLabel(label); \ 133 testing::UseRealTime(); \ 134 auto g = \ 135 ReplicatedSparseMatMul(M, N, K, S1 / 100.0, S2 / 100.0, (Copies)); \ 136 testing::StartTiming(); \ 137 test::Benchmark("cpu", g).Run(iters); \ 138 } \ 139 BENCHMARK(BM_Sparse_replicated##_##M##_##K##_##N##_##S1##_##S2##_##Copies); 140 141#define BM_SPARSE_FLOAT(M, K, N, S1, S2, TRA, TRB) \ 142 BM_SPARSE(M, K, N, S1, S2, TRA, TRB, float, float) 143#define BM_SPARSE_BFLOAT16(M, K, N, S1, S2, TRA, TRB) \ 144 BM_SPARSE(M, K, N, S1, S2, TRA, TRB, bfloat16, bfloat16) 145#define BM_SPARSE_FLOAT_BFLOAT16(M, K, N, S1, S2, TRA, TRB) \ 146 BM_SPARSE(M, K, N, S1, S2, TRA, TRB, float, bfloat16) 147#define BM_SPARSE_BFLOAT16_FLOAT(M, K, N, S1, S2, TRA, TRB) \ 148 BM_SPARSE(M, K, N, S1, S2, TRA, TRB, bfloat16, float) 149 150// Test sparse b 151BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 0, false, false); 152BM_SPARSE_FLOAT(2048, 2048, 2048, 1, 0, false, false); 153BM_SPARSE_FLOAT(2048, 2048, 2048, 50, 0, false, false); 154BM_SPARSE_FLOAT(2048, 2048, 2048, 85, 0, false, false); 155BM_SPARSE_FLOAT(2048, 2048, 2048, 99, 0, false, false); 156// Test sparse a 157BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 50, false, false); 158BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 85, false, false); 159// Test transposing 160BM_SPARSE_FLOAT(2048, 2048, 2048, 85, 0, true, false); 161BM_SPARSE_FLOAT(2048, 2048, 2048, 85, 0, false, true); 162BM_SPARSE_FLOAT(2048, 2048, 2048, 85, 0, true, true); 163BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 85, true, false); 164BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 85, false, true); 165BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 85, true, true); 166 167// Test smaller sizes 168BM_SPARSE_FLOAT(1024, 1024, 1024, 0, 0, false, false); 169BM_SPARSE_FLOAT(1024, 1024, 1024, 1, 0, false, false); 170BM_SPARSE_FLOAT(1024, 1024, 1024, 85, 0, false, false); 171BM_SPARSE_FLOAT(256, 256, 256, 1, 0, false, false); 172BM_SPARSE_FLOAT(512, 512, 512, 1, 0, false, false); 173BM_SPARSE_FLOAT(2560, 400, 1024, 85, 0, false, false); 174BM_SPARSE_FLOAT(2560, 400, 1024, 85, 0, true, false); 175 176BM_SPARSE_FLOAT(400, 800, 2560, 85, 0, false, false); 177BM_SPARSE_FLOAT(400, 2560, 1024, 85, 0, false, false); 178BM_SPARSE_FLOAT(400, 1024, 256, 85, 0, false, false); 179BM_SPARSE_FLOAT(400, 256, 1, 85, 0, false, false); 180 181BM_SPARSE_REPLICATED(400, 800, 2560, 85, 0, 6); 182BM_SPARSE_REPLICATED(400, 2560, 1024, 85, 0, 6); 183BM_SPARSE_REPLICATED(400, 1024, 256, 85, 0, 6); 184BM_SPARSE_REPLICATED(400, 256, 1, 85, 0, 6); 185 186BM_SPARSE_FLOAT(2048, 1792, 1024, 85, 0, false, false); 187BM_SPARSE_FLOAT(2048, 1024, 768, 85, 0, false, false); 188BM_SPARSE_FLOAT(2048, 768, 512, 85, 0, false, false); 189BM_SPARSE_FLOAT(2048, 512, 256, 85, 0, false, false); 190 191BM_SPARSE_FLOAT(2049, 1792, 1024, 85, 0, false, false); 192BM_SPARSE_FLOAT(2049, 1024, 768, 85, 0, false, false); 193BM_SPARSE_FLOAT(2049, 768, 512, 85, 0, false, false); 194BM_SPARSE_FLOAT(2049, 512, 256, 85, 0, false, false); 195 196BM_SPARSE_REPLICATED(2048, 1792, 1024, 85, 0, 6); 197BM_SPARSE_REPLICATED(2048, 1024, 768, 85, 0, 6); 198BM_SPARSE_REPLICATED(2048, 768, 512, 85, 0, 6); 199BM_SPARSE_REPLICATED(2048, 512, 256, 85, 0, 6); 200 201// Test bfloat16 202BM_SPARSE_BFLOAT16(2048, 2048, 2048, 0, 0, false, false); 203BM_SPARSE_BFLOAT16(2048, 2048, 2048, 1, 0, false, false); 204BM_SPARSE_BFLOAT16(2048, 2048, 2048, 85, 0, false, false); 205BM_SPARSE_BFLOAT16(2048, 2048, 2048, 99, 0, false, false); 206BM_SPARSE_BFLOAT16_FLOAT(2048, 2048, 2048, 85, 0, false, false); 207BM_SPARSE_BFLOAT16_FLOAT(2048, 2048, 2048, 99, 0, false, false); 208BM_SPARSE_FLOAT_BFLOAT16(2048, 2048, 2048, 85, 0, false, false); 209BM_SPARSE_FLOAT_BFLOAT16(2048, 2048, 2048, 99, 0, false, false); 210 211static Graph* MultiSparseMatMul(int m, int n, int d, float sparsity_1, 212 float sparsity_2, int copies) { 213 Graph* g = new Graph(OpRegistry::Global()); 214 for (int i = 0; i < copies; ++i) { 215 SparseMatMulHelper<float, float>(g, d, n, m, sparsity_1, sparsity_2, true, 216 false); 217 SparseMatMulHelper<float, float>(g, m, d, n, sparsity_2, 0, false, true); 218 } 219 return g; 220} 221 222#define BM_SPARSE_MULTI(M, K, N, S1, S2, Copies) \ 223 static void BM_Sparse_Multi##_##M##_##K##_##N##_##S1##_##S2##_##Copies( \ 224 int iters) { \ 225 testing::StopTiming(); \ 226 testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2 * 2 * \ 227 Copies); \ 228 auto label = strings::Printf("%d_%d_%d_%d_%0.2f_%0.2f", M, K, N, Copies, \ 229 S1 / 100.0, S2 / 100.0); \ 230 testing::SetLabel(label); \ 231 testing::UseRealTime(); \ 232 auto g = MultiSparseMatMul(M, N, K, S1 / 100.0, S2 / 100.0, Copies); \ 233 testing::StartTiming(); \ 234 test::Benchmark("cpu", g).Run(iters); \ 235 } \ 236 BENCHMARK(BM_Sparse_Multi##_##M##_##K##_##N##_##S1##_##S2##_##Copies); 237 238BM_SPARSE_MULTI(1024, 2140, 4096, 0, 82, 1); 239BM_SPARSE_MULTI(1024, 4096, 2048, 83, 83, 1); 240BM_SPARSE_MULTI(400, 800, 2560, 85, 85, 1); 241BM_SPARSE_MULTI(400, 2560, 1024, 85, 85, 1); 242BM_SPARSE_MULTI(400, 1024, 256, 85, 85, 1); 243BM_SPARSE_MULTI(400, 256, 1, 85, 85, 1); 244 245BM_SPARSE_MULTI(2048, 1792, 1024, 85, 85, 1); 246BM_SPARSE_MULTI(2048, 1024, 768, 85, 85, 1); 247BM_SPARSE_MULTI(2048, 768, 512, 85, 85, 1); 248BM_SPARSE_MULTI(2048, 512, 256, 85, 85, 1); 249 250BM_SPARSE_MULTI(2048, 1792, 1024, 85, 85, 3); 251BM_SPARSE_MULTI(2048, 1024, 768, 85, 85, 3); 252BM_SPARSE_MULTI(2048, 768, 512, 85, 85, 3); 253BM_SPARSE_MULTI(2048, 512, 256, 85, 85, 3); 254 255BM_SPARSE_MULTI(2048, 1792, 1024, 85, 85, 6); 256BM_SPARSE_MULTI(2048, 1024, 768, 85, 85, 6); 257BM_SPARSE_MULTI(2048, 768, 512, 85, 85, 6); 258BM_SPARSE_MULTI(2048, 512, 256, 85, 85, 6); 259 260} // end namespace tensorflow 261 262namespace Eigen { 263namespace internal { 264 265class SparseMatmulOpTest : public ::testing::Test { 266 protected: 267 SparseMatmulOpTest() 268 : PacketSize(Eigen::internal::packet_traits<float>::size) { 269 typedef typename NumTraits<float>::Real RealFloat; 270 271 for (int i = 0; i < kMaxPacketSize; ++i) { 272 data1[i] = internal::random<float>() / RealFloat(PacketSize); 273 data2[i] = internal::random<float>() / RealFloat(PacketSize); 274 data3[i] = internal::random<float>() / RealFloat(PacketSize); 275 } 276 for (int i = kMaxPacketSize; i < kMaxPacketSize * 2; ++i) { 277 data3[i] = internal::random<float>() / RealFloat(PacketSize); 278 } 279 280 // zero out lower 16-bits of mantissa of data3 values 281 // copy bfloat representation to data3_bfloat16 282 for (int i = 0; i < kMaxPacketSize * 2; ++i) { 283 uint16_t* data3_p = reinterpret_cast<uint16_t*>(&data3[i]); 284 uint16_t* data3_bfloat16_p = 285 reinterpret_cast<uint16_t*>(data3_bfloat16) + i; 286#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ 287 data3_p[1] = 0; 288 data3_bfloat16_p[0] = data3_p[0]; 289#else 290 data3_p[0] = 0; 291 data3_bfloat16_p[0] = data3_p[1]; 292#endif 293 } 294 } 295 296 bool areApprox(const float* a, const float* b, int size) { 297 for (int i = 0; i < size; ++i) { 298 if (a[i] != b[i] && !internal::isApprox(a[i], b[i])) { 299 auto ma = Map<const Matrix<float, 1, Dynamic> >(a, size); 300 auto mb = Map<const Matrix<float, 1, Dynamic> >(b, size); 301 std::cout << "[" << ma << "]" 302 << " != [" << mb << "], differences: [" << (mb - ma) << "]\n"; 303 return false; 304 } 305 } 306 return true; 307 } 308 309#ifdef EIGEN_VECTORIZE_AVX512 310 static const int kMaxPacketSize = 16; 311#elif defined EIGEN_VECTORIZE_AVX || defined EIGEN_VECTORIZE_AVX2 312 static const int kMaxPacketSize = 8; 313#else 314 static const int kMaxPacketSize = 4; 315#endif 316 typedef typename Eigen::internal::packet_traits<float>::type Packet; 317 const int PacketSize; 318 // float values 319 EIGEN_ALIGN_MAX float data1[kMaxPacketSize]; 320 // output of intrinsics 321 EIGEN_ALIGN_MAX float data2[kMaxPacketSize]; 322 // float values with only 7 mantissa bits (bfloat representable) 323 EIGEN_ALIGN_MAX float data3[kMaxPacketSize * 2]; 324 // bfloat16 representation of data3 325 EIGEN_ALIGN_MAX float data3_bfloat16[kMaxPacketSize]; 326 EIGEN_ALIGN_MAX float ref[kMaxPacketSize]; 327}; 328 329TEST_F(SparseMatmulOpTest, BroadcastPacketTest) { 330 for (int i = 0; i < PacketSize; ++i) ref[i] = data1[0]; 331 internal::pstoreu(data2, internal::pbroadcast_first<Packet>( 332 internal::ploadu<Packet>(data1))); 333 ASSERT_TRUE(areApprox(ref, data2, PacketSize)); 334 if (PacketSize > 1) { 335 for (int i = 0; i < PacketSize; ++i) ref[i] = data1[1]; 336 internal::pstoreu(data2, internal::pbroadcast_second<Packet>( 337 internal::ploadu<Packet>(data1))); 338 ASSERT_TRUE(areApprox(ref, data2, PacketSize)); 339 340 if (PacketSize > 2) { 341 for (int i = 0; i < PacketSize; ++i) ref[i] = data1[2]; 342 internal::pstoreu(data2, internal::pbroadcast_third<Packet>( 343 internal::ploadu<Packet>(data1))); 344 ASSERT_TRUE(areApprox(ref, data2, PacketSize)); 345 346 if (PacketSize > 3) { 347 for (int i = 0; i < PacketSize; ++i) ref[i] = data1[3]; 348 internal::pstoreu(data2, internal::pbroadcast_fourth<Packet>( 349 internal::ploadu<Packet>(data1))); 350 ASSERT_TRUE(areApprox(ref, data2, PacketSize)); 351 } 352 } 353 } 354} 355 356TEST_F(SparseMatmulOpTest, InterleavePacketTest) { 357 if (PacketSize == 8) { // AVX 358 for (int i = 0; i < PacketSize / 4; ++i) ref[i] = data1[i]; 359 for (int i = PacketSize / 4; i < PacketSize / 2; ++i) 360 ref[i] = data1[i + PacketSize / 4]; 361 for (int i = PacketSize / 2; i < 3 * PacketSize / 4; ++i) 362 ref[i] = data1[i - PacketSize / 4]; 363 for (int i = 3 * PacketSize / 4; i < PacketSize; ++i) ref[i] = data1[i]; 364 } else { 365 // No interleaving done for smaller packets 366 for (int i = 0; i < PacketSize; ++i) ref[i] = data1[i]; 367 } 368 369 internal::pstoreu(data2, internal::pinterleave4x64<Packet>( 370 internal::ploadu<Packet>(data1))); 371 ASSERT_TRUE(areApprox(ref, data2, PacketSize)); 372} 373 374TEST_F(SparseMatmulOpTest, Bfloat16ExpandTest) { 375 if (PacketSize == 8) { // AVX 376 for (int i = 0; i < PacketSize / 2; ++i) { 377 ref[i] = data3[i]; 378 } 379 for (int i = 0; i < PacketSize / 2; ++i) { 380 ref[i + PacketSize / 2] = data3[i + PacketSize]; 381 } 382 } else { 383 for (int i = 0; i < PacketSize; ++i) { 384 ref[i] = data3[i]; 385 } 386 } 387 internal::pstoreu(data2, internal::pexpand_bf16_l<Packet>( 388 internal::ploadu<Packet>(data3_bfloat16))); 389 ASSERT_TRUE(areApprox(ref, data2, PacketSize)); 390 391 if (PacketSize == 8) { // AVX 392 for (int i = 0; i < PacketSize / 2; ++i) { 393 ref[i] = data3[i + PacketSize / 2]; 394 } 395 for (int i = 0; i < PacketSize / 2; ++i) { 396 ref[i + PacketSize / 2] = data3[i + 3 * PacketSize / 2]; 397 } 398 } else { 399 for (int i = 0; i < PacketSize; ++i) { 400 ref[i] = data3[i + PacketSize]; 401 } 402 } 403 404 internal::pstoreu(data2, internal::pexpand_bf16_u<Packet>( 405 internal::ploadu<Packet>(data3_bfloat16))); 406 ASSERT_TRUE(areApprox(ref, data2, PacketSize)); 407} 408 409TEST_F(SparseMatmulOpTest, Bfloat16LoadTest) { 410 if (PacketSize >= 4) { 411 for (int i = 0; i < 4; ++i) ref[i] = data3[i]; 412 internal::pstoreu(data2, internal::pload4bf16<Packet>(data3_bfloat16)); 413 ASSERT_TRUE(areApprox(ref, data2, 4)); 414 415 internal::pstoreu(data2, internal::pload2bf16<Packet>(data3_bfloat16)); 416 ASSERT_TRUE(areApprox(ref, data2, 2)); 417 } 418} 419 420} // namespace internal 421} // namespace Eigen 422