1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/sparse_matmul_op.h"
17#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
18#include "tensorflow/core/framework/bfloat16.h"
19#include "tensorflow/core/framework/tensor.h"
20#include "tensorflow/core/framework/types.pb.h"
21#include "tensorflow/core/graph/node_builder.h"
22#include "tensorflow/core/lib/random/simple_philox.h"
23#include "tensorflow/core/lib/strings/stringprintf.h"
24#include "tensorflow/core/platform/test.h"
25#include "tensorflow/core/platform/test_benchmark.h"
26
27namespace tensorflow {
28random::PhiloxRandom philox(1, 1);
29random::SimplePhilox rnd(&philox);
30using Eigen::operator==;
31
32template <typename T>
33void Sparsify(Tensor* t, float sparsity) {
34  const int64 N = t->NumElements();
35  CHECK_LE(sparsity, 1);
36  auto flat = t->flat<T>();
37  if (sparsity == 1) {
38    flat.setZero();
39    return;
40  }
41  static const uint32 K = 10000;
42  for (int64 i = 0; i < N; ++i) {
43    if (rnd.Uniform(K) < sparsity * K) {
44      flat(i) = T(0);
45    } else if (flat(i) == T(0)) {
46      flat(i) = T(1);
47    }
48  }
49}
50
51Node* SparseMatMulNode(Graph* g, Node* in0, Node* in1, bool transpose_a,
52                       bool transpose_b, bool a_sparse, bool b_sparse) {
53  Node* ret;
54  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "SparseMatMul")
55                  .Input(in0)
56                  .Input(in1)
57                  .Attr("transpose_a", transpose_a)
58                  .Attr("transpose_b", transpose_b)
59                  .Attr("a_is_sparse", a_sparse)
60                  .Attr("b_is_sparse", b_sparse)
61                  .Finalize(g, &ret));
62  return ret;
63}
64
65template <typename TA, typename TB>
66static Graph* SparseMatMulHelper(Graph* g, int m, int n, int d,
67                                 float sparsity_a, float sparsity_b,
68                                 bool transpose_a, bool transpose_b) {
69  bool a_sparse = (sparsity_a > 0);
70  bool b_sparse = (sparsity_b > 0);
71
72  auto left_shape = transpose_a ? TensorShape({d, m}) : TensorShape({m, d});
73  Tensor left(DataTypeToEnum<TA>::value, left_shape);
74  left.flat<TA>().setRandom();
75  Sparsify<TA>(&left, sparsity_a);
76
77  auto right_shape = transpose_b ? TensorShape({n, d}) : TensorShape({d, n});
78  Tensor right(DataTypeToEnum<TB>::value, right_shape);
79  right.flat<TB>().setRandom();
80  Sparsify<TB>(&right, sparsity_b);
81
82  SparseMatMulNode(g, test::graph::Constant(g, left),
83                   test::graph::Constant(g, right), transpose_a, transpose_b,
84                   a_sparse, b_sparse);
85  return g;
86}
87
88template <typename TA, typename TB>
89static Graph* SparseMatMul(int m, int n, int d, float sparsity_a,
90                           float sparsity_b, bool transpose_a,
91                           bool transpose_b) {
92  Graph* g = new Graph(OpRegistry::Global());
93  return SparseMatMulHelper<TA, TB>(g, m, n, d, sparsity_a, sparsity_b,
94                                    transpose_a, transpose_b);
95}
96
97static Graph* ReplicatedSparseMatMul(int m, int n, int d, float sparsity_1,
98                                     float sparsity_2, int copies) {
99  Graph* g = new Graph(OpRegistry::Global());
100  for (int i = 0; i < copies; ++i) {
101    SparseMatMulHelper<float, float>(g, m, n, d, sparsity_1, sparsity_2, false,
102                                     false);
103  }
104  return g;
105}
106
107#define BM_SPARSE(M, K, N, S1, S2, TRA, TRB, TA, TB)                           \
108  static void                                                                  \
109      BM_Sparse##_##M##_##K##_##N##_##S1##_##S2##_##TRA##_##TRB##_##TA##_##TB( \
110          int iters) {                                                         \
111    testing::StopTiming();                                                     \
112    testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2);        \
113    auto label = strings::Printf("tr_a: %d tr_b: %d sp_a: %0.2f sp_b: %0.2f",  \
114                                 TRA, TRB, S1 / 100.0, S2 / 100.0);            \
115    testing::SetLabel(label);                                                  \
116    testing::UseRealTime();                                                    \
117    auto g = SparseMatMul<TA, TB>(M, N, K, S1 / 100.0, S2 / 100.0, TRA, TRB);  \
118    testing::StartTiming();                                                    \
119    test::Benchmark("cpu", g).Run(iters);                                      \
120  }                                                                            \
121  BENCHMARK(                                                                   \
122      BM_Sparse##_##M##_##K##_##N##_##S1##_##S2##_##TRA##_##TRB##_##TA##_##TB);
123
124#define BM_SPARSE_REPLICATED(M, K, N, S1, S2, Copies)                          \
125  static void BM_Sparse_replicated##_##M##_##K##_##N##_##S1##_##S2##_##Copies( \
126      int iters) {                                                             \
127    testing::StopTiming();                                                     \
128    testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * Copies *   \
129                            2);                                                \
130    auto label = strings::Printf("copies: %d sp_a: %0.2f sp_b: %0.2f",         \
131                                 (Copies), S1 / 100.0, S2 / 100.0);            \
132    testing::SetLabel(label);                                                  \
133    testing::UseRealTime();                                                    \
134    auto g =                                                                   \
135        ReplicatedSparseMatMul(M, N, K, S1 / 100.0, S2 / 100.0, (Copies));     \
136    testing::StartTiming();                                                    \
137    test::Benchmark("cpu", g).Run(iters);                                      \
138  }                                                                            \
139  BENCHMARK(BM_Sparse_replicated##_##M##_##K##_##N##_##S1##_##S2##_##Copies);
140
141#define BM_SPARSE_FLOAT(M, K, N, S1, S2, TRA, TRB) \
142  BM_SPARSE(M, K, N, S1, S2, TRA, TRB, float, float)
143#define BM_SPARSE_BFLOAT16(M, K, N, S1, S2, TRA, TRB) \
144  BM_SPARSE(M, K, N, S1, S2, TRA, TRB, bfloat16, bfloat16)
145#define BM_SPARSE_FLOAT_BFLOAT16(M, K, N, S1, S2, TRA, TRB) \
146  BM_SPARSE(M, K, N, S1, S2, TRA, TRB, float, bfloat16)
147#define BM_SPARSE_BFLOAT16_FLOAT(M, K, N, S1, S2, TRA, TRB) \
148  BM_SPARSE(M, K, N, S1, S2, TRA, TRB, bfloat16, float)
149
150// Test sparse b
151BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 0, false, false);
152BM_SPARSE_FLOAT(2048, 2048, 2048, 1, 0, false, false);
153BM_SPARSE_FLOAT(2048, 2048, 2048, 50, 0, false, false);
154BM_SPARSE_FLOAT(2048, 2048, 2048, 85, 0, false, false);
155BM_SPARSE_FLOAT(2048, 2048, 2048, 99, 0, false, false);
156// Test sparse a
157BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 50, false, false);
158BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 85, false, false);
159// Test transposing
160BM_SPARSE_FLOAT(2048, 2048, 2048, 85, 0, true, false);
161BM_SPARSE_FLOAT(2048, 2048, 2048, 85, 0, false, true);
162BM_SPARSE_FLOAT(2048, 2048, 2048, 85, 0, true, true);
163BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 85, true, false);
164BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 85, false, true);
165BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 85, true, true);
166
167// Test smaller sizes
168BM_SPARSE_FLOAT(1024, 1024, 1024, 0, 0, false, false);
169BM_SPARSE_FLOAT(1024, 1024, 1024, 1, 0, false, false);
170BM_SPARSE_FLOAT(1024, 1024, 1024, 85, 0, false, false);
171BM_SPARSE_FLOAT(256, 256, 256, 1, 0, false, false);
172BM_SPARSE_FLOAT(512, 512, 512, 1, 0, false, false);
173BM_SPARSE_FLOAT(2560, 400, 1024, 85, 0, false, false);
174BM_SPARSE_FLOAT(2560, 400, 1024, 85, 0, true, false);
175
176BM_SPARSE_FLOAT(400, 800, 2560, 85, 0, false, false);
177BM_SPARSE_FLOAT(400, 2560, 1024, 85, 0, false, false);
178BM_SPARSE_FLOAT(400, 1024, 256, 85, 0, false, false);
179BM_SPARSE_FLOAT(400, 256, 1, 85, 0, false, false);
180
181BM_SPARSE_REPLICATED(400, 800, 2560, 85, 0, 6);
182BM_SPARSE_REPLICATED(400, 2560, 1024, 85, 0, 6);
183BM_SPARSE_REPLICATED(400, 1024, 256, 85, 0, 6);
184BM_SPARSE_REPLICATED(400, 256, 1, 85, 0, 6);
185
186BM_SPARSE_FLOAT(2048, 1792, 1024, 85, 0, false, false);
187BM_SPARSE_FLOAT(2048, 1024, 768, 85, 0, false, false);
188BM_SPARSE_FLOAT(2048, 768, 512, 85, 0, false, false);
189BM_SPARSE_FLOAT(2048, 512, 256, 85, 0, false, false);
190
191BM_SPARSE_FLOAT(2049, 1792, 1024, 85, 0, false, false);
192BM_SPARSE_FLOAT(2049, 1024, 768, 85, 0, false, false);
193BM_SPARSE_FLOAT(2049, 768, 512, 85, 0, false, false);
194BM_SPARSE_FLOAT(2049, 512, 256, 85, 0, false, false);
195
196BM_SPARSE_REPLICATED(2048, 1792, 1024, 85, 0, 6);
197BM_SPARSE_REPLICATED(2048, 1024, 768, 85, 0, 6);
198BM_SPARSE_REPLICATED(2048, 768, 512, 85, 0, 6);
199BM_SPARSE_REPLICATED(2048, 512, 256, 85, 0, 6);
200
201// Test bfloat16
202BM_SPARSE_BFLOAT16(2048, 2048, 2048, 0, 0, false, false);
203BM_SPARSE_BFLOAT16(2048, 2048, 2048, 1, 0, false, false);
204BM_SPARSE_BFLOAT16(2048, 2048, 2048, 85, 0, false, false);
205BM_SPARSE_BFLOAT16(2048, 2048, 2048, 99, 0, false, false);
206BM_SPARSE_BFLOAT16_FLOAT(2048, 2048, 2048, 85, 0, false, false);
207BM_SPARSE_BFLOAT16_FLOAT(2048, 2048, 2048, 99, 0, false, false);
208BM_SPARSE_FLOAT_BFLOAT16(2048, 2048, 2048, 85, 0, false, false);
209BM_SPARSE_FLOAT_BFLOAT16(2048, 2048, 2048, 99, 0, false, false);
210
211static Graph* MultiSparseMatMul(int m, int n, int d, float sparsity_1,
212                                float sparsity_2, int copies) {
213  Graph* g = new Graph(OpRegistry::Global());
214  for (int i = 0; i < copies; ++i) {
215    SparseMatMulHelper<float, float>(g, d, n, m, sparsity_1, sparsity_2, true,
216                                     false);
217    SparseMatMulHelper<float, float>(g, m, d, n, sparsity_2, 0, false, true);
218  }
219  return g;
220}
221
222#define BM_SPARSE_MULTI(M, K, N, S1, S2, Copies)                             \
223  static void BM_Sparse_Multi##_##M##_##K##_##N##_##S1##_##S2##_##Copies(    \
224      int iters) {                                                           \
225    testing::StopTiming();                                                   \
226    testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2 * 2 *  \
227                            Copies);                                         \
228    auto label = strings::Printf("%d_%d_%d_%d_%0.2f_%0.2f", M, K, N, Copies, \
229                                 S1 / 100.0, S2 / 100.0);                    \
230    testing::SetLabel(label);                                                \
231    testing::UseRealTime();                                                  \
232    auto g = MultiSparseMatMul(M, N, K, S1 / 100.0, S2 / 100.0, Copies);     \
233    testing::StartTiming();                                                  \
234    test::Benchmark("cpu", g).Run(iters);                                    \
235  }                                                                          \
236  BENCHMARK(BM_Sparse_Multi##_##M##_##K##_##N##_##S1##_##S2##_##Copies);
237
238BM_SPARSE_MULTI(1024, 2140, 4096, 0, 82, 1);
239BM_SPARSE_MULTI(1024, 4096, 2048, 83, 83, 1);
240BM_SPARSE_MULTI(400, 800, 2560, 85, 85, 1);
241BM_SPARSE_MULTI(400, 2560, 1024, 85, 85, 1);
242BM_SPARSE_MULTI(400, 1024, 256, 85, 85, 1);
243BM_SPARSE_MULTI(400, 256, 1, 85, 85, 1);
244
245BM_SPARSE_MULTI(2048, 1792, 1024, 85, 85, 1);
246BM_SPARSE_MULTI(2048, 1024, 768, 85, 85, 1);
247BM_SPARSE_MULTI(2048, 768, 512, 85, 85, 1);
248BM_SPARSE_MULTI(2048, 512, 256, 85, 85, 1);
249
250BM_SPARSE_MULTI(2048, 1792, 1024, 85, 85, 3);
251BM_SPARSE_MULTI(2048, 1024, 768, 85, 85, 3);
252BM_SPARSE_MULTI(2048, 768, 512, 85, 85, 3);
253BM_SPARSE_MULTI(2048, 512, 256, 85, 85, 3);
254
255BM_SPARSE_MULTI(2048, 1792, 1024, 85, 85, 6);
256BM_SPARSE_MULTI(2048, 1024, 768, 85, 85, 6);
257BM_SPARSE_MULTI(2048, 768, 512, 85, 85, 6);
258BM_SPARSE_MULTI(2048, 512, 256, 85, 85, 6);
259
260}  // end namespace tensorflow
261
262namespace Eigen {
263namespace internal {
264
265class SparseMatmulOpTest : public ::testing::Test {
266 protected:
267  SparseMatmulOpTest()
268      : PacketSize(Eigen::internal::packet_traits<float>::size) {
269    typedef typename NumTraits<float>::Real RealFloat;
270
271    for (int i = 0; i < kMaxPacketSize; ++i) {
272      data1[i] = internal::random<float>() / RealFloat(PacketSize);
273      data2[i] = internal::random<float>() / RealFloat(PacketSize);
274      data3[i] = internal::random<float>() / RealFloat(PacketSize);
275    }
276    for (int i = kMaxPacketSize; i < kMaxPacketSize * 2; ++i) {
277      data3[i] = internal::random<float>() / RealFloat(PacketSize);
278    }
279
280    // zero out lower 16-bits of mantissa of data3 values
281    // copy bfloat representation to data3_bfloat16
282    for (int i = 0; i < kMaxPacketSize * 2; ++i) {
283      uint16_t* data3_p = reinterpret_cast<uint16_t*>(&data3[i]);
284      uint16_t* data3_bfloat16_p =
285          reinterpret_cast<uint16_t*>(data3_bfloat16) + i;
286#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
287      data3_p[1] = 0;
288      data3_bfloat16_p[0] = data3_p[0];
289#else
290      data3_p[0] = 0;
291      data3_bfloat16_p[0] = data3_p[1];
292#endif
293    }
294  }
295
296  bool areApprox(const float* a, const float* b, int size) {
297    for (int i = 0; i < size; ++i) {
298      if (a[i] != b[i] && !internal::isApprox(a[i], b[i])) {
299        auto ma = Map<const Matrix<float, 1, Dynamic> >(a, size);
300        auto mb = Map<const Matrix<float, 1, Dynamic> >(b, size);
301        std::cout << "[" << ma << "]"
302                  << " != [" << mb << "], differences: [" << (mb - ma) << "]\n";
303        return false;
304      }
305    }
306    return true;
307  }
308
309#ifdef EIGEN_VECTORIZE_AVX512
310  static const int kMaxPacketSize = 16;
311#elif defined EIGEN_VECTORIZE_AVX || defined EIGEN_VECTORIZE_AVX2
312  static const int kMaxPacketSize = 8;
313#else
314  static const int kMaxPacketSize = 4;
315#endif
316  typedef typename Eigen::internal::packet_traits<float>::type Packet;
317  const int PacketSize;
318  // float values
319  EIGEN_ALIGN_MAX float data1[kMaxPacketSize];
320  // output of intrinsics
321  EIGEN_ALIGN_MAX float data2[kMaxPacketSize];
322  // float values with only 7 mantissa bits (bfloat representable)
323  EIGEN_ALIGN_MAX float data3[kMaxPacketSize * 2];
324  // bfloat16 representation of data3
325  EIGEN_ALIGN_MAX float data3_bfloat16[kMaxPacketSize];
326  EIGEN_ALIGN_MAX float ref[kMaxPacketSize];
327};
328
329TEST_F(SparseMatmulOpTest, BroadcastPacketTest) {
330  for (int i = 0; i < PacketSize; ++i) ref[i] = data1[0];
331  internal::pstoreu(data2, internal::pbroadcast_first<Packet>(
332                               internal::ploadu<Packet>(data1)));
333  ASSERT_TRUE(areApprox(ref, data2, PacketSize));
334  if (PacketSize > 1) {
335    for (int i = 0; i < PacketSize; ++i) ref[i] = data1[1];
336    internal::pstoreu(data2, internal::pbroadcast_second<Packet>(
337                                 internal::ploadu<Packet>(data1)));
338    ASSERT_TRUE(areApprox(ref, data2, PacketSize));
339
340    if (PacketSize > 2) {
341      for (int i = 0; i < PacketSize; ++i) ref[i] = data1[2];
342      internal::pstoreu(data2, internal::pbroadcast_third<Packet>(
343                                   internal::ploadu<Packet>(data1)));
344      ASSERT_TRUE(areApprox(ref, data2, PacketSize));
345
346      if (PacketSize > 3) {
347        for (int i = 0; i < PacketSize; ++i) ref[i] = data1[3];
348        internal::pstoreu(data2, internal::pbroadcast_fourth<Packet>(
349                                     internal::ploadu<Packet>(data1)));
350        ASSERT_TRUE(areApprox(ref, data2, PacketSize));
351      }
352    }
353  }
354}
355
356TEST_F(SparseMatmulOpTest, InterleavePacketTest) {
357  if (PacketSize == 8) {  // AVX
358    for (int i = 0; i < PacketSize / 4; ++i) ref[i] = data1[i];
359    for (int i = PacketSize / 4; i < PacketSize / 2; ++i)
360      ref[i] = data1[i + PacketSize / 4];
361    for (int i = PacketSize / 2; i < 3 * PacketSize / 4; ++i)
362      ref[i] = data1[i - PacketSize / 4];
363    for (int i = 3 * PacketSize / 4; i < PacketSize; ++i) ref[i] = data1[i];
364  } else {
365    // No interleaving done for smaller packets
366    for (int i = 0; i < PacketSize; ++i) ref[i] = data1[i];
367  }
368
369  internal::pstoreu(data2, internal::pinterleave4x64<Packet>(
370                               internal::ploadu<Packet>(data1)));
371  ASSERT_TRUE(areApprox(ref, data2, PacketSize));
372}
373
374TEST_F(SparseMatmulOpTest, Bfloat16ExpandTest) {
375  if (PacketSize == 8) {  // AVX
376    for (int i = 0; i < PacketSize / 2; ++i) {
377      ref[i] = data3[i];
378    }
379    for (int i = 0; i < PacketSize / 2; ++i) {
380      ref[i + PacketSize / 2] = data3[i + PacketSize];
381    }
382  } else {
383    for (int i = 0; i < PacketSize; ++i) {
384      ref[i] = data3[i];
385    }
386  }
387  internal::pstoreu(data2, internal::pexpand_bf16_l<Packet>(
388                               internal::ploadu<Packet>(data3_bfloat16)));
389  ASSERT_TRUE(areApprox(ref, data2, PacketSize));
390
391  if (PacketSize == 8) {  // AVX
392    for (int i = 0; i < PacketSize / 2; ++i) {
393      ref[i] = data3[i + PacketSize / 2];
394    }
395    for (int i = 0; i < PacketSize / 2; ++i) {
396      ref[i + PacketSize / 2] = data3[i + 3 * PacketSize / 2];
397    }
398  } else {
399    for (int i = 0; i < PacketSize; ++i) {
400      ref[i] = data3[i + PacketSize];
401    }
402  }
403
404  internal::pstoreu(data2, internal::pexpand_bf16_u<Packet>(
405                               internal::ploadu<Packet>(data3_bfloat16)));
406  ASSERT_TRUE(areApprox(ref, data2, PacketSize));
407}
408
409TEST_F(SparseMatmulOpTest, Bfloat16LoadTest) {
410  if (PacketSize >= 4) {
411    for (int i = 0; i < 4; ++i) ref[i] = data3[i];
412    internal::pstoreu(data2, internal::pload4bf16<Packet>(data3_bfloat16));
413    ASSERT_TRUE(areApprox(ref, data2, 4));
414
415    internal::pstoreu(data2, internal::pload2bf16<Packet>(data3_bfloat16));
416    ASSERT_TRUE(areApprox(ref, data2, 2));
417  }
418}
419
420}  // namespace internal
421}  // namespace Eigen
422