1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
17
18#include <memory>
19#include <vector>
20
21#include "llvm/IR/BasicBlock.h"
22#include "llvm/IR/Instructions.h"
23#include "llvm/IR/Module.h"
24#include "llvm/IR/Value.h"
25#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
26#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
27#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
28#include "tensorflow/compiler/xla/service/hlo_module.h"
29#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
30#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
31#include "tensorflow/compiler/xla/shape_util.h"
32#include "tensorflow/compiler/xla/status_macros.h"
33#include "tensorflow/compiler/xla/util.h"
34#include "tensorflow/compiler/xla/xla_data.pb.h"
35#include "tensorflow/core/platform/logging.h"
36
37namespace xla {
38
39using llvm_ir::SetToFirstInsertPoint;
40
41namespace cpu {
42
43namespace {
44// Loads a tile of values from a 2D tensor.
45class TileLoader {
46 public:
47  // Constructs a TileLoader that will load a tile consisting of
48  // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at
49  // `major_dim_offset` in the major dimension.  The tile size along the minor
50  // dimension is the vector size, and that is implicitly determined by `vsl`.
51  TileLoader(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder,
52             llvm::Value* matrix, int64 matrix_size_along_minor_dim,
53             llvm::Value* major_dim_offset, int64 tile_size_along_major_dim)
54      : vsl_(vsl) {
55    pointers_.reserve(tile_size_along_major_dim);
56    for (int64 i = 0; i < tile_size_along_major_dim; i++) {
57      llvm::Value* total_offset = ir_builder->CreateMul(
58          ir_builder->getInt64(matrix_size_along_minor_dim),
59          ir_builder->CreateAdd(ir_builder->getInt64(i), major_dim_offset));
60      pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset));
61    }
62  }
63
64  // Load a tile consisting of `tile_size_along_major_dim_` vectors starting at
65  // `major_dim_offset_` in the major dimension and `minor_dim_offset` in the
66  // minor dimension.
67  std::vector<llvm::Value*> LoadTile(llvm::Value* minor_dim_offset) const {
68    std::vector<llvm::Value*> result;
69    result.reserve(pointers_.size());
70    for (const auto& pointer : pointers_) {
71      result.push_back(vsl_->LoadVector(pointer, minor_dim_offset));
72    }
73    return result;
74  }
75
76 private:
77  VectorSupportLibrary* vsl_;
78  std::vector<llvm::Value*> pointers_;
79};
80
81// Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the
82// layout of the vector does not matter).  This implementation uses a tiling
83// scheme to improve performance.
84//
85// We logically separate the LHS matrix into four segments:
86//
87//   +----------------------+---+
88//   |                      |   |
89//   |                      |   |
90//   |         A            | B |
91//   |                      |   |
92//   |                      |   |
93//   |                      |   |
94//   +----------------------+---+
95//   |         C            | D |
96//   +----------------------+---+
97//
98// where A is the largest submatrix of the LHS that can be evenly dividied into
99// tiles.  For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have:
100//
101//   +---+---+---+---+       +--+--+--+--+
102//   |M00|M10|M20|M30|       |V0|V1|V2|V3|
103//   +---+---+---+---+       +--+--+--+--+
104//   |M01|M11|M21|M31| and   |V0|V1|V2|V3|
105//   +---+---+---+---+       +--+--+--+--+
106//   |M02|M12|M22|M32|       |V0|V1|V2|V3|
107//   +---+---+---+---+       +--+--+--+--+
108//   |M03|M13|M23|M33|       |V0|V1|V2|V3|
109//   +---+---+---+---+       +--+--+--+--+
110//
111// (Legend: rows are horizontal and columns are vertical; and each column is one
112// llvm::Value of a vector type)
113//
114// where:
115//
116//   a. The left tile is from the column major left matrix.
117//   b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3]
118//      vector loaded from the RHS vector.
119//
120// As we iterate through the column dimension, we compute the change to the
121// result vector by an elementwise multiplication between the two tiles above
122// followed by a reduction along the major dimension:
123//
124//                     +-----------------------------------+
125//                     | M00*V0 + M10*V1 + M20*V2 + M30*V3 |
126//                     +-----------------------------------+
127//                     | M01*V0 + M11*V1 + M21*V2 + M31*V3 |
128// Result[R:R+4] +=    +-----------------------------------+
129//                     | M02*V0 + M12*V1 + M22*V2 + M32*V3 |
130//                     +-----------------------------------+
131//                     | M03*V0 + M13*V1 + M23*V2 + M33*V3 |
132//                     +-----------------------------------+
133//
134// Where R is the starting row for the tile.
135//
136// We have an inner epilogue loop to deal with the "C" submatrix and an outer
137// epilogue loop to deal with the B,D submarix.
138//
139// TODO(sanjoy): We should investigate if using gather loads and scatter stores
140// can be used here have the same inner loop for both column-major and row-major
141// matrix-vector products.
142class ColumnMajorMatrixVectorProductEmitter {
143 public:
144  ColumnMajorMatrixVectorProductEmitter(PrimitiveType scalar_type,
145                                        int64 tile_rows, int64 tile_cols,
146                                        int64 m, int64 k, llvm::Value* lhs,
147                                        llvm::Value* rhs, llvm::Value* addend,
148                                        llvm::Value* result,
149                                        llvm::IRBuilder<>* ir_builder)
150      : scalar_type_(scalar_type),
151        tile_rows_(tile_rows),
152        tile_cols_(tile_cols),
153        m_(m),
154        k_(k),
155        lhs_(lhs),
156        rhs_(rhs),
157        addend_(addend),
158        result_(result),
159        ir_builder_(ir_builder),
160        ksl_(ir_builder_),
161        vsl_(scalar_type_, /*vector_size=*/tile_rows_, ir_builder_, "") {
162    CHECK(tile_rows_ > 0 && IsPowerOfTwo(static_cast<uint64>(tile_rows_)));
163  }
164
165  void Emit();
166
167 private:
168  void EmitOuterLoopBody(llvm::Value* column, int64 column_count,
169                         bool is_first_column);
170
171  TileLoader GetLhsTileLoader(llvm::Value* column_start, int64 column_count) {
172    return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_,
173                      /*matrix_size_along_minor_dim=*/m_,
174                      /*major_dim_offset=*/column_start,
175                      /*tile_size_along_major_dim=*/column_count);
176  }
177
178  // Load a tile of values from the RHS.  For the RHS a "tile" is a contiguous
179  // sequence of `count` values, each one broadcasted to the vector width.
180  std::vector<llvm::Value*> LoadRhsTile(llvm::Value* offset, int64 count) {
181    llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset);
182    std::vector<llvm::Value*> result;
183    result.reserve(count);
184    for (int64 i = 0; i < count; i++) {
185      result.push_back(vsl_.LoadBroadcast(base_pointer, i));
186    }
187    return result;
188  }
189
190  void EmitInnerLoopTiled(TileLoader* lhs_tile_loader,
191                          const std::vector<llvm::Value*>& rhs_tile,
192                          int64 columns, bool is_first_column);
193
194  void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns,
195                             bool is_first_tiled_column);
196
197  PrimitiveType scalar_type_;
198  int64 tile_rows_;
199  int64 tile_cols_;
200  int64 m_;
201  int64 k_;
202  llvm::Value* lhs_;
203  llvm::Value* rhs_;
204  llvm::Value* addend_;
205  llvm::Value* result_;
206  llvm::IRBuilder<>* ir_builder_;
207  KernelSupportLibrary ksl_;
208  VectorSupportLibrary vsl_;
209};
210
211void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody(
212    llvm::Value* column, int64 column_count, bool is_first_column) {
213  TileLoader lhs_tile_loader = GetLhsTileLoader(/*column_start=*/column,
214                                                /*column_count=*/column_count);
215
216  std::vector<llvm::Value*> rhs_tile =
217      LoadRhsTile(column, /*count=*/column_count);
218  EmitInnerLoopTiled(&lhs_tile_loader, rhs_tile,
219                     /*columns=*/column_count, is_first_column);
220  EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column);
221}
222
223void ColumnMajorMatrixVectorProductEmitter::Emit() {
224  // See the comment on the class declaration for the algorithm used here.
225  int64 column_remainder = k_ % tile_cols_;
226  int64 column_limit = k_ - column_remainder;
227
228  ksl_.For("dot.outer.tiled",
229           /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols_,
230           [&](llvm::Value* column, bool is_first_column) {
231             EmitOuterLoopBody(column, tile_cols_, is_first_column);
232           });
233
234  if (column_remainder != 0) {
235    EmitOuterLoopBody(ir_builder_->getInt64(column_limit), column_remainder,
236                      column_limit == 0);
237  }
238}
239
240void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
241    TileLoader* lhs_tile_loader, const std::vector<llvm::Value*>& rhs_tile,
242    int64 columns, bool is_first_column) {
243  int64 row_limit = m_ - (m_ % tile_rows_);
244
245  ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit,
246           /*step=*/tile_rows_, [&](llvm::Value* row) {
247             std::vector<llvm::Value*> lhs_tile =
248                 lhs_tile_loader->LoadTile(/*minor_dim_offset=*/row);
249             llvm::Value* accumulator =
250                 is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row)
251                                            : vsl_.GetZeroVector())
252                                 : vsl_.LoadVector(result_, row);
253             for (int i = 0; i < columns; i++) {
254               accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator);
255             }
256             vsl_.StoreVector(accumulator, result_, row);
257           });
258}
259
260void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
261    llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) {
262  int64 row_start = m_ - (m_ % tile_rows_);
263  if (row_start == m_) {
264    return;
265  }
266
267  llvm::Value* columns_llvm = ir_builder_->getInt64(columns);
268
269  // for (col = current_tile_col; col < (columns + current_tile_col); col++)
270  //   for (row = row_start, row < m_; row++) {
271  //     result[row] += lhs[row, col] * rhs[col]
272  //     // Also take into account that if col is 0 then result[row] is not
273  //     // initialized.
274  //   }
275
276  ksl_.For(
277      "dot.inner.epilg.outer", /*start=*/current_tile_col,
278      /*end=*/ir_builder_->CreateAdd(columns_llvm, current_tile_col),
279      /*step=*/1, /*peel_first_iteration=*/false,
280      [&](llvm::Value* col, llvm::Value* is_first_scalar_col) {
281        llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col);
282        llvm::Value* total_offset =
283            ir_builder_->CreateMul(col, ir_builder_->getInt64(m_));
284        llvm::Value* lhs_base_pointer =
285            vsl_.ComputeOffsetPointer(lhs_, total_offset);
286        ksl_.For(
287            "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m_,
288            /*step=*/1, [&](llvm::Value* scalar_row) {
289              llvm::Value* product = vsl_.Mul(
290                  vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element);
291              llvm::Value* setting_result_first_time = ir_builder_->CreateAnd(
292                  is_first_scalar_col,
293                  ir_builder_->getInt1(is_first_tiled_column));
294              ksl_.If(
295                  setting_result_first_time,
296                  /*true_block_generator=*/
297                  [&]() {
298                    if (addend_) {
299                      vsl_.StoreScalar(
300                          vsl_.Add(vsl_.LoadScalar(addend_, scalar_row),
301                                   product),
302                          result_, scalar_row);
303                    } else {
304                      vsl_.StoreScalar(product, result_, scalar_row);
305                    }
306                  },
307                  /*false_block_generator=*/
308                  [&]() {
309                    vsl_.StoreScalar(
310                        vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product),
311                        result_, scalar_row);
312                  });
313            });
314      });
315}
316
317// Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the
318// layout of the vector does not matter).  This implementation uses a tiling
319// scheme to improve performance.
320//
321// We logically separate the LHS matrix into four segments:
322//
323//   +----------------------+---+
324//   |                      |   |
325//   |                      |   |
326//   |         A            | B |
327//   |                      |   |
328//   |                      |   |
329//   |                      |   |
330//   +----------------------+---+
331//   |         C            | D |
332//   +----------------------+---+
333//
334// where A is the largest submatrix of the LHS that can be evenly dividied into
335// tiles.  For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have:
336//
337//   +---+---+---+---+
338//   |M00|M10|M20|M30|
339//   +---+---+---+---+       +--+--+--+--+
340//   |M01|M11|M21|M31| and   |V0|V1|V2|V3|
341//   +---+---+---+---+       +--+--+--+--+
342//   |M02|M12|M22|M32|
343//   +---+---+---+---+
344//   |M03|M13|M23|M33|
345//   +---+---+---+---+
346//
347// (Legend: rows are horizontal and columns are vertical; and each row is one
348// llvm::Value of a vector type)
349//
350// where:
351//
352//   a. The left tile is loaded from the row major left matrix.
353//   b. The right vector is loaded from the RHS vector.
354//
355// We keep 4 vector accumulators accumulating the following four vector
356// expressions as we iterate over the row dimension:
357//
358//   +------+------+------+------+
359//   |M0I*V0|M1I*V1|M2I*V2|M3I*V3|  for I in [0,4)
360//   +------+------+------+------+
361//
362// In the end we do a horizontal reduction over these 4 vector accumulators to
363// get 4 values in the result vector.
364//
365// We have an inner epilogue loop to deal with the "B" sub-matrix and an outer
366// epilogue loop to deal with the C,D submatrix.
367class RowMajorMatrixVectorProductEmitter {
368 public:
369  RowMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows,
370                                     int64 tile_cols, int64 m, int64 k,
371                                     llvm::Value* lhs, llvm::Value* rhs,
372                                     llvm::Value* addend, llvm::Value* result,
373                                     llvm::IRBuilder<>* ir_builder)
374      : scalar_type_(scalar_type),
375        tile_rows_(tile_rows),
376        tile_cols_(tile_cols),
377        m_(m),
378        k_(k),
379        lhs_(lhs),
380        rhs_(rhs),
381        addend_(addend),
382        result_(result),
383        ir_builder_(ir_builder),
384        ksl_(ir_builder_),
385        vsl_(scalar_type_, /*vector_size=*/tile_cols_, ir_builder_, "") {
386    CHECK(tile_cols_ > 0 && IsPowerOfTwo(static_cast<uint64>(tile_cols_)));
387  }
388
389  void Emit();
390
391 private:
392  TileLoader GetLhsTileLoader(llvm::Value* row_start, int64 row_count) {
393    return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_,
394                      /*matrix_size_along_minor_dim=*/k_,
395                      /*major_dim_offset=*/row_start,
396                      /*tile_size_along_major_dim=*/row_count);
397  }
398
399  void EmitOuterLoopBody(llvm::Value* row, int64 row_count);
400
401  void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, int64 rows,
402                          std::vector<VectorVariable>* vector_accumulators);
403
404  void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows,
405                             std::vector<ScalarVariable>* scalar_accumulators);
406
407  PrimitiveType scalar_type_;
408  int64 tile_rows_;
409  int64 tile_cols_;
410  int64 m_;
411  int64 k_;
412  llvm::Value* lhs_;
413  llvm::Value* rhs_;
414  llvm::Value* addend_;
415  llvm::Value* result_;
416  llvm::IRBuilder<>* ir_builder_;
417  KernelSupportLibrary ksl_;
418  VectorSupportLibrary vsl_;
419};
420
421void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row,
422                                                           int64 row_count) {
423  TileLoader lhs_tile_loader = GetLhsTileLoader(/*row_start=*/row,
424                                                /*row_count=*/row_count);
425  std::vector<VectorVariable> vector_accumulators;
426  std::vector<ScalarVariable> scalar_accumulators;
427  for (int i = 0; i < row_count; i++) {
428    vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector());
429    scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar());
430  }
431  EmitInnerLoopTiled(&lhs_tile_loader, /*rows=*/row_count,
432                     &vector_accumulators);
433  EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count,
434                        &scalar_accumulators);
435
436  std::vector<llvm::Value*> accumulator_values;
437  std::transform(
438      vector_accumulators.begin(), vector_accumulators.end(),
439      std::back_inserter(accumulator_values),
440      [](const VectorVariable& vector_var) { return vector_var.Get(); });
441
442  std::vector<llvm::Value*> horizontal_sums;
443  if (row_count == vsl_.vector_size()) {
444    if (addend_) {
445      horizontal_sums = vsl_.ComputeHorizontalSums(
446          std::move(accumulator_values), vsl_.LoadVector(addend_, row));
447    } else {
448      horizontal_sums =
449          vsl_.ComputeHorizontalSums(std::move(accumulator_values));
450    }
451  } else {
452    horizontal_sums = vsl_.ComputeHorizontalSums(std::move(accumulator_values));
453  }
454
455  for (int i = 0; i < row_count; i++) {
456    llvm::Value* result_value =
457        vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get());
458    llvm::Value* offset = ir_builder_->CreateAdd(ir_builder_->getInt64(i), row);
459    if (addend_ && row_count != vsl_.vector_size()) {
460      result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value);
461    }
462    vsl_.StoreScalar(result_value, result_, offset);
463  }
464}
465
466void RowMajorMatrixVectorProductEmitter::Emit() {
467  // See the comment on the class declaration for the algorithm used here.
468  int64 row_remainder = m_ % tile_rows_;
469  int64 row_limit = m_ - row_remainder;
470
471  ksl_.For("dot.outer.tiled",
472           /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows_,
473           [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows_); });
474
475  if (row_remainder != 0) {
476    EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder);
477  }
478}
479
480void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
481    TileLoader* lhs_tile_loader, int64 rows,
482    std::vector<VectorVariable>* vector_accumulators) {
483  int64 column_limit = k_ - (k_ % tile_cols_);
484
485  ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit,
486           /*step=*/tile_cols_, [&](llvm::Value* col) {
487             std::vector<llvm::Value*> lhs_tile =
488                 lhs_tile_loader->LoadTile(/*minor_dim_offset=*/col);
489             llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col);
490             for (int i = 0; i < rows; i++) {
491               llvm::Value* old_sum = (*vector_accumulators)[i].Get();
492               (*vector_accumulators)[i].Set(
493                   vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i])));
494             }
495           });
496}
497
498void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
499    llvm::Value* current_tile_row, int64 rows,
500    std::vector<ScalarVariable>* scalar_accumulators) {
501  int64 column_start = k_ - (k_ % tile_cols_);
502  if (column_start == k_) {
503    return;
504  }
505
506  for (int r = 0; r < rows; r++) {
507    llvm::Value* total_offset = ir_builder_->CreateMul(
508        ir_builder_->CreateAdd(ir_builder_->getInt64(r), current_tile_row),
509        ir_builder_->getInt64(k_));
510    llvm::Value* lhs_base_pointer =
511        vsl_.ComputeOffsetPointer(lhs_, total_offset);
512    ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k_,
513             /*step=*/1, [&](llvm::Value* scalar_col) {
514               llvm::Value* product =
515                   vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col),
516                            vsl_.LoadScalar(rhs_, scalar_col));
517               llvm::Value* old_value = (*scalar_accumulators)[r].Get();
518               (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product));
519             });
520  }
521}
522
523}  // namespace
524
525DotOpEmitter::DotOpEmitter(
526    const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs,
527    const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array,
528    const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array,
529    llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder,
530    const HloModuleConfig& hlo_module_config,
531    const TargetMachineFeatures& target_machine_features)
532    : dot_(dot),
533      transpose_lhs_(transpose_lhs),
534      transpose_rhs_(transpose_rhs),
535      target_array_(target_array),
536      lhs_array_(lhs_array),
537      rhs_array_(rhs_array),
538      addend_array_(addend_array),
539      executable_run_options_value_(executable_run_options_value),
540      ir_builder_(ir_builder),
541      hlo_module_config_(hlo_module_config),
542      target_machine_features_(target_machine_features) {}
543
544/* static */ tensorflow::Status DotOpEmitter::EmitDotOperation(
545    const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs,
546    const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array,
547    const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array,
548    llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder,
549    const HloModuleConfig& hlo_module_config,
550    const TargetMachineFeatures& target_machine_features) {
551  PrimitiveType type = target_array.GetShape().element_type();
552  TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type);
553  DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array,
554                           lhs_array, rhs_array, addend_array,
555                           executable_run_options_value, ir_builder,
556                           hlo_module_config, target_machine_features);
557  return dot_emitter.Emit();
558}
559
560bool DotOpEmitter::ShapesAreLegalForRuntimeDot() const { return true; }
561
562bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
563  if (dot_.shape().dimensions_size() != 2) {
564    return false;
565  }
566
567  PrimitiveType primitive_type = dot_.shape().element_type();
568
569  if (!primitive_util::IsFloatingPointType(primitive_type) &&
570      !primitive_util::IsIntegralType(primitive_type)) {
571    return false;
572  }
573
574  MatMultDims mat_mult_dims = GetMatMultDims();
575  bool is_column_major_matrix_vector = false;
576  bool is_row_major_matrix_vector = false;
577
578  int64 m, k;
579  bool swap_operands;
580
581  if (mat_mult_dims.m == 1) {
582    bool rhs_effectively_row_major =
583        transpose_rhs_ ^ !mat_mult_dims.rhs_column_major;
584    if (rhs_effectively_row_major) {
585      k = mat_mult_dims.k;
586      m = mat_mult_dims.n;
587      is_column_major_matrix_vector = true;
588      swap_operands = true;
589    } else {
590      k = mat_mult_dims.k;
591      m = mat_mult_dims.n;
592      is_row_major_matrix_vector = true;
593      swap_operands = true;
594    }
595  }
596
597  if (mat_mult_dims.n == 1) {
598    bool lhs_effectively_column_major =
599        transpose_lhs_ ^ mat_mult_dims.lhs_column_major;
600    if (lhs_effectively_column_major) {
601      m = mat_mult_dims.m;
602      k = mat_mult_dims.k;
603      is_column_major_matrix_vector = true;
604      swap_operands = false;
605    } else {
606      m = mat_mult_dims.m;
607      k = mat_mult_dims.k;
608      is_row_major_matrix_vector = true;
609      swap_operands = false;
610    }
611  }
612
613  if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) {
614    return false;
615  }
616
617  int64 tiling_factor = GetGemvTilingFactor();
618  CHECK_GT(tiling_factor, 0);
619
620  llvm::Value* result_op = target_array_.GetBasePointer();
621  llvm::Value* lhs_op =
622      swap_operands ? rhs_array_.GetBasePointer() : lhs_array_.GetBasePointer();
623  llvm::Value* rhs_op =
624      swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer();
625
626  const bool enable_fast_math =
627      hlo_module_config_.debug_options().xla_enable_fast_math();
628  const bool optimize_for_size =
629      options::OptimizeForSizeRequested(hlo_module_config_);
630
631  const int target_vector_register_element_size =
632      target_machine_features_.vector_register_num_elements(
633          *ir_builder_->GetInsertBlock()->getParent(), primitive_type);
634
635  // We may not always know the vector register size for the target we're
636  // compiling against, in which case target_vector_register_element_size is 0.
637  // In these cases we choose a default LLVM IR register size.
638  const int kUnknownTargetVectorRegisterSize = 4;
639  const int vector_register_element_size =
640      target_vector_register_element_size == 0
641          ? kUnknownTargetVectorRegisterSize
642          : target_vector_register_element_size;
643
644  if (is_column_major_matrix_vector) {
645    VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
646            << " and k = " << k;
647    int64 tile_rows = vector_register_element_size;
648    int64 tile_cols = tiling_factor;
649
650    string kernel_name = tensorflow::strings::StrCat(
651        "col_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows,
652        "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : "");
653
654    KernelSupportLibrary::EmitAndCallOutlinedKernel(
655        /*enable_fast_math=*/enable_fast_math,
656        /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name,
657        lhs_op, rhs_op,
658        addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op,
659        [this, tile_rows, tile_cols, m, k, primitive_type](
660            llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op,
661            llvm::Value* result_op) {
662          ColumnMajorMatrixVectorProductEmitter emitter(
663              primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op,
664              addend_op, result_op, ir_builder_);
665          emitter.Emit();
666        });
667  } else {
668    VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
669            << " and k = " << k;
670    int64 tile_rows = tiling_factor;
671    int64 tile_cols = vector_register_element_size;
672
673    string kernel_name = tensorflow::strings::StrCat(
674        "row_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows,
675        "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : "");
676
677    KernelSupportLibrary::EmitAndCallOutlinedKernel(
678        /*enable_fast_math=*/enable_fast_math,
679        /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name,
680        lhs_op, rhs_op,
681        addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op,
682        [this, tile_rows, tile_cols, m, k, primitive_type](
683            llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op,
684            llvm::Value* result_op) {
685          RowMajorMatrixVectorProductEmitter emitter(
686              primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op,
687              addend_op, result_op, ir_builder_);
688          emitter.Emit();
689        });
690  }
691
692  return true;
693}
694
695tensorflow::Status DotOpEmitter::Emit() {
696  // The dot operation performs a sum of products over dimension 0 of the left
697  // hand side operand and dimension 1 of the right hand side operand.
698  //
699  // Let the shapes of lhs and rhs be defined as below:
700  //
701  //   lhs = [L{n-1} x L{n-2} x ... L{0}]
702  //   rhs = [R{m-1} x R{m-2} x ... R{0}]
703  //
704  // The sum-of-products dimension in the lhs has size L{0} and the dimension in
705  // the rhs has size R{1}. Necessarily, then:
706  //
707  //   L{0} == R{1}
708  //
709  // The output of the operation has the following shape:
710  //
711  //   output = [L{n-1} x L{n-2} x ... L{1} x R{m-1} x R{m-2} x ... R{2} x R{0}]
712  //
713  // To perform the operation we construct a loop nest with one for-loop for
714  // each dimension of the output. Inside this loop nest is another for-loop
715  // which performs the sum-of-products (the reduction loop) before storing
716  // the result in the output buffer.
717
718  const Shape& lhs_shape = lhs_array_.GetShape();
719  const Shape& rhs_shape = rhs_array_.GetShape();
720
721  if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) {
722    // If the operands are scalar, don't emit any loops.
723    TF_RET_CHECK(ShapeUtil::IsScalar(lhs_shape) &&
724                 ShapeUtil::IsScalar(rhs_shape));
725    return EmitScalarDot();
726  }
727
728  if (EmitLlvmIrDotIfProfitable()) {
729    return Status::OK();
730  }
731
732  CHECK_EQ(addend_array_, nullptr);
733
734  if (PotentiallyImplementedAsEigenDot(dot_)) {
735    return EmitCallToRuntime();
736  }
737
738  // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special
739  // case where the reduction dimension is 0 for both LHS and RHS. This results
740  // in a vector dot product producing a scalar.
741  int64 lhs_reduction_dimension = 0;
742  if (ShapeUtil::Rank(lhs_shape) >= 2) {
743    lhs_reduction_dimension =
744        ShapeUtil::GetDimensionNumber(lhs_shape, transpose_lhs_ ? -2 : -1);
745  }
746  int64 rhs_reduction_dimension = 0;
747  if (ShapeUtil::Rank(rhs_shape) >= 2) {
748    rhs_reduction_dimension =
749        ShapeUtil::GetDimensionNumber(rhs_shape, transpose_rhs_ ? -1 : -2);
750  }
751
752  // Verify the reduction dimension in the two operands are the same size.
753  TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) ==
754               rhs_shape.dimensions(rhs_reduction_dimension));
755
756  bool lhs_reduction_along_minor_dimension =
757      lhs_reduction_dimension == LayoutUtil::Minor(lhs_shape.layout(), 0);
758  bool rhs_reduction_along_minor_dimension =
759      rhs_reduction_dimension == LayoutUtil::Minor(rhs_shape.layout(), 0);
760
761  // Create loop nests which loop through the LHS operand dimensions and the RHS
762  // operand dimensions. The reduction dimension of the LHS and RHS are handled
763  // in a separate innermost loop which performs the sum of products.
764  llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(&dot_), ir_builder_);
765  llvm_ir::IrArray::Index lhs_index = EmitOperandArrayLoopNest(
766      &loop_nest, lhs_array_, lhs_reduction_dimension, "lhs");
767  llvm_ir::IrArray::Index rhs_index = EmitOperandArrayLoopNest(
768      &loop_nest, rhs_array_, rhs_reduction_dimension, "rhs");
769
770  // Create the loop which does the sum of products reduction.
771  //
772  // The prevent_unrolling bit is working around a deficiency in LLVM's loop
773  // vectorization pipeline, wherein in some cases unrolling a loop can prevent
774  // effective vectorization.  Since we know that the IR we generate when
775  // reducing across the minor dimension in both LHS and RHS is vectorized well
776  // by the loop vectorizer, we block unrolling in that case to stop loop unroll
777  // from messing up the vectorization.
778  std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
779      0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction",
780      /*prevent_unrolling=*/lhs_reduction_along_minor_dimension &&
781          rhs_reduction_along_minor_dimension);
782
783  // The final entry in the rhs and lhs indexes is the indvar of the
784  // reduction loop.
785  lhs_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue();
786  rhs_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue();
787
788  // For computing the sum of products we alloca a single location to store the
789  // dot product result as we accumulate it within the reduction loop. After the
790  // reduction loop we load the result and store into the output array.
791
792  // Function entry basic block.
793  // - Emit alloca for accumulator
794  llvm::Function* func = reduction_loop->GetPreheaderBasicBlock()->getParent();
795  SetToFirstInsertPoint(&func->getEntryBlock(), ir_builder_);
796  llvm::Type* accum_type = target_array_.GetElementLlvmType();
797  llvm::Value* accum_address = ir_builder_->CreateAlloca(
798      accum_type, /*ArraySize=*/nullptr, "accum_address");
799
800  // Preheader basic block of reduction loop:
801  // - Initialize accumulator to zero.
802  llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock();
803  ir_builder_->SetInsertPoint(preheader_bb->getTerminator());
804
805  ir_builder_->CreateStore(llvm::Constant::getNullValue(accum_type),
806                           accum_address);
807
808  // Body basic block of reduction loop:
809  // - Load elements from lhs and rhs array.
810  // - Multiply lhs-element and rhs-element.
811  // - Load accumulator and add to product.
812  // - Store sum back into accumulator.
813  SetToFirstInsertPoint(reduction_loop->GetBodyBasicBlock(), ir_builder_);
814
815  llvm::Value* lhs_element =
816      lhs_array_.EmitReadArrayElement(lhs_index, ir_builder_);
817  llvm::Value* rhs_element =
818      rhs_array_.EmitReadArrayElement(rhs_index, ir_builder_);
819
820  llvm::Value* accum = ir_builder_->CreateLoad(accum_address);
821  llvm::Value* updated_accum;
822  if (ShapeUtil::ElementIsComplex(lhs_shape)) {
823    auto real = [&](llvm::Value* x) {
824      return ir_builder_->CreateExtractValue(x, {0});
825    };
826    auto imag = [&](llvm::Value* x) {
827      return ir_builder_->CreateExtractValue(x, {1});
828    };
829    llvm::Value* product_real = ir_builder_->CreateFSub(
830        ir_builder_->CreateFMul(real(lhs_element), real(rhs_element)),
831        ir_builder_->CreateFMul(imag(lhs_element), imag(rhs_element)));
832    llvm::Value* product_imag = ir_builder_->CreateFAdd(
833        ir_builder_->CreateFMul(real(lhs_element), imag(rhs_element)),
834        ir_builder_->CreateFMul(imag(lhs_element), real(rhs_element)));
835    updated_accum = ir_builder_->CreateInsertValue(
836        accum, ir_builder_->CreateFAdd(real(accum), product_real), {0});
837    updated_accum = ir_builder_->CreateInsertValue(
838        updated_accum, ir_builder_->CreateFAdd(imag(accum), product_imag), {1});
839  } else {
840    llvm::Value* product = ir_builder_->CreateFMul(lhs_element, rhs_element);
841    updated_accum = ir_builder_->CreateFAdd(accum, product);
842  }
843  ir_builder_->CreateStore(updated_accum, accum_address);
844
845  // Exit basic block of reduction loop.
846  // - Load accumulator value (the result).
847  // - Store into output array.
848  SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), ir_builder_);
849
850  llvm::Value* result = ir_builder_->CreateLoad(accum_address);
851
852  // Create index into target address. The target index is the concatenation of
853  // the rhs and lhs indexes with the reduction dimensions removed. The terms
854  // from the rhs index are the lower dimensions in the index so we add them
855  // first.
856  llvm_ir::IrArray::Index target_index;
857  for (int dimension = 0; dimension < lhs_index.size(); ++dimension) {
858    if (dimension != lhs_reduction_dimension) {
859      target_index.push_back(lhs_index[dimension]);
860    }
861  }
862  for (int dimension = 0; dimension < rhs_index.size(); ++dimension) {
863    if (dimension != rhs_reduction_dimension) {
864      target_index.push_back(rhs_index[dimension]);
865    }
866  }
867
868  target_array_.EmitWriteArrayElement(target_index, result, ir_builder_);
869
870  // Set the IR builder insert point to the exit basic block of the outer most
871  // loop.
872  ir_builder_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
873
874  return tensorflow::Status::OK();
875}
876
877tensorflow::Status DotOpEmitter::EmitScalarDot() {
878  // A scalar dot is just a scalar multiply.
879  llvm::Value* result;
880  llvm::Value* lhs_value =
881      lhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_);
882  llvm::Value* rhs_value =
883      rhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_);
884  if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) {
885#define REAL(x) ir_builder_->CreateExtractValue(x, {0})
886#define IMAG(x) ir_builder_->CreateExtractValue(x, {1})
887    llvm::Value* real = ir_builder_->CreateFSub(
888        ir_builder_->CreateFMul(REAL(lhs_value), REAL(rhs_value)),
889        ir_builder_->CreateFMul(IMAG(lhs_value), IMAG(rhs_value)));
890    llvm::Value* imag = ir_builder_->CreateFAdd(
891        ir_builder_->CreateFMul(REAL(lhs_value), IMAG(rhs_value)),
892        ir_builder_->CreateFMul(IMAG(lhs_value), REAL(rhs_value)));
893#undef IMAG
894#undef REAL
895    result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType());
896    result = ir_builder_->CreateInsertValue(result, real, {0});
897    result = ir_builder_->CreateInsertValue(result, imag, {1});
898  } else {
899    result = ir_builder_->CreateFMul(lhs_value, rhs_value);
900  }
901  target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_);
902  return tensorflow::Status::OK();
903}
904
905tensorflow::Status DotOpEmitter::EmitCallToRuntime() {
906  DCHECK(ShapesAreLegalForRuntimeDot());
907
908  // The signature of the Eigen runtime matmul function is:
909  //
910  //   (void)(void* run_options, float* out, float* lhs, float* rhs,
911  //          int64 m, int64 n, int64 k, int32 transpose_lhs,
912  //          int32 transpose_rhs);
913  // The two transpose_... parameters are actually booleans, but we use int32
914  // to avoid target-dependent calling convention details.
915
916  bool multi_threaded_eigen =
917      hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
918  PrimitiveType type = target_array_.GetShape().element_type();
919  llvm::Type* float_type;
920  const char* fn_name;
921  switch (type) {
922    case F32:
923      fn_name = multi_threaded_eigen
924                    ? runtime::kEigenMatMulF32SymbolName
925                    : runtime::kEigenSingleThreadedMatMulF32SymbolName;
926      float_type = ir_builder_->getFloatTy();
927      break;
928    case F64:
929      fn_name = multi_threaded_eigen
930                    ? runtime::kEigenMatMulF64SymbolName
931                    : runtime::kEigenSingleThreadedMatMulF64SymbolName;
932      float_type = ir_builder_->getDoubleTy();
933      break;
934    default:
935      return Unimplemented("Invalid type %s for dot operation",
936                           PrimitiveType_Name(type).c_str());
937  }
938
939  llvm::Type* float_ptr_type = float_type->getPointerTo();
940  llvm::Type* int64_type = ir_builder_->getInt64Ty();
941  llvm::Type* int32_type = ir_builder_->getInt32Ty();
942  llvm::Type* int8_ptr_type = ir_builder_->getInt8Ty()->getPointerTo();
943  llvm::FunctionType* matmul_type = llvm::FunctionType::get(
944      ir_builder_->getVoidTy(),
945      {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type,
946       int64_type, int64_type, int64_type, int32_type, int32_type},
947      /*isVarArg=*/false);
948
949  llvm::Function* function = ir_builder_->GetInsertBlock()->getParent();
950  llvm::Module* module = function->getParent();
951
952  llvm::Function* matmul_func = llvm::cast<llvm::Function>(
953      module->getOrInsertFunction(fn_name, matmul_type));
954  matmul_func->setCallingConv(llvm::CallingConv::C);
955  matmul_func->setDoesNotThrow();
956  matmul_func->setOnlyAccessesArgMemory();
957
958  // The Eigen runtime function expects column-major layout. If the matrices are
959  // row major, then use the following identity to compute the product:
960  //
961  //   (A x B)^T = B^T x A^T
962  //
963  // The connection between this identity and memory layout is that the
964  // transpose operation can also be considered as an operation that changes the
965  // memory layout of a matrix from row-major to column-major or vice versa.
966  //
967  // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'.
968
969  MatMultDims mat_mult_dims = GetMatMultDims();
970
971  CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major);
972
973  const llvm_ir::IrArray* lhs = &lhs_array_;
974  const llvm_ir::IrArray* rhs = &rhs_array_;
975  bool transpose_lhs = transpose_lhs_;
976  bool transpose_rhs = transpose_rhs_;
977
978  if (!mat_mult_dims.lhs_column_major) {
979    std::swap(mat_mult_dims.m, mat_mult_dims.n);
980    std::swap(lhs, rhs);
981    std::swap(transpose_lhs, transpose_rhs);
982  }
983
984  ir_builder_->CreateCall(
985      matmul_func,
986      {ir_builder_->CreateBitCast(executable_run_options_value_, int8_ptr_type),
987       ir_builder_->CreateBitCast(target_array_.GetBasePointer(),
988                                  float_ptr_type),
989       ir_builder_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type),
990       ir_builder_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type),
991       ir_builder_->getInt64(mat_mult_dims.m),
992       ir_builder_->getInt64(mat_mult_dims.n),
993       ir_builder_->getInt64(mat_mult_dims.k),
994       ir_builder_->getInt32(transpose_lhs),
995       ir_builder_->getInt32(transpose_rhs)});
996  return tensorflow::Status::OK();
997}
998
999DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
1000  CHECK_EQ(dot_.shape().dimensions_size(), 2);
1001
1002  const Shape& lhs_shape = lhs_array_.GetShape();
1003  const Shape& rhs_shape = rhs_array_.GetShape();
1004
1005  return {lhs_shape.dimensions(transpose_lhs_ ? 1 : 0),
1006          lhs_shape.dimensions(transpose_lhs_ ? 0 : 1),
1007          rhs_shape.dimensions(transpose_rhs_ ? 0 : 1),
1008          LayoutUtil::Minor(lhs_shape.layout(), 0) == 0,
1009          LayoutUtil::Minor(rhs_shape.layout(), 0) == 0};
1010}
1011
1012llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest(
1013    llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array,
1014    int64 reduction_dimension, tensorflow::StringPiece name_suffix) {
1015  // Prepares the dimension list we will use to emit the loop nest. Outermost
1016  // loops are added first. Add loops in major-to-minor order, and skip the
1017  // reduction dimension.
1018  std::vector<int64> dimensions;
1019  const Shape& shape = operand_array.GetShape();
1020  for (int i = LayoutUtil::MinorToMajor(shape).size() - 1; i >= 0; --i) {
1021    int64 dimension = LayoutUtil::Minor(shape.layout(), i);
1022    if (dimension != reduction_dimension) {
1023      dimensions.push_back(dimension);
1024    }
1025  }
1026
1027  // Create loop nest with one for-loop for each dimension of the
1028  // output.
1029  llvm_ir::IrArray::Index index =
1030      loop_nest->AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix);
1031  // Verify every dimension except the reduction dimension was set in the index.
1032  for (int dimension = 0; dimension < index.size(); ++dimension) {
1033    if (dimension == reduction_dimension) {
1034      DCHECK_EQ(nullptr, index[dimension]);
1035    } else {
1036      DCHECK_NE(nullptr, index[dimension]);
1037    }
1038  }
1039  return index;
1040}
1041
1042// Return whether the given shape is a matrix with no padding.
1043static bool IsRank2WithNoPadding(const Shape& shape) {
1044  return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape);
1045}
1046
1047// In a gemm operation where output = lhs * rhs, check whether the given shapes
1048// are valid for the operation.
1049static bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
1050                               const Shape& output_shape) {
1051  // The inputs and the output must
1052  // 1) be matrices with no padding, and
1053  // 2) have an allowed element type.
1054  return output_shape.element_type() == F32 &&
1055         IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) &&
1056         IsRank2WithNoPadding(output_shape);
1057}
1058
1059bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) {
1060  // For certain types of Dot, we can call Eigen
1061  if (hlo.opcode() == HloOpcode::kDot) {
1062    const Shape& lhs_shape = hlo.operand(0)->shape();
1063    const Shape& rhs_shape = hlo.operand(1)->shape();
1064
1065    if (ShapeUtil::HasZeroElements(lhs_shape) ||
1066        ShapeUtil::HasZeroElements(rhs_shape)) {
1067      return false;
1068    }
1069
1070    if (ProfitableToImplementDotInTiledLlvmIr(hlo)) {
1071      return false;
1072    }
1073
1074    // If gemm can accept the operand shapes, use it rather than a custom
1075    // kernel.
1076    if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) {
1077      // The size of the reduction dimension should match. The shape inference
1078      // guarantees this invariant, so the check here is for programming
1079      // errors.
1080      CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0));
1081      return true;
1082    }
1083  }
1084
1085  if (hlo.opcode() == HloOpcode::kFusion &&
1086      hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot &&
1087      hlo.fused_expression_root()->opcode() == HloOpcode::kDot) {
1088    auto* dot = hlo.fused_expression_root();
1089    const Shape& lhs_shape = dot->operand(0)->shape();
1090    const Shape& rhs_shape = dot->operand(1)->shape();
1091    if (ShapeUtil::HasZeroElements(lhs_shape) ||
1092        ShapeUtil::HasZeroElements(rhs_shape)) {
1093      return false;
1094    }
1095    return true;
1096  }
1097
1098  return false;
1099}
1100
1101// For vector-matrix dot products, it is always profitable to make the Rhs
1102// column major.
1103tensorflow::gtl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
1104    const HloInstruction& hlo) {
1105  if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() == 2 &&
1106      hlo.shape().dimensions(0) == 1) {
1107    if (hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) == 0) {
1108      return 1;
1109    }
1110    return {};
1111  }
1112
1113  if (hlo.opcode() == HloOpcode::kFusion &&
1114      hlo.fusion_kind() == HloInstruction::FusionKind::kOutput) {
1115    auto* fusion_root =
1116        hlo.fused_instructions_computation()->root_instruction();
1117    if (fusion_root->opcode() != HloOpcode::kAdd) {
1118      return {};
1119    }
1120
1121    for (auto* fusion_root_op : fusion_root->operands()) {
1122      if (fusion_root_op->opcode() != HloOpcode::kDot) {
1123        continue;
1124      }
1125      if (auto operand_num =
1126              ProfitableToMakeDotOperandColumnMajor(*fusion_root_op)) {
1127        auto* operand = fusion_root_op->operand(*operand_num);
1128        if (operand->opcode() == HloOpcode::kParameter &&
1129            operand->user_count() == 1) {
1130          return operand->parameter_number();
1131        }
1132      }
1133    }
1134  }
1135
1136  return {};
1137}
1138
1139bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) {
1140  // Any Matrix-Vector product of floating point or integral type, or
1141  // a transpose-dot fusion of the same can be lowered to a tiled LLVM
1142  // IR implementation.
1143  const Shape& shape = dot.shape();
1144  return shape.dimensions_size() == 2 &&
1145         (shape.dimensions(0) == 1 || shape.dimensions(1) == 1) &&
1146         (primitive_util::IsFloatingPointType(shape.element_type()) ||
1147          primitive_util::IsIntegralType(shape.element_type()));
1148}
1149
1150}  // namespace cpu
1151}  // namespace xla
1152