1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/math_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/kernels/sparse_matmul_op.h"
21
22#include <map>
23#include <memory>
24#include <vector>
25
26#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27#include "tensorflow/core/common_runtime/device.h"
28#include "tensorflow/core/framework/bfloat16.h"
29#include "tensorflow/core/framework/op.h"
30#include "tensorflow/core/framework/op_kernel.h"
31#include "tensorflow/core/framework/types.h"
32#include "tensorflow/core/kernels/fill_functor.h"
33#include "tensorflow/core/lib/core/blocking_counter.h"
34#include "tensorflow/core/lib/core/threadpool.h"
35#include "tensorflow/core/lib/gtl/stl_util.h"
36#include "tensorflow/core/platform/logging.h"
37#include "tensorflow/core/platform/macros.h"
38#include "tensorflow/core/platform/mutex.h"
39#include "tensorflow/core/platform/thread_annotations.h"
40#include "tensorflow/core/platform/types.h"
41#ifdef TENSORFLOW_USE_LIBXSMM
42#include "include/libxsmm_intrinsics_x86.h"
43#include "include/libxsmm_malloc.h"
44#include "include/libxsmm_spmdm.h"
45#endif
46
47namespace tensorflow {
48namespace {
49
50using Eigen::operator==;
51
52template <typename T>
53using BasicMatrix = Eigen::Tensor<T, 2, Eigen::RowMajor>;
54
55template <typename T>
56using BasicMatrixMap =
57    Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Aligned>;
58
59using Matrix = BasicMatrix<float>;
60using MatrixMap = BasicMatrixMap<float>;
61using CPUDevice = Eigen::ThreadPoolDevice;
62using DSizes = Eigen::DSizes<Eigen::DenseIndex, 2>;
63
64// Two commonly used static dsizes. We use Eigen::type2index to allow as much
65// compile time optimization as possible.
66#ifdef EIGEN_HAS_INDEX_LIST
67inline Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>>
68dsizes_00() {
69  return Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>>();
70}
71inline Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>>
72dsizes_10() {
73  return Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>>();
74}
75#else
76inline DSizes dsizes_00() { return DSizes(0, 0); }
77inline DSizes dsizes_10() { return DSizes(1, 0); }
78#endif
79
80// Blocksizes
81// TODO(agarwal): compute these sizes based on cache sizes.
82const int K = 64;
83const int M = 64;
84const int N = 128;
85
86// This stores a sparse representation of a slice of a matrix with size
87// (num_rows, num_cols). The slice is represented as a series of blocks of size
88// (num_rows, b), where b = block_size for all but the last block, which may
89// have fewer columns.
90//
91// num_rows and block_size are assumed to be <= 256. This allows storing
92// different indices as uint8.
93//
94// For each block, we store all the non zero entries in data/data3 vector and
95// the corresponding coordinates of the element in index/index3 vectors. index3
96// vector stores index of 3 elements in the same row so that these elements can
97// share the same row coordinate. Each entry in Index3 corresponds to 3 entries
98// in data3.
99//
100// Note that all the data/indices of all the blocks are stored in the same
101// vectors respectively. To identify block boundaries, we store the block
102// offsets using index3_offset/index_offset. If there are n blocks in the slice,
103// index3_offset and index_offset have n entries. The indices for the ith block
104// are the values in the following range:
105// [index3[index3_offset[i-1]], index3[index3_offset[i]]). Similarly for
106// index_offset.
107template <typename T>
108struct SparseSlice {
109  using ConstMatrixMap = BasicMatrixMap<const T>;
110
111 public:
112  // Indices of three elements on the same row.
113  struct Index3 {
114    uint8 m;  // row
115    // columns
116    uint8 k1;
117    uint8 k2;
118    uint8 k3;
119  };
120
121  // Index of one element.
122  struct Index {
123    uint8 m;
124    uint8 k;
125  };
126
127  SparseSlice(int nrows, int ncols, int bsize)
128      : num_rows(nrows), num_cols(ncols), block_size(bsize) {
129    DCHECK_LE(nrows, 256);
130    DCHECK_LE(block_size, 256);
131  }
132
133  // Initializes the slice with data starting at mat(0, col_offset) and with
134  // size (num_rows, num_cols).
135  // If Transpose is true, implicitly transposes mat.
136  template <bool Transpose = false>
137  void Initialize(const ConstMatrixMap& mat, int col_offset);
138
139  void Clear();
140
141  // See comments above.
142  std::vector<int> index3_offset;
143  std::vector<Index3> index3;
144  std::vector<T> data3;
145
146  // See comments above. Similar to "index3" except that each element in "index"
147  // corresponds to one element in data.
148  std::vector<int> index_offset;
149  std::vector<Index> index;
150  std::vector<T> data;
151
152  // Number of rows and columns for the slice.
153  const int num_rows;
154  const int num_cols;
155
156  // Block size used to initialize from a matrix.
157  const int block_size;
158};
159
160template <typename T>
161template <bool Transpose>
162void SparseSlice<T>::Initialize(
163    const typename SparseSlice<T>::ConstMatrixMap& mat, int col_offset) {
164  const int mat_rows = Transpose ? mat.dimension(1) : mat.dimension(0);
165  const int mat_cols = Transpose ? mat.dimension(0) : mat.dimension(1);
166  DCHECK_LE(num_rows, mat_rows);
167  DCHECK_LE(num_cols + col_offset, mat_cols);
168
169  int num_blocks = (num_cols + block_size - 1) / block_size;
170  int mat_size = num_rows * num_cols;
171
172  index3_offset.reserve(num_blocks);
173  data3.reserve(mat_size);
174  index3.reserve(mat_size / 3);
175
176  index_offset.reserve(num_blocks);
177  data.reserve(num_blocks * num_rows * 2);
178  index.reserve(num_blocks * num_rows * 2);
179
180  Index3 idx3;
181  Index idx;
182  int data3_size = 0;
183  static const T zero(0);
184  for (int i = 0; i < num_blocks; ++i) {
185    int num_block_cols = std::min(block_size, num_cols - block_size * i);
186    for (int row = 0; row < num_rows; ++row) {
187      idx3.m = static_cast<uint8>(row);
188      // Safety note: The following code has a race, since it checks whether
189      // *curr is nonzero and then reads it again on use.  However, the result
190      // of the race is only that some of the "nonzeros" in the resulting sparse
191      // representation may actually be zero, which is harmless.
192      const auto* start =
193          Transpose ? &mat(col_offset, row) : &mat(row, col_offset);
194      const auto* curr = start;
195      const int stride = Transpose ? mat.dimension(1) : 1;
196      const auto* end = start + stride * num_block_cols;
197      uint8 k = 0;
198#define NEXT_ELEM \
199  curr += stride; \
200  ++k;
201      while (true) {
202        while (curr < end && (*curr == zero)) {
203          NEXT_ELEM;
204        }
205        if (curr >= end) break;
206        idx3.k1 = k;
207        data3.push_back(*curr);
208        NEXT_ELEM;
209
210        while (curr < end && (*curr == zero)) {
211          NEXT_ELEM;
212        }
213        if (curr >= end) break;
214        idx3.k2 = k;
215        data3.push_back(*curr);
216        NEXT_ELEM;
217
218        while (curr < end && (*curr == zero)) {
219          NEXT_ELEM;
220        }
221        if (curr >= end) break;
222        idx3.k3 = k;
223        data3.push_back(*curr);
224        NEXT_ELEM;
225        index3.push_back(idx3);
226#undef NEXT_ELEM
227      }
228      int num_inserted_mod = data3.size() % 3;
229      // Move some elements to index and data if needed.
230      data3_size = data3.size() - num_inserted_mod;
231      idx.m = idx3.m;
232      switch (num_inserted_mod) {
233        case 2:
234          idx.k = idx3.k2;
235          data.push_back(data3[data3_size + 1]);
236          index.push_back(idx);
237          TF_FALLTHROUGH_INTENDED;
238        case 1:
239          idx.k = idx3.k1;
240          data.push_back(data3[data3_size]);
241          index.push_back(idx);
242          data3.resize(data3_size);
243      }
244    }
245    col_offset += block_size;
246    index3_offset.push_back(index3.size());
247    index_offset.push_back(index.size());
248  }
249  DCHECK_EQ(index3_offset.size(), num_blocks);
250  DCHECK_EQ(index_offset.size(), num_blocks);
251  DCHECK_EQ(3 * index3.size(), data3.size());
252  DCHECK_EQ(index.size(), data.size());
253}
254
255template <typename T>
256void SparseSlice<T>::Clear() {
257  index3_offset.clear();
258  index3.clear();
259  data3.clear();
260  index_offset.clear();
261  index.clear();
262  data.clear();
263}
264
265using Packet = Eigen::internal::packet_traits<float>::type;
266const int kNumOperands = (sizeof(Packet) / sizeof(float));
267#define LOAD(x) Eigen::internal::pload<Packet>(x);
268#define EXPAND_BFLOAT_L(x, y) \
269  const auto y = Eigen::internal::pexpand_bf16_l<Packet>(x);
270#define EXPAND_BFLOAT_U(x, y) \
271  const auto y = Eigen::internal::pexpand_bf16_u<Packet>(x);
272#define STORE(x, y) Eigen::internal::pstore<float>(x, y);
273#define FMA(a, b, c, d) d = Eigen::internal::pmadd<Packet>(a, b, c);
274
275#define ALWAYS_INLINE EIGEN_ALWAYS_INLINE
276
277ALWAYS_INLINE float ConvertBfloat16ToFloat(const bfloat16* src) {
278  float out = 0;
279  auto tmp = reinterpret_cast<bfloat16*>(&out);
280#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
281  tmp[0] = *src;
282#else
283  tmp[1] = *src;
284#endif
285  return out;
286}
287
288ALWAYS_INLINE Packet ConvertFourBfloat16ToFloat(const bfloat16* src) {
289  return Eigen::internal::pload4bf16<Packet>(
290      reinterpret_cast<const float*>(src));
291}
292
293ALWAYS_INLINE Packet ConvertTwoBfloat16ToFloat(const bfloat16* src) {
294  return Eigen::internal::pload2bf16<Packet>(
295      reinterpret_cast<const float*>(src));
296}
297
298ALWAYS_INLINE void ScalarMulAdd(const float a, const float** inp, float** out) {
299  **out += a * **inp;
300  ++*inp;
301  ++*out;
302}
303
304ALWAYS_INLINE void ScalarMulAdd(const float a, const bfloat16** inp,
305                                float** out) {
306  float inp_f = ConvertBfloat16ToFloat(*inp);
307  **out += a * inp_f;
308  ++*inp;
309  ++*out;
310}
311ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2,
312                                    const float a3, const bfloat16** inp1,
313                                    const bfloat16** inp2,
314                                    const bfloat16** inp3, float** out) {
315  float inp1_f = ConvertBfloat16ToFloat(*inp1);
316  float inp2_f = ConvertBfloat16ToFloat(*inp2);
317  float inp3_f = ConvertBfloat16ToFloat(*inp3);
318  **out += a1 * inp1_f + a2 * inp2_f + a3 * inp3_f;
319  ++*out;
320  ++*inp1;
321  ++*inp2;
322  ++*inp3;
323}
324
325ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2,
326                                    const float a3, const float** inp1,
327                                    const float** inp2, const float** inp3,
328                                    float** out) {
329  **out += a1 * **inp1 + a2 * **inp2 + a3 * **inp3;
330  ++*out;
331  ++*inp1;
332  ++*inp2;
333  ++*inp3;
334}
335
336ALWAYS_INLINE void LoadSingleScalar(const bfloat16** data, Packet* l) {
337  auto tmp = ConvertBfloat16ToFloat(*data);
338  *l = Eigen::internal::pset1<Packet>(tmp);
339  ++*data;
340}
341
342ALWAYS_INLINE void LoadTwoScalars(const bfloat16** data, Packet* l1,
343                                  Packet* l2) {
344  if (kNumOperands >= 2) {
345    auto tmp = ConvertTwoBfloat16ToFloat(*data);
346    *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp);
347    *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp);
348    *data += 2;
349  } else {
350    LoadSingleScalar(data, l1);
351    LoadSingleScalar(data, l2);
352  }
353}
354
355ALWAYS_INLINE void LoadFourScalars(const bfloat16** data, Packet* l1,
356                                   Packet* l2, Packet* l3, Packet* l4) {
357  if (kNumOperands >= 4) {
358    auto tmp = ConvertFourBfloat16ToFloat(*data);
359    *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp);
360    *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp);
361    *l3 = Eigen::internal::pbroadcast_third<Packet>(tmp);
362    *l4 = Eigen::internal::pbroadcast_fourth<Packet>(tmp);
363    *data += 4;
364  } else {
365    LoadTwoScalars(data, l1, l2);
366    LoadTwoScalars(data, l3, l4);
367  }
368}
369
370ALWAYS_INLINE void LoadSingleScalar(const float** data, Packet* l) {
371  *l = Eigen::internal::pload1<Packet>(*data);
372  ++(*data);
373}
374
375ALWAYS_INLINE void LoadTwoScalars(const float** data, Packet* l1, Packet* l2) {
376  LoadSingleScalar(data, l1);
377  LoadSingleScalar(data, l2);
378}
379
380ALWAYS_INLINE void LoadFourScalars(const float** data, Packet* l1, Packet* l2,
381                                   Packet* l3, Packet* l4) {
382  LoadTwoScalars(data, l1, l2);
383  LoadTwoScalars(data, l3, l4);
384}
385
386template <typename T>
387ALWAYS_INLINE void LoadThreeScalars(const T** data, Packet* l1, Packet* l2,
388                                    Packet* l3) {
389  LoadTwoScalars(data, l1, l2);
390  LoadSingleScalar(data, l3);
391}
392
393template <typename T>
394ALWAYS_INLINE void LoadSixScalars(const T** data, Packet* l1, Packet* l2,
395                                  Packet* l3, Packet* l4, Packet* l5,
396                                  Packet* l6) {
397  LoadFourScalars(data, l1, l2, l3, l4);
398  LoadTwoScalars(data, l5, l6);
399}
400
401// Vectorized version of ScalarMulAdd.
402ALWAYS_INLINE void MulAdd(const Packet a, const bfloat16** binp, float** out) {
403  auto inp = reinterpret_cast<const float*>(*binp);
404  const auto b = LOAD(inp);
405  EXPAND_BFLOAT_L(b, b_0);
406  EXPAND_BFLOAT_U(b, b_1);
407  *binp += 2 * kNumOperands;
408  auto c1 = LOAD(*out);
409  auto c2 = LOAD(*out + kNumOperands);
410  FMA(a, b_0, c1, c1);
411  FMA(a, b_1, c2, c2);
412  STORE(*out, c1);
413  STORE(*out + kNumOperands, c2);
414  *out += 2 * kNumOperands;
415}
416
417// Vectorized version of ScalarMulAdd3Way.
418ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3,
419                              const bfloat16** binp1, const bfloat16** binp2,
420                              const bfloat16** binp3, float** out) {
421  auto inp1 = reinterpret_cast<const float*>(*binp1);
422  auto inp2 = reinterpret_cast<const float*>(*binp2);
423  auto inp3 = reinterpret_cast<const float*>(*binp3);
424  auto c1 = LOAD(*out);
425  auto c2 = LOAD(*out + kNumOperands);
426  const auto b1 = LOAD(inp1);
427  EXPAND_BFLOAT_L(b1, b1_0);
428  EXPAND_BFLOAT_U(b1, b1_1);
429  *binp1 += 2 * kNumOperands;
430  const auto b2 = LOAD(inp2);
431  EXPAND_BFLOAT_L(b2, b2_0);
432  EXPAND_BFLOAT_U(b2, b2_1);
433  *binp2 += 2 * kNumOperands;
434  const auto b3 = LOAD(inp3);
435  EXPAND_BFLOAT_L(b3, b3_0);
436  EXPAND_BFLOAT_U(b3, b3_1);
437  *binp3 += 2 * kNumOperands;
438  FMA(a1, b1_0, c1, c1);
439  FMA(a1, b1_1, c2, c2);
440  FMA(a2, b2_0, c1, c1);
441  FMA(a2, b2_1, c2, c2);
442  FMA(a3, b3_0, c1, c1);
443  FMA(a3, b3_1, c2, c2);
444  STORE(*out, c1);
445  STORE(*out + kNumOperands, c2);
446  *out += 2 * kNumOperands;
447}
448
449// Unroll MulAdd3Way for two iterations
450ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2,
451                                 const Packet a3, const bfloat16** binp1,
452                                 const bfloat16** binp2, const bfloat16** binp3,
453                                 float** out) {
454  auto inp1 = reinterpret_cast<const float*>(*binp1);
455  auto inp2 = reinterpret_cast<const float*>(*binp2);
456  auto inp3 = reinterpret_cast<const float*>(*binp3);
457  auto c1 = LOAD(*out);
458  auto c2 = LOAD(*out + kNumOperands);
459  const auto b1 = LOAD(inp1);
460  const auto b2 = LOAD(inp2);
461  const auto b3 = LOAD(inp3);
462
463  EXPAND_BFLOAT_L(b1, b1_0);
464  EXPAND_BFLOAT_U(b1, b1_1);
465  EXPAND_BFLOAT_L(b2, b2_0);
466  EXPAND_BFLOAT_U(b2, b2_1);
467  EXPAND_BFLOAT_L(b3, b3_0);
468  EXPAND_BFLOAT_U(b3, b3_1);
469  auto c3 = LOAD(*out + 2 * kNumOperands);
470  auto c4 = LOAD(*out + 3 * kNumOperands);
471  const auto b4 = LOAD(inp1 + kNumOperands);
472  const auto b5 = LOAD(inp2 + kNumOperands);
473  const auto b6 = LOAD(inp3 + kNumOperands);
474
475  EXPAND_BFLOAT_L(b4, b4_0);
476  EXPAND_BFLOAT_U(b4, b4_1);
477  EXPAND_BFLOAT_L(b5, b5_0);
478  EXPAND_BFLOAT_U(b5, b5_1);
479  EXPAND_BFLOAT_L(b6, b6_0);
480  EXPAND_BFLOAT_U(b6, b6_1);
481
482  FMA(a1, b1_0, c1, c1);
483  FMA(a1, b1_1, c2, c2);
484  FMA(a1, b4_0, c3, c3);
485  FMA(a1, b4_1, c4, c4);
486  FMA(a2, b2_0, c1, c1);
487  FMA(a2, b2_1, c2, c2);
488  FMA(a2, b5_0, c3, c3);
489  FMA(a2, b5_1, c4, c4);
490  FMA(a3, b3_0, c1, c1);
491  FMA(a3, b3_1, c2, c2);
492  FMA(a3, b6_0, c3, c3);
493  FMA(a3, b6_1, c4, c4);
494  STORE(*out, c1);
495  STORE(*out + kNumOperands, c2);
496  STORE(*out + 2 * kNumOperands, c3);
497  STORE(*out + 3 * kNumOperands, c4);
498  *out += 4 * kNumOperands;
499  *binp1 += 4 * kNumOperands;
500  *binp2 += 4 * kNumOperands;
501  *binp3 += 4 * kNumOperands;
502}
503
504// Apply MulAdd3Way on 128 operands.
505ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2,
506                                 const Packet a3, const bfloat16** inp1,
507                                 const bfloat16** inp2, const bfloat16** inp3,
508                                 float** out) {
509  for (int k = 0; k < 128 / (8 * kNumOperands); ++k) {
510    TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
511    TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
512  }
513}
514
515// Vectorized version of ScalarMulAdd
516ALWAYS_INLINE void MulAdd(const Packet a, const float** inp, float** out) {
517  const auto b = LOAD(*inp);
518  *inp += kNumOperands;
519  auto c = LOAD(*out);
520  FMA(a, b, c, c);
521  STORE(*out, c);
522  *out += kNumOperands;
523}
524
525// Vectorized version of ScalarMulAdd3Way
526ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3,
527                              const float** inp1, const float** inp2,
528                              const float** inp3, float** out) {
529  auto c = LOAD(*out);
530  const auto b1 = LOAD(*inp1);
531  *inp1 += kNumOperands;
532  const auto b2 = LOAD(*inp2);
533  *inp2 += kNumOperands;
534  const auto b3 = LOAD(*inp3);
535  *inp3 += kNumOperands;
536  FMA(a1, b1, c, c);
537  FMA(a2, b2, c, c);
538  FMA(a3, b3, c, c);
539  STORE(*out, c);
540  *out += kNumOperands;
541}
542
543// Unroll MulAdd3Way for two iterations
544ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2,
545                                 const Packet a3, const float** inp1,
546                                 const float** inp2, const float** inp3,
547                                 float** out) {
548  auto c1 = LOAD(*out);
549  const auto b1 = LOAD(*inp1);
550  const auto b2 = LOAD(*inp2);
551  const auto b3 = LOAD(*inp3);
552
553  auto c2 = LOAD(*out + kNumOperands);
554  const auto b4 = LOAD(*inp1 + kNumOperands);
555  const auto b5 = LOAD(*inp2 + kNumOperands);
556  const auto b6 = LOAD(*inp3 + kNumOperands);
557
558  FMA(a1, b1, c1, c1);
559  FMA(a1, b4, c2, c2);
560  FMA(a2, b2, c1, c1);
561  FMA(a2, b5, c2, c2);
562  FMA(a3, b3, c1, c1);
563  FMA(a3, b6, c2, c2);
564  STORE(*out, c1);
565  STORE(*out + kNumOperands, c2);
566  *out += 2 * kNumOperands;
567  *inp1 += 2 * kNumOperands;
568  *inp2 += 2 * kNumOperands;
569  *inp3 += 2 * kNumOperands;
570}
571
572// Unroll MulAdd3Way for four iterations
573ALWAYS_INLINE void FourMulAdd3Way(const Packet a1, const Packet a2,
574                                  const Packet a3, const float** inp1,
575                                  const float** inp2, const float** inp3,
576                                  float** out) {
577  TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
578  TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
579}
580
581// Apply MulAdd3Way on 128 operands.
582ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2,
583                                 const Packet a3, const float** inp1,
584                                 const float** inp2, const float** inp3,
585                                 float** out) {
586  if (kNumOperands == 8) {
587    FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
588    FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
589    FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
590    FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
591  } else {
592    DCHECK_LE(4 * kNumOperands, 128);
593    for (int i = 0; i < 128 / (4 * kNumOperands); ++i) {
594      MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
595      MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
596      MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
597      MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
598    }
599  }
600}
601// Computes product of "left_slices" with "num_cols" columns of "right", and
602// stores the output in *"output".
603// Note that left_slices is a list of SparseSlices, which are conceptually
604// assumed to be concatenated along the column dimension. Also each SparseSlice
605// is encoded as a list of blocks with upto N columns. See SparseSlice for more
606// details.
607template <typename TL, typename TR, int Cols>
608inline void GEPP(
609    const std::vector<SparseSlice<TL>*>& left_slices,
610    const Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>,
611                           Eigen::Aligned>& right,
612    const int num_cols, Matrix* output) {
613  const int cols = (Cols == -1) ? num_cols : Cols;
614  DCHECK_EQ(num_cols, cols);
615  const int right_num_cols = right.dimension(1);
616  const int output_num_cols = output->dimension(1);
617  static const int kNumOperandsR = kNumOperands * sizeof(float) / sizeof(TR);
618  const int cols_mod = cols % kNumOperandsR;
619  int k_offset = 0;
620  // Pre-compute pointers for output matrix.
621  float* out_ptrs[M];
622  float* const out_start = &(*output)(0, 0);
623  for (int j = 0; j < M; ++j) {
624    out_ptrs[j] = out_start + output_num_cols * j;
625  }
626  for (const auto* left_slice : left_slices) {
627    const auto& left = *left_slice;
628    const auto* data3 = (!left.data3.empty()) ? &left.data3[0] : nullptr;
629    const auto* data = (!left.data.empty()) ? &left.data[0] : nullptr;
630    const int num_blocks = left.index3_offset.size();
631    int begin3 = 0;
632    int begin = 0;
633    for (int i = 0; i < num_blocks; ++i) {
634      // Pre-compute pointers for right matrix
635      const TR* right_ptrs[K];
636      const auto* const right_start = &right(k_offset, 0);
637      DCHECK_LT(k_offset, right.dimension(0));
638      for (int j = 0; j < K; ++j) {
639        right_ptrs[j] = right_start + right_num_cols * j;
640      }
641
642      const int end3 = left.index3_offset[i];
643      int j = begin3;
644      // Loop unrolled for 2 iterations.
645      for (; j + 1 < end3; j += 2) {
646        Packet l1, l2, l3, nl1, nl2, nl3;
647        LoadSixScalars(&data3, &l1, &l2, &l3, &nl1, &nl2, &nl3);
648        const auto& index = left.index3[j];
649        const auto& nindex = left.index3[j + 1];
650        float* out = out_ptrs[index.m];
651        float* nout = out_ptrs[nindex.m];
652        const auto* r1 = right_ptrs[index.k1];
653        const auto* r2 = right_ptrs[index.k2];
654        const auto* r3 = right_ptrs[index.k3];
655
656        const auto* nr1 = right_ptrs[nindex.k1];
657        const auto* nr2 = right_ptrs[nindex.k2];
658        const auto* nr3 = right_ptrs[nindex.k3];
659        if (cols == 128) {
660          MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out);
661          MulAdd3Way128(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout);
662        } else {
663          for (int n = 0; n < cols / kNumOperandsR; ++n) {
664            MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out);
665            MulAdd3Way(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout);
666          }
667
668          const float sl1 = Eigen::internal::pfirst<Packet>(l1);
669          const float sl2 = Eigen::internal::pfirst<Packet>(l2);
670          const float sl3 = Eigen::internal::pfirst<Packet>(l3);
671          const float nsl1 = Eigen::internal::pfirst<Packet>(nl1);
672          const float nsl2 = Eigen::internal::pfirst<Packet>(nl2);
673          const float nsl3 = Eigen::internal::pfirst<Packet>(nl3);
674          for (int k = 0; k < cols_mod; ++k) {
675            ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out);
676            ScalarMulAdd3Way(nsl1, nsl2, nsl3, &nr1, &nr2, &nr3, &nout);
677          }
678        }
679      }
680      if (j < end3) {
681        Packet l1, l2, l3;
682        LoadThreeScalars(&data3, &l1, &l2, &l3);
683
684        const auto& index = left.index3[j];
685        float* out = out_ptrs[index.m];
686        const auto* r1 = right_ptrs[index.k1];
687        const auto* r2 = right_ptrs[index.k2];
688        const auto* r3 = right_ptrs[index.k3];
689        if (cols == 128) {
690          MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out);
691        } else {
692          for (int n = 0; n < cols / kNumOperandsR; ++n) {
693            MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out);
694          }
695          const float sl1 = Eigen::internal::pfirst<Packet>(l1);
696          const float sl2 = Eigen::internal::pfirst<Packet>(l2);
697          const float sl3 = Eigen::internal::pfirst<Packet>(l3);
698          for (int k = 0; k < cols_mod; ++k) {
699            ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out);
700          }
701        }
702      }
703      begin3 = end3;
704      int end = left.index_offset[i];
705      // Loop unrolled for 4 iterations.
706      j = begin;
707      for (; j + 3 < end; j += 4) {
708        Packet l, nl, n2l, n3l;
709        LoadFourScalars(&data, &l, &nl, &n2l, &n3l);
710
711        const auto& index = left.index[j];
712        const auto& nindex = left.index[j + 1];
713        const auto& n2index = left.index[j + 2];
714        const auto& n3index = left.index[j + 3];
715        const auto* r = right_ptrs[index.k];
716        const auto* nr = right_ptrs[nindex.k];
717        const auto* n2r = right_ptrs[n2index.k];
718        const auto* n3r = right_ptrs[n3index.k];
719        float* out = out_ptrs[index.m];
720        float* nout = out_ptrs[nindex.m];
721        float* n2out = out_ptrs[n2index.m];
722        float* n3out = out_ptrs[n3index.m];
723
724        for (int n = 0; n < cols / kNumOperandsR; ++n) {
725          MulAdd(l, &r, &out);
726          MulAdd(nl, &nr, &nout);
727          MulAdd(n2l, &n2r, &n2out);
728          MulAdd(n3l, &n3r, &n3out);
729        }
730
731        const float sl1 = Eigen::internal::pfirst<Packet>(l);
732        const float sl2 = Eigen::internal::pfirst<Packet>(nl);
733        const float sl3 = Eigen::internal::pfirst<Packet>(n2l);
734        const float sl4 = Eigen::internal::pfirst<Packet>(n3l);
735        for (int k = 0; k < cols_mod; ++k) {
736          ScalarMulAdd(sl1, &r, &out);
737          ScalarMulAdd(sl2, &nr, &nout);
738          ScalarMulAdd(sl3, &n2r, &n2out);
739          ScalarMulAdd(sl4, &n3r, &n3out);
740        }
741      }
742      while (j < end) {
743        Packet l;
744        LoadSingleScalar(&data, &l);
745        const auto& index = left.index[j];
746        const auto* r = right_ptrs[index.k];
747        float* out = out_ptrs[index.m];
748        for (int n = 0; n < cols / kNumOperandsR; ++n) {
749          MulAdd(l, &r, &out);
750        }
751        const float sl = Eigen::internal::pfirst<Packet>(l);
752        for (int k = 0; k < cols_mod; ++k) {
753          ScalarMulAdd(sl, &r, &out);
754        }
755        j++;
756      }
757      k_offset += left.block_size;
758      begin = end;
759    }
760  }
761}
762
763#undef LOAD
764#undef EXPAND_BFLOAT_L
765#undef EXPAND_BFLOAT_U
766#undef STORE
767#undef FMA
768
769}  // namespace
770
771template <typename TL, typename TR>
772class SparseMatMul {
773  using MatrixL = BasicMatrix<TL>;
774  using MatrixR = BasicMatrix<TR>;
775  using ConstMatrixMapL = BasicMatrixMap<const TL>;
776  using ConstMatrixMapR = BasicMatrixMap<const TR>;
777  using MatrixMapR = BasicMatrixMap<TR>;
778
779 public:
780  // Not used; added to match interface of LibxsmmSparseMatMul
781  struct TensorInfoCache {};
782
783  // Perform matrix multiplication of "left" and "right", and store the result
784  // in *"output".
785 public:
786  static inline void Compute(TensorInfoCache* cache,
787                             const ConstMatrixMapL& left,
788                             const ConstMatrixMapR& right, bool transpose_left,
789                             const DeviceBase::CpuWorkerThreads* thread_pool,
790                             bool transpose_output, MatrixMap* output);
791
792 private:
793  // Computes multiplication of left and num_cols columns of right, and stores
794  // the output block in *"output" at offsets "output_row_offset" and
795  // "output_col_offset". If assign is true, assigns the value to that block,
796  // else adds the values to the existing values.
797  static inline void ComputeOutputBlock(
798      const std::vector<SparseSlice<TL>*>& left, const ConstMatrixMapR& right,
799      int num_cols, int output_row_offset, int output_col_offset, bool assign,
800      bool transpose_output, MatrixMap* output);
801
802  // Encodes "mat" using a sparse representation and stores that in
803  // "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and
804  // "slice_num_cols", each grid element is converted into a SparseSlice and
805  // stored in mat_slices. "slice_block_size" is used to perform further column
806  // blocking of each slice.
807  static inline std::unique_ptr<BlockingCounter> CreateSparseSlices(
808      const ConstMatrixMapL& mat, bool transpose, int slice_num_rows,
809      int slice_block_size, int slice_num_cols,
810      std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
811      const DeviceBase::CpuWorkerThreads* thread_pool);
812
813  // This function chops "mat" along column dimension into pieces with at most N
814  // columns, and concatenates the pieces one after the other in "buffer". It
815  // returns the list of the pieces in "slices". It returns a BlockingCounter
816  // which should be used to wait for the shuffle operations to complete.
817  static inline std::unique_ptr<BlockingCounter> CreateDenseSlices(
818      const ConstMatrixMapR& mat, int row_start, int num_rows, int col_start,
819      int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool,
820      MatrixR* buffer, std::vector<ConstMatrixMapR*>* slices);
821
822  // Helper function for CreateDenseSlices to move the data around. It returns a
823  // BlockingCounter which should be used to wait for the shuffle operations to
824  // complete.
825  static inline BlockingCounter* ShuffleMatrix(
826      const ConstMatrixMapR& mat, int slice_row_start, int slice_num_rows,
827      int slice_col_start, int slice_num_cols, const int N,
828      const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer);
829
830  // Helper function for CreateDenseSlices to create slices.
831  static inline void SliceMatrix(const MatrixR& mat, const int num_rows,
832                                 const int num_slices,
833                                 std::vector<ConstMatrixMapR*>* slices);
834
835  // Heuristics to compute various block sizes.
836  // KR, NR: block sizes for "right". We run blocking iterations that operate on
837  // matrices with at most this size.
838  // KL: grid size along the column dimension used while encoding left.
839  // IB, JB: number of left and right slices to multiply together. This is used
840  // for ordering different ComputeBlockOutput operations inside each blocking
841  // iteration so as to potentially reduce the working set size.
842  static inline void ComputeBlockSizes(const ConstMatrixMapL& left,
843                                       const ConstMatrixMapR& right,
844                                       bool transpose_left, int num_threads,
845                                       int* KR, int* NR, int* KL, int* JB,
846                                       int* IB);
847
848  TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMul);
849};
850
851#ifdef TENSORFLOW_USE_LIBXSMM
852template <typename TL, typename TR>
853class LibxsmmSparseMatMul {
854  using MatrixL = BasicMatrix<TL>;
855  using MatrixR = BasicMatrix<TR>;
856  using ConstMatrixMapL = BasicMatrixMap<const TL>;
857  using ConstMatrixMapR = BasicMatrixMap<const TR>;
858  using MatrixMapR = BasicMatrixMap<TR>;
859
860 public:
861  // This structure contains a set of libxsmm kernels for sizes that have been
862  // encountered previously by this operator so that libxsmm does not need to
863  // reallocate its scratchpad memory each time (which hurts performance
864  // substantially).
865  struct TensorInfoCache {
866    struct TensorInfoCacheEntry {
867      // Parameters for kernel
868      int M;
869      int K;
870      int N;
871      int max_threads;
872      // libxsmm handle and matrix data
873      libxsmm_spmdm_handle handle;
874      libxsmm_CSR_sparseslice* output_csr;
875      // Chain to non-libxsmm implementation's cache in case that ever becomes
876      // useful (it is an empty struct right now)
877      typename SparseMatMul<TL, TR>::TensorInfoCache
878          non_libxsmm_cache;  // Currently not used
879    };
880    // protects entries; invariant: entries is a valid std::multimap
881    tensorflow::mutex lock;
882    // Because there could be multiple matrix multiplies with the same sizes
883    // going on at the same time, we need to allow multiple cache entries for a
884    // given set of parameters. Taking and returning entries is used to make
885    // sure the same cache entry is not used from two threads at a time.
886    std::multimap<std::tuple<int, int, int, int>,
887                  std::unique_ptr<TensorInfoCacheEntry>>
888        entries GUARDED_BY(lock);
889
890    TensorInfoCache() : lock(), entries() {}
891    // Look up and remove first entry with these parameters, creating one if
892    // there isn't one
893    std::unique_ptr<TensorInfoCacheEntry> take_cache_entry(int M, int K, int N,
894                                                           int max_threads)
895        LOCKS_EXCLUDED(lock) {
896      tensorflow::mutex_lock ml(lock);
897      auto key = std::make_tuple(M, K, N, max_threads);
898      auto it = entries.find(key);
899      if (it != entries.end()) {
900        auto val = std::move(it->second);
901        entries.erase(it);
902        return val;
903      } else {
904        std::unique_ptr<TensorInfoCacheEntry> e{
905            new TensorInfoCacheEntry{M, K, N, max_threads, {}, nullptr}};
906        // setup scoped allocator, which uses cpu_allocator() for this scope
907        const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator;
908        libxsmm_spmdm_init(M, N, K, max_threads, &e->handle, &e->output_csr);
909        return e;
910      }
911    }
912    // Add a cache entry with certain parameters
913    void return_cache_entry(std::unique_ptr<TensorInfoCacheEntry> e)
914        LOCKS_EXCLUDED(lock) {
915      tensorflow::mutex_lock ml(lock);
916      auto key = std::make_tuple(e->M, e->K, e->N, e->max_threads);
917      entries.insert(std::make_pair(key, std::move(e)));
918    }
919    ~TensorInfoCache() {
920      tensorflow::mutex_lock ml(lock);
921      for (auto& p : entries) {
922        libxsmm_spmdm_destroy(&p.second->handle);
923      }
924      entries.clear();
925    }
926
927   private:
928    TF_DISALLOW_COPY_AND_ASSIGN(TensorInfoCache);
929  };
930
931  // Perform matrix multiplication of "left" and "right", and store the result
932  // in *"output".
933 public:
934  static inline void Compute(TensorInfoCache* cache,
935                             const ConstMatrixMapL& left,
936                             const ConstMatrixMapR& right, bool transpose_left,
937                             const DeviceBase::CpuWorkerThreads* thread_pool,
938                             bool transpose_output, MatrixMap* output);
939
940 private:
941  TF_DISALLOW_COPY_AND_ASSIGN(LibxsmmSparseMatMul);
942};
943#endif
944
945template <typename TL, typename TR,
946          template <typename TL2, typename TR2> class DoMatMul>
947class SparseMatMulOp : public OpKernel {
948  using MatrixR = BasicMatrix<TR>;
949  using ConstMatrixMapR = BasicMatrixMap<const TR>;
950
951 public:
952  explicit SparseMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
953    OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
954    OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
955    OP_REQUIRES_OK(ctx, ctx->GetAttr("a_is_sparse", &a_is_sparse_));
956    OP_REQUIRES_OK(ctx, ctx->GetAttr("b_is_sparse", &b_is_sparse_));
957  }
958
959  void Compute(OpKernelContext* ctx) override {
960    const Tensor& a = ctx->input(0);
961    const Tensor& b = ctx->input(1);
962    OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()),
963                errors::InvalidArgument("a is not a matrix"));
964    OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()),
965                errors::InvalidArgument("b is not a matrix"));
966
967    const int m = transpose_a_ ? a.dim_size(1) : a.dim_size(0);
968    const int k = transpose_a_ ? a.dim_size(0) : a.dim_size(1);
969    const int n = transpose_b_ ? b.dim_size(0) : b.dim_size(1);
970    const int k2 = transpose_b_ ? b.dim_size(1) : b.dim_size(0);
971
972    OP_REQUIRES(ctx, k == k2,
973                errors::InvalidArgument(
974                    "Matrix size incompatible: a: ", a.shape().DebugString(),
975                    ", b: ", b.shape().DebugString()));
976    Tensor* output = nullptr;
977    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({m, n}), &output));
978
979    if (k == 0) {
980      // If the inner dimension k in the matrix multiplication is zero, we fill
981      // the output with zeros.
982      functor::SetZeroFunctor<CPUDevice, float> f;
983      f(ctx->eigen_device<CPUDevice>(), output->flat<float>());
984      return;
985    }
986
987    auto out = output->matrix<float>();
988
989    std::unique_ptr<Tensor> a_float;
990    std::unique_ptr<Tensor> b_float;
991    if (!a_is_sparse_ && !b_is_sparse_) {
992      auto left = &a;
993      auto right = &b;
994      // TODO(agarwal): multi-thread the conversions from bfloat16 to float.
995      if (std::is_same<TL, bfloat16>::value) {
996        a_float.reset(new Tensor(DT_FLOAT, a.shape()));
997        BFloat16ToFloat(a.flat<bfloat16>().data(),
998                        a_float->flat<float>().data(), a.NumElements());
999        left = a_float.get();
1000      }
1001      if (std::is_same<TR, bfloat16>::value) {
1002        b_float.reset(new Tensor(DT_FLOAT, b.shape()));
1003        BFloat16ToFloat(b.flat<bfloat16>().data(),
1004                        b_float->flat<float>().data(), b.NumElements());
1005        right = b_float.get();
1006      }
1007      Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
1008      dim_pair[0].first = transpose_a_ ? 0 : 1;
1009      dim_pair[0].second = transpose_b_ ? 1 : 0;
1010
1011      out.device(ctx->template eigen_device<CPUDevice>()) =
1012          left->matrix<float>().contract(right->matrix<float>(), dim_pair);
1013      return;
1014    }
1015
1016    auto left = &a;
1017    auto right = &b;
1018    bool transpose_output = false;
1019    bool transpose_a = transpose_a_;
1020    bool transpose_b = transpose_b_;
1021    if (!a_is_sparse_) {
1022      // Swap the order of multiplications using the identity:
1023      // A * B = (B' *  A')'.
1024      std::swap(left, right);
1025      std::swap(transpose_a, transpose_b);
1026      transpose_a = !transpose_a;
1027      transpose_b = !transpose_b;
1028      transpose_output = !transpose_output;
1029    }
1030
1031    std::unique_ptr<Tensor> right_tr;
1032    if (transpose_b) {
1033      // TODO(agarwal): avoid transposing the matrix here and directly handle
1034      // transpose in CreateDenseSlices.
1035      right_tr.reset(
1036          new Tensor(right->dtype(),
1037                     TensorShape({right->dim_size(1), right->dim_size(0)})));
1038
1039      const auto perm = dsizes_10();
1040      if (transpose_output) {
1041        right_tr->matrix<TL>().device(ctx->template eigen_device<CPUDevice>()) =
1042            right->matrix<TL>().shuffle(perm);
1043      } else {
1044        right_tr->matrix<TR>().device(ctx->template eigen_device<CPUDevice>()) =
1045            right->matrix<TR>().shuffle(perm);
1046      }
1047      right = right_tr.get();
1048    }
1049
1050    if (transpose_output) {
1051      DoMatMul<TR, TL>::Compute(&this->cache_tr_, left->matrix<TR>(),
1052                                right->matrix<TL>(), transpose_a,
1053                                ctx->device()->tensorflow_cpu_worker_threads(),
1054                                transpose_output, &out);
1055    } else {
1056      DoMatMul<TL, TR>::Compute(&this->cache_nt_, left->matrix<TL>(),
1057                                right->matrix<TR>(), transpose_a,
1058                                ctx->device()->tensorflow_cpu_worker_threads(),
1059                                transpose_output, &out);
1060    }
1061  }
1062
1063 private:
1064  bool transpose_a_;
1065  bool transpose_b_;
1066  bool a_is_sparse_;
1067  bool b_is_sparse_;
1068
1069  // Cache for non-transposed-output multiply
1070  typename DoMatMul<TL, TR>::TensorInfoCache cache_nt_;
1071  // Cache for transposed-output multiply
1072  typename DoMatMul<TR, TL>::TensorInfoCache cache_tr_;
1073
1074  TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMulOp);
1075};
1076
1077template <typename TL, typename TR>
1078inline void SparseMatMul<TL, TR>::ComputeOutputBlock(
1079    const std::vector<SparseSlice<TL>*>& left,
1080    const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right, int num_cols,
1081    int output_row_offset, int output_col_offset, bool assign,
1082    bool transpose_output, MatrixMap* output) {
1083  const auto perm = dsizes_10();
1084  int num_rows = left[0]->num_rows;
1085  const int rhs_num_cols = right.dimension(1);
1086  DCHECK_LE(num_cols, rhs_num_cols);
1087  Matrix out(num_rows, rhs_num_cols);
1088  out.setZero();
1089  if (num_cols == N) {
1090    GEPP<TL, TR, N>(left, right, num_cols, &out);
1091  } else {
1092    GEPP<TL, TR, -1>(left, right, num_cols, &out);
1093  }
1094  if (!assign) {
1095    const DSizes begin(output_row_offset, output_col_offset);
1096    const DSizes sizes(num_rows, num_cols);
1097    if (transpose_output) {
1098      if (num_cols == rhs_num_cols) {
1099        output->shuffle(perm).slice(begin, sizes) += out;
1100      } else {
1101        const auto zero = dsizes_00();
1102        output->shuffle(perm).slice(begin, sizes) += out.slice(zero, sizes);
1103      }
1104    } else {
1105      if (num_cols == rhs_num_cols) {
1106        output->slice(begin, sizes) += out;
1107      } else {
1108        const auto zero = dsizes_00();
1109        output->slice(begin, sizes) += out.slice(zero, sizes);
1110      }
1111    }
1112  } else {
1113    std::unique_ptr<Matrix> out_tr;
1114    if (transpose_output) {
1115      out_tr.reset(new Matrix(rhs_num_cols, num_rows));
1116      *out_tr = out.shuffle(perm);
1117      std::swap(output_row_offset, output_col_offset);
1118      std::swap(num_rows, num_cols);
1119    }
1120    const Matrix& final_out = transpose_output ? *out_tr : out;
1121    for (int i = 0; i < num_rows; ++i) {
1122      memcpy(&(*output)(output_row_offset + i, output_col_offset),
1123             &final_out(i, 0), num_cols * sizeof(float));
1124    }
1125  }
1126}
1127
1128template <typename TL, typename TR>
1129inline std::unique_ptr<BlockingCounter>
1130SparseMatMul<TL, TR>::CreateSparseSlices(
1131    const typename SparseMatMul<TL, TR>::ConstMatrixMapL& mat, bool transpose,
1132    int slice_num_rows, int slice_block_size, int slice_num_cols,
1133    std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
1134    const DeviceBase::CpuWorkerThreads* thread_pool) {
1135  const int mat_num_rows = transpose ? mat.dimension(1) : mat.dimension(0);
1136  const int mat_num_cols = transpose ? mat.dimension(0) : mat.dimension(1);
1137  const int num_slices_dim0 =
1138      std::max(1, (mat_num_rows + slice_num_rows - 1) / slice_num_rows);
1139  const int num_slices_dim1 =
1140      std::max(1, (mat_num_cols + slice_num_cols - 1) / slice_num_cols);
1141  mat_slices->resize(num_slices_dim0);
1142  BlockingCounter* counter =
1143      new BlockingCounter(num_slices_dim0 * num_slices_dim1);
1144  auto work = [counter, transpose](SparseSlice<TL>* sparse_slice,
1145                                   SparseMatMul<TL, TR>::ConstMatrixMapL* slice,
1146                                   int col_offset) {
1147    if (transpose) {
1148      sparse_slice->template Initialize<true>(*slice, col_offset);
1149    } else {
1150      sparse_slice->template Initialize<false>(*slice, col_offset);
1151    }
1152    delete slice;
1153    counter->DecrementCount();
1154  };
1155  for (int i = 0; i < num_slices_dim0; ++i) {
1156    (*mat_slices)[i].resize(num_slices_dim1);
1157    int num_rows =
1158        std::min<int>(slice_num_rows, mat_num_rows - i * slice_num_rows);
1159    for (int j = 0; j < num_slices_dim1; ++j) {
1160      int num_cols =
1161          std::min<int>(slice_num_cols, mat_num_cols - j * slice_num_cols);
1162      SparseMatMul<TL, TR>::ConstMatrixMapL* slice = nullptr;
1163      if (transpose) {
1164        slice = new SparseMatMul<TL, TR>::ConstMatrixMapL(
1165            &mat(0, i * slice_num_rows), mat.dimensions());
1166      } else {
1167        DSizes d(num_rows, mat_num_cols);
1168        slice = new SparseMatMul<TL, TR>::ConstMatrixMapL(
1169            &mat(i * slice_num_rows, 0), d);
1170      }
1171      auto* sparse_slice =
1172          new SparseSlice<TL>(num_rows, num_cols, slice_block_size);
1173      (*mat_slices)[i][j] = sparse_slice;
1174      thread_pool->workers->Schedule(
1175          [=]() { work(sparse_slice, slice, slice_num_cols * j); });
1176    }
1177  }
1178  return std::unique_ptr<BlockingCounter>(counter);
1179}
1180#define LOAD(x) Eigen::internal::ploadu<Packet>((x));
1181#define INTERLEAVE(x) Eigen::internal::pinterleave4x64<Packet>(x);
1182#define STORE(x, y) Eigen::internal::pstoreu<float>(x, y);
1183
1184template <int NUM_ELEM = -1>
1185ALWAYS_INLINE void CopyAndMayBeInterleaveBfloat16(void* bdst, const void* bsrc,
1186                                                  int num_elements) {
1187  DCHECK_GE(kNumOperands, 8);
1188  static const int kStep = kNumOperands * sizeof(float) / sizeof(bfloat16);
1189  const int num = (NUM_ELEM == -1) ? num_elements : NUM_ELEM;
1190  DCHECK_EQ(num, num_elements);
1191  const float* src = reinterpret_cast<const float*>(bsrc);
1192  float* dst = reinterpret_cast<float*>(bdst);
1193  for (int index = 0; index + kStep <= num; index += kStep) {
1194    auto in = LOAD(src);
1195    auto tmp = INTERLEAVE(in);
1196    STORE(dst, tmp);
1197    src += kNumOperands;
1198    dst += kNumOperands;
1199  }
1200  if (num % kStep != 0) {
1201    memcpy(dst, src, (num % kStep) * sizeof(bfloat16));
1202  }
1203}
1204
1205template <typename T>
1206ALWAYS_INLINE void CopyAndMayBeInterleave(void* dst, const void* src,
1207                                          int num_elements) {
1208  if (std::is_same<T, float>::value || kNumOperands < 8) {
1209    memcpy(dst, src, num_elements * sizeof(T));
1210  } else if (std::is_same<T, bfloat16>::value) {
1211    if (num_elements == N) {
1212      CopyAndMayBeInterleaveBfloat16<N>(dst, src, num_elements);
1213    } else {
1214      CopyAndMayBeInterleaveBfloat16<-1>(dst, src, num_elements);
1215    }
1216  } else {
1217    LOG(FATAL) << "Unsupported type";
1218  }
1219}
1220
1221#undef LOAD
1222#undef Interleave
1223#undef Store
1224
1225template <typename TL, typename TR>
1226inline BlockingCounter* SparseMatMul<TL, TR>::ShuffleMatrix(
1227    const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat,
1228    int slice_row_start, int slice_num_rows, int slice_col_start,
1229    int slice_num_cols, const int N,
1230    const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer) {
1231  DCHECK_EQ(N % 2, 0);
1232  DCHECK_LE(kNumOperands * sizeof(float) / sizeof(TR), N);
1233  int num_threads = std::min(thread_pool->num_threads, 16);
1234  BlockingCounter* counter = new BlockingCounter(num_threads);
1235  DCHECK_EQ(N, buffer->dimension(1));
1236  auto shuffle_work = [&mat, slice_row_start, slice_num_rows, slice_col_start,
1237                       slice_num_cols, N, buffer, counter](int s, int e) {
1238    const int row_start = s % slice_num_rows + slice_row_start;
1239    const int col_start = s / slice_num_rows * N + slice_col_start;
1240    auto* out_start = &(*buffer)(s, 0);
1241    const auto* input_start = &mat(row_start, col_start);
1242    const auto* input_end = &mat(slice_row_start + slice_num_rows - 1,
1243                                 slice_col_start + slice_num_cols - 1);
1244    const int mat_num_cols = mat.dimension(1);
1245    const int row_slice_size = slice_num_rows * mat_num_cols;
1246
1247    const int aligned_end = slice_num_cols / N * slice_num_rows;
1248    const int e1 = std::min(e, aligned_end);
1249    while (s < e1) {
1250      CopyAndMayBeInterleave<TR>(out_start, input_start, N);
1251      out_start += N;
1252      input_start += mat_num_cols;
1253      if (input_start > input_end) {
1254        input_start = input_start - row_slice_size + N;
1255      }
1256      ++s;
1257    }
1258    int s1 = std::max(s, aligned_end);
1259    const int copy_num_cols = slice_num_cols % N;
1260    while (s1 < e) {
1261      CopyAndMayBeInterleave<TR>(out_start, input_start, copy_num_cols);
1262      out_start += N;
1263      input_start += mat_num_cols;
1264      ++s1;
1265    }
1266    if (counter) counter->DecrementCount();
1267  };
1268
1269  int start = 0;
1270  int end = 0;
1271  int num_out_rows = (slice_num_cols + N - 1) / N * slice_num_rows;
1272  DCHECK_LE(num_out_rows, buffer->dimension(0));
1273  for (int i = std::max(1, num_threads); i > 0; --i) {
1274    end = start + num_out_rows / i;
1275    thread_pool->workers->Schedule([=]() { shuffle_work(start, end); });
1276    num_out_rows -= (end - start);
1277    start = end;
1278  }
1279  return counter;
1280}
1281
1282template <typename TL, typename TR>
1283inline void SparseMatMul<TL, TR>::SliceMatrix(
1284    const MatrixR& mat, const int num_rows, const int num_slices,
1285    std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
1286  slices->resize(num_slices);
1287  DSizes d(num_rows, mat.dimension(1));
1288  DCHECK_LE(num_rows * num_slices, mat.dimension(0));
1289  for (int i = 0; i < num_slices; ++i) {
1290    (*slices)[i] = new ConstMatrixMapR(&mat(i * num_rows, 0), d);
1291  }
1292}
1293
1294template <typename TL, typename TR>
1295inline std::unique_ptr<BlockingCounter> SparseMatMul<TL, TR>::CreateDenseSlices(
1296    const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat, int row_start,
1297    int num_rows, int col_start, int num_cols,
1298    const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer,
1299    std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
1300  std::unique_ptr<BlockingCounter> shuffle_counter(ShuffleMatrix(
1301      mat, row_start, num_rows, col_start, num_cols, N, thread_pool, buffer));
1302  const int num_slices = (num_cols + N - 1) / N;
1303  SliceMatrix(*buffer, num_rows, num_slices, slices);
1304  return shuffle_counter;
1305}
1306
1307template <typename TL, typename TR>
1308inline void SparseMatMul<TL, TR>::ComputeBlockSizes(
1309    const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left,
1310    const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right,
1311    bool transpose_left, int num_threads, int* KR, int* NR, int* KL, int* JB,
1312    int* IB) {
1313  // Heuristics for calculating block sizes
1314  // Assume two hyperthreads per core.
1315  const int est_num_cores = std::max(1, (num_threads + 1) / 2);
1316  // Use block of rhs with at most 128K floats per core.
1317  const int mem = est_num_cores * 128 * 1024;
1318  *KR = std::min(static_cast<int>(right.dimension(0)), mem / 256);
1319  *NR = right.dimension(1);
1320  if (*KR * *NR > mem) {
1321    // 4096 may be enough to amortize the cost of writes.
1322    *KR = std::min<int>(*KR, 4096);
1323  }
1324  // Use sizes that are multiples of K and 256.
1325  *KR = std::max(1, *KR / K) * K;
1326  *NR = std::max(1, *NR / 256) * 256;
1327  if (*KR * *NR > mem) {
1328    *NR = mem / *KR;
1329  }
1330  *NR = std::max(1, *NR / 256) * 256;
1331
1332  const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
1333  const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
1334  for (*KL = 1024; *KL > K; *KL /= 2) {
1335    if (*KR % *KL == 0 &&
1336        std::max<int>(1, left_dim0 / 64) * (left_dim1 / *KL) > est_num_cores) {
1337      break;
1338    }
1339  }
1340  DCHECK_EQ(*KL % K, 0);
1341  DCHECK_GE(*KR, *KL);
1342  if (*KR < right.dimension(0)) {
1343    CHECK_EQ(*KR % *KL, 0);
1344  }
1345
1346  *JB = std::max(1, static_cast<int>(sqrt(num_threads) / 2.0));
1347  *IB = 8 * *JB;
1348  DCHECK_EQ(N * sizeof(float) % 64, size_t{0});
1349}
1350
1351#ifdef TENSORFLOW_USE_LIBXSMM
1352
1353template <typename F>
1354void do_on_all_threads(const DeviceBase::CpuWorkerThreads* thread_pool,
1355                       const F& f) {
1356  int num_threads = thread_pool->num_threads;
1357  if (num_threads == 0) {
1358    LOG(FATAL) << "Have 0 threads in thread pool";
1359  } else if (num_threads == 1) {
1360    f(0);
1361  } else {
1362    BlockingCounter counter(num_threads - 1);
1363    for (int i = 1; i < num_threads; ++i) {
1364      thread_pool->workers->Schedule([&, i]() {
1365        f(i);
1366        counter.DecrementCount();
1367      });
1368    }
1369    f(0);
1370    counter.Wait();
1371  }
1372}
1373
1374template <typename T>
1375struct empty_type_wrapper {};
1376
1377// Copies of interface to libxsmm_spmdm_createSparseSlice_*_notrans_thread to
1378// allow overloading
1379void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1380    empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
1381    const float* A, libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id,
1382    int tid, int nthreads) {
1383  return libxsmm_spmdm_createSparseSlice_fp32_thread(
1384      handle, transA, A, libxsmm_output_csr_a, block_id, tid, nthreads);
1385}
1386void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1387    empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle,
1388    char transA, const bfloat16* A,
1389    libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, int tid,
1390    int nthreads) {
1391  return libxsmm_spmdm_createSparseSlice_bfloat16_thread(
1392      handle, transA, reinterpret_cast<const uint16*>(A), libxsmm_output_csr_a,
1393      block_id, tid, nthreads);
1394}
1395
1396void wrapper_libxsmm_spmdm_compute_generic_thread(
1397    empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle,
1398    char transA, char transB, const bfloat16* alpha,
1399    libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC,
1400    const bfloat16* beta, float* C, int block_id, int tid, int nthreads) {
1401  return libxsmm_spmdm_compute_bfloat16_thread(
1402      handle, transA, transB, reinterpret_cast<const uint16*>(alpha), A_sparse,
1403      reinterpret_cast<const uint16*>(B), transC,
1404      reinterpret_cast<const uint16*>(beta), C, block_id, tid, nthreads);
1405}
1406void wrapper_libxsmm_spmdm_compute_generic_thread(
1407    empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
1408    char transB, const float* alpha, libxsmm_CSR_sparseslice* A_sparse,
1409    const float* B, char transC, const float* beta, float* C, int block_id,
1410    int tid, int nthreads) {
1411  return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, alpha,
1412                                           A_sparse, B, transC, beta, C,
1413                                           block_id, tid, nthreads);
1414}
1415
1416template <typename TL, typename TR>
1417inline void LibxsmmSparseMatMul<TL, TR>::Compute(
1418    typename LibxsmmSparseMatMul<TL, TR>::TensorInfoCache* cache,
1419    const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapL& left,
1420    const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapR& right,
1421    bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
1422    bool transpose_output, MatrixMap* output) {
1423  if (false) {
1424    // Not handled by libxsmm currently
1425    SparseMatMul<TL, TR>::Compute(
1426        nullptr /* Assumes no cached data for fallback */, left, right,
1427        transpose_left, thread_pool, transpose_output, output);
1428    return;
1429  }
1430  const int num_threads = thread_pool->num_threads;
1431  const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
1432  const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
1433  const int right_dim0 = right.dimension(0);
1434  const int right_dim1 = right.dimension(1);
1435  CHECK_EQ(left_dim1, right_dim0);
1436  CHECK_EQ(left_dim0,
1437           (transpose_output ? output->dimension(1) : output->dimension(0)));
1438  CHECK_EQ(right_dim1,
1439           (transpose_output ? output->dimension(0) : output->dimension(1)));
1440  if (left_dim0 < 32 || left_dim1 < 32 || right_dim1 < 32) {
1441    // Causes problems in libxsmm
1442    SparseMatMul<TL, TR>::Compute(
1443        nullptr /* Assumes no cached data for fallback */, left, right,
1444        transpose_left, thread_pool, transpose_output, output);
1445    return;
1446  }
1447  auto left_data = left.data();
1448  auto right_data = right.data();
1449  auto output_data = output->data();
1450  // Initialize libxsmm for this matrix; make sure another thread doesn't use
1451  // this handle
1452  auto entry =
1453      cache->take_cache_entry(left_dim0, right_dim0, right_dim1, num_threads);
1454  // Convert the left matrix to compressed sparse row (CSR) format
1455  ptrdiff_t total_num_creation_blocks =
1456      libxsmm_spmdm_get_num_createSparseSlice_blocks(&entry->handle);
1457  std::atomic<int> cur_create_block_number;
1458  cur_create_block_number.store(0);
1459  do_on_all_threads(thread_pool, [&](int i) {
1460    while (true) {
1461      int work_item = cur_create_block_number.fetch_add(1);
1462      if (work_item >= total_num_creation_blocks) break;
1463      wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1464          empty_type_wrapper<TL>{}, &entry->handle,
1465          (transpose_left ? 'T' : 'N'), left_data, entry->output_csr, work_item,
1466          i, num_threads);
1467    }
1468  });
1469  // Do matrix-matrix multiplication
1470  ptrdiff_t total_num_mult_blocks =
1471      libxsmm_spmdm_get_num_compute_blocks(&entry->handle);
1472  std::atomic<int> cur_mult_block_number;
1473  cur_mult_block_number.store(0);
1474  do_on_all_threads(thread_pool, [&](int i) {
1475    while (true) {
1476      int work_item = cur_mult_block_number.fetch_add(1);
1477      if (work_item >= total_num_mult_blocks) break;
1478      const TL alpha(1.0);  // Stored in a variable so we can get a pointer
1479      const TL beta(0.0);   // Stored in a variable so we can get a pointer
1480      wrapper_libxsmm_spmdm_compute_generic_thread(
1481          empty_type_wrapper<TL>{}, &entry->handle,
1482          (transpose_left ? 'T' : 'N'), 'N', &alpha, entry->output_csr,
1483          right_data, (transpose_output ? 'T' : 'N'), &beta, output_data,
1484          work_item, i, num_threads);
1485    }
1486  });
1487  // Put handle + CSR storage back into cache
1488  cache->return_cache_entry(std::move(entry));
1489}
1490
1491#endif  // TENSORFLOW_USE_LIBXSMM
1492
1493// Here is a an overview of the SparseMatMul code. Note that we assume that the
1494// left matrix is sparse.
1495//
1496// The matrix "left" is divided into a grid with blocksize of (M, KL). Each
1497// block is encoded as a SparseSlice. These grid elements are stored as
1498// std::vector<std::vector<SparseSlice>>. Each element of the outer vector
1499// represents M rows of the left matrix. Lets call these elements l_i and lets
1500// call each element of the inner vector L_mk.
1501//
1502// The matrix "right" is divided into a grid with block size KR * NR.  Lets
1503// denote the blocks on the right as R_kn. Note that we ensure that KL divides
1504// KR so that for each element R_kn, we don't need to multiply it with any
1505// partial L_mk blocks.
1506//
1507// We then multiply each right side block R_kn with the full "left" matrix and
1508// update the output. These iterations are run sequentially since R_kn are
1509// packed into the same underlying temporary buffer.
1510//
1511// In each iteration we do the following:
1512// 1. Create slices r_j of R_kn: We split R_kn into vertical blocks with N
1513//    (=128) columns and then concatenating these slices into a buffer. This is
1514//    done so that each slice r_j of R_kn is stored contiguously in memory. Note
1515//    that if R_kj has dimensions (KR, NR), we create NR / N slices, and the
1516//    buffer has dimensions (KR * NR / N, N) (assuming N divides NR).
1517// 2. For each (l_i, r_j), we compute the inner product using the GEPP function
1518//    and update the output block o_ij. These calls are further blocked to
1519//    reduce the working set size. In each iteration we take IB elements from
1520//    {l_i} and JB elements from {r_j} and compute the IB * JB inner products.
1521template <typename TL, typename TR>
1522inline void SparseMatMul<TL, TR>::Compute(
1523    typename SparseMatMul<TL, TR>::TensorInfoCache* /*cache*/,
1524    const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left,
1525    const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right,
1526    bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
1527    bool transpose_output, MatrixMap* output) {
1528  const int num_threads = thread_pool->num_threads;
1529  int KR, NR, KL, JB, IB;
1530  ComputeBlockSizes(left, right, transpose_left, num_threads, &KR, &NR, &KL,
1531                    &JB, &IB);
1532  // Slice the left matrix
1533  std::vector<std::vector<SparseSlice<TL>*>> left_slices;
1534  std::unique_ptr<BlockingCounter> sparse_slice_counter =
1535      CreateSparseSlices(ConstMatrixMapL(left.data(), left.dimensions()),
1536                         transpose_left, M, K, KL, &left_slices, thread_pool);
1537  const int num_left_slices = left_slices.size();
1538
1539  const int right_dim0 = right.dimension(0);
1540  const int right_dim1 = right.dimension(1);
1541  // Allocate buffer for storing slices of right matrix.
1542  // Note buffer needs enough space to hold at most a KR * NR matrix since that
1543  // is the block size per iteration.
1544  const int buffer_num_rows =
1545      std::min(KR, right_dim0) * (std::min(NR, right_dim1) + N - 1) / N;
1546  MatrixR buffer(buffer_num_rows, N);
1547  std::vector<ConstMatrixMapR*> right_slices;
1548
1549  std::vector<SparseSlice<TL>*> block_left_slices;
1550  std::vector<std::function<void(void)>> tasks;
1551  // Number of blocks based on block sizes of KR * NR.
1552  const int num_k_blocks = (right_dim0 + KR - 1) / KR;
1553  const int num_n_blocks = (right_dim1 + NR - 1) / NR;
1554  std::unique_ptr<BlockingCounter> dense_slice_counter;
1555
1556  for (int nb = 0; nb < num_n_blocks; ++nb) {
1557    const int right_num_cols =
1558        std::min(NR, static_cast<int>(right_dim1 - NR * nb));
1559    for (int kb = 0; kb < num_k_blocks; ++kb) {
1560      const int right_num_rows =
1561          std::min(KR, static_cast<int>(right_dim0 - KR * kb));
1562      dense_slice_counter = CreateDenseSlices(
1563          right, kb * KR, right_num_rows, nb * NR, right_num_cols, thread_pool,
1564          &buffer, &right_slices);
1565      const int num_right_slices = right_slices.size();
1566      tasks.reserve(num_left_slices * num_right_slices);
1567      for (int j_outer = 0; j_outer < num_right_slices; j_outer += JB) {
1568        for (int i_outer = 0; i_outer < num_left_slices; i_outer += IB) {
1569          for (int j_inner = j_outer;
1570               j_inner < std::min(num_right_slices, j_outer + JB); ++j_inner) {
1571            const int num_cols = std::min(N, right_num_cols - N * j_inner);
1572            for (int i_inner = i_outer;
1573                 i_inner < std::min(num_left_slices, i_outer + IB); ++i_inner) {
1574              block_left_slices.clear();
1575              int begin = kb * KR / KL;
1576              int end = std::min<int>((kb + 1) * KR / KL,
1577                                      (right.dimension(0) + KL - 1) / KL);
1578              DCHECK_LT(begin, end);
1579              block_left_slices.insert(block_left_slices.begin(),
1580                                       left_slices[i_inner].begin() + begin,
1581                                       left_slices[i_inner].begin() + end);
1582              tasks.push_back(std::bind(
1583                  &ComputeOutputBlock, block_left_slices,
1584                  std::ref(*right_slices[j_inner]), num_cols, M * i_inner,
1585                  N * j_inner + nb * NR, kb == 0, transpose_output, output));
1586            }
1587          }
1588        }
1589      }
1590      if (sparse_slice_counter) {
1591        sparse_slice_counter->Wait();
1592        sparse_slice_counter.reset(nullptr);
1593      }
1594      if (dense_slice_counter) {
1595        dense_slice_counter->Wait();
1596        dense_slice_counter.reset(nullptr);
1597      }
1598      BlockingCounter bc(tasks.size());
1599      for (const auto& t : tasks) {
1600        thread_pool->workers->Schedule([&bc, &t]() {
1601          t();
1602          bc.DecrementCount();
1603        });
1604      }
1605      bc.Wait();
1606      tasks.clear();
1607      gtl::STLDeleteElements(&right_slices);
1608      right_slices.clear();
1609    }
1610  }
1611  for (auto& left_slice : left_slices) {
1612    gtl::STLDeleteElements(&left_slice);
1613  }
1614}
1615
1616#define REGISTER_SPARSE_MATMUL(TA, TB)                   \
1617  REGISTER_KERNEL_BUILDER(Name("SparseMatMul")           \
1618                              .Device(DEVICE_CPU)        \
1619                              .TypeConstraint<TA>("Ta")  \
1620                              .TypeConstraint<TB>("Tb"), \
1621                          SparseMatMulOp<TA, TB, SparseMatMul>);
1622#ifdef TENSORFLOW_USE_LIBXSMM
1623#define REGISTER_SPARSE_MATMUL_LIBXSMM(TA, TB)           \
1624  REGISTER_KERNEL_BUILDER(Name("SparseMatMul")           \
1625                              .Device(DEVICE_CPU)        \
1626                              .TypeConstraint<TA>("Ta")  \
1627                              .TypeConstraint<TB>("Tb"), \
1628                          SparseMatMulOp<TA, TB, LibxsmmSparseMatMul>);
1629#endif
1630
1631REGISTER_SPARSE_MATMUL(bfloat16, bfloat16);
1632
1633REGISTER_SPARSE_MATMUL(float, bfloat16);
1634
1635REGISTER_SPARSE_MATMUL(bfloat16, float);
1636
1637#ifdef TENSORFLOW_USE_LIBXSMM
1638REGISTER_SPARSE_MATMUL_LIBXSMM(float, float);
1639#else
1640REGISTER_SPARSE_MATMUL(float, float);
1641#endif
1642
1643#undef REGISTER_SPARSE_MATMUL
1644
1645}  // end namespace tensorflow
1646