1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" 17 18#include <memory> 19#include <vector> 20 21#include "llvm/IR/BasicBlock.h" 22#include "llvm/IR/Instructions.h" 23#include "llvm/IR/Module.h" 24#include "llvm/IR/Value.h" 25#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" 26#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" 27#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" 28#include "tensorflow/compiler/xla/service/hlo_module.h" 29#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" 30#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 31#include "tensorflow/compiler/xla/shape_util.h" 32#include "tensorflow/compiler/xla/status_macros.h" 33#include "tensorflow/compiler/xla/util.h" 34#include "tensorflow/compiler/xla/xla_data.pb.h" 35#include "tensorflow/core/platform/logging.h" 36 37namespace xla { 38 39using llvm_ir::SetToFirstInsertPoint; 40 41namespace cpu { 42 43namespace { 44// Loads a tile of values from a 2D tensor. 45class TileLoader { 46 public: 47 // Constructs a TileLoader that will load a tile consisting of 48 // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at 49 // `major_dim_offset` in the major dimension. The tile size along the minor 50 // dimension is the vector size, and that is implicitly determined by `vsl`. 51 TileLoader(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder, 52 llvm::Value* matrix, int64 matrix_size_along_minor_dim, 53 llvm::Value* major_dim_offset, int64 tile_size_along_major_dim) 54 : vsl_(vsl) { 55 pointers_.reserve(tile_size_along_major_dim); 56 for (int64 i = 0; i < tile_size_along_major_dim; i++) { 57 llvm::Value* total_offset = ir_builder->CreateMul( 58 ir_builder->getInt64(matrix_size_along_minor_dim), 59 ir_builder->CreateAdd(ir_builder->getInt64(i), major_dim_offset)); 60 pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset)); 61 } 62 } 63 64 // Load a tile consisting of `tile_size_along_major_dim_` vectors starting at 65 // `major_dim_offset_` in the major dimension and `minor_dim_offset` in the 66 // minor dimension. 67 std::vector<llvm::Value*> LoadTile(llvm::Value* minor_dim_offset) const { 68 std::vector<llvm::Value*> result; 69 result.reserve(pointers_.size()); 70 for (const auto& pointer : pointers_) { 71 result.push_back(vsl_->LoadVector(pointer, minor_dim_offset)); 72 } 73 return result; 74 } 75 76 private: 77 VectorSupportLibrary* vsl_; 78 std::vector<llvm::Value*> pointers_; 79}; 80 81// Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the 82// layout of the vector does not matter). This implementation uses a tiling 83// scheme to improve performance. 84// 85// We logically separate the LHS matrix into four segments: 86// 87// +----------------------+---+ 88// | | | 89// | | | 90// | A | B | 91// | | | 92// | | | 93// | | | 94// +----------------------+---+ 95// | C | D | 96// +----------------------+---+ 97// 98// where A is the largest submatrix of the LHS that can be evenly dividied into 99// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: 100// 101// +---+---+---+---+ +--+--+--+--+ 102// |M00|M10|M20|M30| |V0|V1|V2|V3| 103// +---+---+---+---+ +--+--+--+--+ 104// |M01|M11|M21|M31| and |V0|V1|V2|V3| 105// +---+---+---+---+ +--+--+--+--+ 106// |M02|M12|M22|M32| |V0|V1|V2|V3| 107// +---+---+---+---+ +--+--+--+--+ 108// |M03|M13|M23|M33| |V0|V1|V2|V3| 109// +---+---+---+---+ +--+--+--+--+ 110// 111// (Legend: rows are horizontal and columns are vertical; and each column is one 112// llvm::Value of a vector type) 113// 114// where: 115// 116// a. The left tile is from the column major left matrix. 117// b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3] 118// vector loaded from the RHS vector. 119// 120// As we iterate through the column dimension, we compute the change to the 121// result vector by an elementwise multiplication between the two tiles above 122// followed by a reduction along the major dimension: 123// 124// +-----------------------------------+ 125// | M00*V0 + M10*V1 + M20*V2 + M30*V3 | 126// +-----------------------------------+ 127// | M01*V0 + M11*V1 + M21*V2 + M31*V3 | 128// Result[R:R+4] += +-----------------------------------+ 129// | M02*V0 + M12*V1 + M22*V2 + M32*V3 | 130// +-----------------------------------+ 131// | M03*V0 + M13*V1 + M23*V2 + M33*V3 | 132// +-----------------------------------+ 133// 134// Where R is the starting row for the tile. 135// 136// We have an inner epilogue loop to deal with the "C" submatrix and an outer 137// epilogue loop to deal with the B,D submarix. 138// 139// TODO(sanjoy): We should investigate if using gather loads and scatter stores 140// can be used here have the same inner loop for both column-major and row-major 141// matrix-vector products. 142class ColumnMajorMatrixVectorProductEmitter { 143 public: 144 ColumnMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, 145 int64 tile_rows, int64 tile_cols, 146 int64 m, int64 k, llvm::Value* lhs, 147 llvm::Value* rhs, llvm::Value* addend, 148 llvm::Value* result, 149 llvm::IRBuilder<>* ir_builder) 150 : scalar_type_(scalar_type), 151 tile_rows_(tile_rows), 152 tile_cols_(tile_cols), 153 m_(m), 154 k_(k), 155 lhs_(lhs), 156 rhs_(rhs), 157 addend_(addend), 158 result_(result), 159 ir_builder_(ir_builder), 160 ksl_(ir_builder_), 161 vsl_(scalar_type_, /*vector_size=*/tile_rows_, ir_builder_, "") { 162 CHECK(tile_rows_ > 0 && IsPowerOfTwo(static_cast<uint64>(tile_rows_))); 163 } 164 165 void Emit(); 166 167 private: 168 void EmitOuterLoopBody(llvm::Value* column, int64 column_count, 169 bool is_first_column); 170 171 TileLoader GetLhsTileLoader(llvm::Value* column_start, int64 column_count) { 172 return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_, 173 /*matrix_size_along_minor_dim=*/m_, 174 /*major_dim_offset=*/column_start, 175 /*tile_size_along_major_dim=*/column_count); 176 } 177 178 // Load a tile of values from the RHS. For the RHS a "tile" is a contiguous 179 // sequence of `count` values, each one broadcasted to the vector width. 180 std::vector<llvm::Value*> LoadRhsTile(llvm::Value* offset, int64 count) { 181 llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset); 182 std::vector<llvm::Value*> result; 183 result.reserve(count); 184 for (int64 i = 0; i < count; i++) { 185 result.push_back(vsl_.LoadBroadcast(base_pointer, i)); 186 } 187 return result; 188 } 189 190 void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, 191 const std::vector<llvm::Value*>& rhs_tile, 192 int64 columns, bool is_first_column); 193 194 void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns, 195 bool is_first_tiled_column); 196 197 PrimitiveType scalar_type_; 198 int64 tile_rows_; 199 int64 tile_cols_; 200 int64 m_; 201 int64 k_; 202 llvm::Value* lhs_; 203 llvm::Value* rhs_; 204 llvm::Value* addend_; 205 llvm::Value* result_; 206 llvm::IRBuilder<>* ir_builder_; 207 KernelSupportLibrary ksl_; 208 VectorSupportLibrary vsl_; 209}; 210 211void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody( 212 llvm::Value* column, int64 column_count, bool is_first_column) { 213 TileLoader lhs_tile_loader = GetLhsTileLoader(/*column_start=*/column, 214 /*column_count=*/column_count); 215 216 std::vector<llvm::Value*> rhs_tile = 217 LoadRhsTile(column, /*count=*/column_count); 218 EmitInnerLoopTiled(&lhs_tile_loader, rhs_tile, 219 /*columns=*/column_count, is_first_column); 220 EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column); 221} 222 223void ColumnMajorMatrixVectorProductEmitter::Emit() { 224 // See the comment on the class declaration for the algorithm used here. 225 int64 column_remainder = k_ % tile_cols_; 226 int64 column_limit = k_ - column_remainder; 227 228 ksl_.For("dot.outer.tiled", 229 /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols_, 230 [&](llvm::Value* column, bool is_first_column) { 231 EmitOuterLoopBody(column, tile_cols_, is_first_column); 232 }); 233 234 if (column_remainder != 0) { 235 EmitOuterLoopBody(ir_builder_->getInt64(column_limit), column_remainder, 236 column_limit == 0); 237 } 238} 239 240void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( 241 TileLoader* lhs_tile_loader, const std::vector<llvm::Value*>& rhs_tile, 242 int64 columns, bool is_first_column) { 243 int64 row_limit = m_ - (m_ % tile_rows_); 244 245 ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit, 246 /*step=*/tile_rows_, [&](llvm::Value* row) { 247 std::vector<llvm::Value*> lhs_tile = 248 lhs_tile_loader->LoadTile(/*minor_dim_offset=*/row); 249 llvm::Value* accumulator = 250 is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) 251 : vsl_.GetZeroVector()) 252 : vsl_.LoadVector(result_, row); 253 for (int i = 0; i < columns; i++) { 254 accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); 255 } 256 vsl_.StoreVector(accumulator, result_, row); 257 }); 258} 259 260void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( 261 llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) { 262 int64 row_start = m_ - (m_ % tile_rows_); 263 if (row_start == m_) { 264 return; 265 } 266 267 llvm::Value* columns_llvm = ir_builder_->getInt64(columns); 268 269 // for (col = current_tile_col; col < (columns + current_tile_col); col++) 270 // for (row = row_start, row < m_; row++) { 271 // result[row] += lhs[row, col] * rhs[col] 272 // // Also take into account that if col is 0 then result[row] is not 273 // // initialized. 274 // } 275 276 ksl_.For( 277 "dot.inner.epilg.outer", /*start=*/current_tile_col, 278 /*end=*/ir_builder_->CreateAdd(columns_llvm, current_tile_col), 279 /*step=*/1, /*peel_first_iteration=*/false, 280 [&](llvm::Value* col, llvm::Value* is_first_scalar_col) { 281 llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col); 282 llvm::Value* total_offset = 283 ir_builder_->CreateMul(col, ir_builder_->getInt64(m_)); 284 llvm::Value* lhs_base_pointer = 285 vsl_.ComputeOffsetPointer(lhs_, total_offset); 286 ksl_.For( 287 "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m_, 288 /*step=*/1, [&](llvm::Value* scalar_row) { 289 llvm::Value* product = vsl_.Mul( 290 vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); 291 llvm::Value* setting_result_first_time = ir_builder_->CreateAnd( 292 is_first_scalar_col, 293 ir_builder_->getInt1(is_first_tiled_column)); 294 ksl_.If( 295 setting_result_first_time, 296 /*true_block_generator=*/ 297 [&]() { 298 if (addend_) { 299 vsl_.StoreScalar( 300 vsl_.Add(vsl_.LoadScalar(addend_, scalar_row), 301 product), 302 result_, scalar_row); 303 } else { 304 vsl_.StoreScalar(product, result_, scalar_row); 305 } 306 }, 307 /*false_block_generator=*/ 308 [&]() { 309 vsl_.StoreScalar( 310 vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product), 311 result_, scalar_row); 312 }); 313 }); 314 }); 315} 316 317// Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the 318// layout of the vector does not matter). This implementation uses a tiling 319// scheme to improve performance. 320// 321// We logically separate the LHS matrix into four segments: 322// 323// +----------------------+---+ 324// | | | 325// | | | 326// | A | B | 327// | | | 328// | | | 329// | | | 330// +----------------------+---+ 331// | C | D | 332// +----------------------+---+ 333// 334// where A is the largest submatrix of the LHS that can be evenly dividied into 335// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: 336// 337// +---+---+---+---+ 338// |M00|M10|M20|M30| 339// +---+---+---+---+ +--+--+--+--+ 340// |M01|M11|M21|M31| and |V0|V1|V2|V3| 341// +---+---+---+---+ +--+--+--+--+ 342// |M02|M12|M22|M32| 343// +---+---+---+---+ 344// |M03|M13|M23|M33| 345// +---+---+---+---+ 346// 347// (Legend: rows are horizontal and columns are vertical; and each row is one 348// llvm::Value of a vector type) 349// 350// where: 351// 352// a. The left tile is loaded from the row major left matrix. 353// b. The right vector is loaded from the RHS vector. 354// 355// We keep 4 vector accumulators accumulating the following four vector 356// expressions as we iterate over the row dimension: 357// 358// +------+------+------+------+ 359// |M0I*V0|M1I*V1|M2I*V2|M3I*V3| for I in [0,4) 360// +------+------+------+------+ 361// 362// In the end we do a horizontal reduction over these 4 vector accumulators to 363// get 4 values in the result vector. 364// 365// We have an inner epilogue loop to deal with the "B" sub-matrix and an outer 366// epilogue loop to deal with the C,D submatrix. 367class RowMajorMatrixVectorProductEmitter { 368 public: 369 RowMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows, 370 int64 tile_cols, int64 m, int64 k, 371 llvm::Value* lhs, llvm::Value* rhs, 372 llvm::Value* addend, llvm::Value* result, 373 llvm::IRBuilder<>* ir_builder) 374 : scalar_type_(scalar_type), 375 tile_rows_(tile_rows), 376 tile_cols_(tile_cols), 377 m_(m), 378 k_(k), 379 lhs_(lhs), 380 rhs_(rhs), 381 addend_(addend), 382 result_(result), 383 ir_builder_(ir_builder), 384 ksl_(ir_builder_), 385 vsl_(scalar_type_, /*vector_size=*/tile_cols_, ir_builder_, "") { 386 CHECK(tile_cols_ > 0 && IsPowerOfTwo(static_cast<uint64>(tile_cols_))); 387 } 388 389 void Emit(); 390 391 private: 392 TileLoader GetLhsTileLoader(llvm::Value* row_start, int64 row_count) { 393 return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_, 394 /*matrix_size_along_minor_dim=*/k_, 395 /*major_dim_offset=*/row_start, 396 /*tile_size_along_major_dim=*/row_count); 397 } 398 399 void EmitOuterLoopBody(llvm::Value* row, int64 row_count); 400 401 void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, int64 rows, 402 std::vector<VectorVariable>* vector_accumulators); 403 404 void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows, 405 std::vector<ScalarVariable>* scalar_accumulators); 406 407 PrimitiveType scalar_type_; 408 int64 tile_rows_; 409 int64 tile_cols_; 410 int64 m_; 411 int64 k_; 412 llvm::Value* lhs_; 413 llvm::Value* rhs_; 414 llvm::Value* addend_; 415 llvm::Value* result_; 416 llvm::IRBuilder<>* ir_builder_; 417 KernelSupportLibrary ksl_; 418 VectorSupportLibrary vsl_; 419}; 420 421void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, 422 int64 row_count) { 423 TileLoader lhs_tile_loader = GetLhsTileLoader(/*row_start=*/row, 424 /*row_count=*/row_count); 425 std::vector<VectorVariable> vector_accumulators; 426 std::vector<ScalarVariable> scalar_accumulators; 427 for (int i = 0; i < row_count; i++) { 428 vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector()); 429 scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar()); 430 } 431 EmitInnerLoopTiled(&lhs_tile_loader, /*rows=*/row_count, 432 &vector_accumulators); 433 EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count, 434 &scalar_accumulators); 435 436 std::vector<llvm::Value*> accumulator_values; 437 std::transform( 438 vector_accumulators.begin(), vector_accumulators.end(), 439 std::back_inserter(accumulator_values), 440 [](const VectorVariable& vector_var) { return vector_var.Get(); }); 441 442 std::vector<llvm::Value*> horizontal_sums; 443 if (row_count == vsl_.vector_size()) { 444 if (addend_) { 445 horizontal_sums = vsl_.ComputeHorizontalSums( 446 std::move(accumulator_values), vsl_.LoadVector(addend_, row)); 447 } else { 448 horizontal_sums = 449 vsl_.ComputeHorizontalSums(std::move(accumulator_values)); 450 } 451 } else { 452 horizontal_sums = vsl_.ComputeHorizontalSums(std::move(accumulator_values)); 453 } 454 455 for (int i = 0; i < row_count; i++) { 456 llvm::Value* result_value = 457 vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get()); 458 llvm::Value* offset = ir_builder_->CreateAdd(ir_builder_->getInt64(i), row); 459 if (addend_ && row_count != vsl_.vector_size()) { 460 result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value); 461 } 462 vsl_.StoreScalar(result_value, result_, offset); 463 } 464} 465 466void RowMajorMatrixVectorProductEmitter::Emit() { 467 // See the comment on the class declaration for the algorithm used here. 468 int64 row_remainder = m_ % tile_rows_; 469 int64 row_limit = m_ - row_remainder; 470 471 ksl_.For("dot.outer.tiled", 472 /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows_, 473 [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows_); }); 474 475 if (row_remainder != 0) { 476 EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder); 477 } 478} 479 480void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( 481 TileLoader* lhs_tile_loader, int64 rows, 482 std::vector<VectorVariable>* vector_accumulators) { 483 int64 column_limit = k_ - (k_ % tile_cols_); 484 485 ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, 486 /*step=*/tile_cols_, [&](llvm::Value* col) { 487 std::vector<llvm::Value*> lhs_tile = 488 lhs_tile_loader->LoadTile(/*minor_dim_offset=*/col); 489 llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); 490 for (int i = 0; i < rows; i++) { 491 llvm::Value* old_sum = (*vector_accumulators)[i].Get(); 492 (*vector_accumulators)[i].Set( 493 vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); 494 } 495 }); 496} 497 498void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( 499 llvm::Value* current_tile_row, int64 rows, 500 std::vector<ScalarVariable>* scalar_accumulators) { 501 int64 column_start = k_ - (k_ % tile_cols_); 502 if (column_start == k_) { 503 return; 504 } 505 506 for (int r = 0; r < rows; r++) { 507 llvm::Value* total_offset = ir_builder_->CreateMul( 508 ir_builder_->CreateAdd(ir_builder_->getInt64(r), current_tile_row), 509 ir_builder_->getInt64(k_)); 510 llvm::Value* lhs_base_pointer = 511 vsl_.ComputeOffsetPointer(lhs_, total_offset); 512 ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k_, 513 /*step=*/1, [&](llvm::Value* scalar_col) { 514 llvm::Value* product = 515 vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), 516 vsl_.LoadScalar(rhs_, scalar_col)); 517 llvm::Value* old_value = (*scalar_accumulators)[r].Get(); 518 (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); 519 }); 520 } 521} 522 523} // namespace 524 525DotOpEmitter::DotOpEmitter( 526 const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, 527 const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, 528 const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, 529 llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, 530 const HloModuleConfig& hlo_module_config, 531 const TargetMachineFeatures& target_machine_features) 532 : dot_(dot), 533 transpose_lhs_(transpose_lhs), 534 transpose_rhs_(transpose_rhs), 535 target_array_(target_array), 536 lhs_array_(lhs_array), 537 rhs_array_(rhs_array), 538 addend_array_(addend_array), 539 executable_run_options_value_(executable_run_options_value), 540 ir_builder_(ir_builder), 541 hlo_module_config_(hlo_module_config), 542 target_machine_features_(target_machine_features) {} 543 544/* static */ tensorflow::Status DotOpEmitter::EmitDotOperation( 545 const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, 546 const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, 547 const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, 548 llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, 549 const HloModuleConfig& hlo_module_config, 550 const TargetMachineFeatures& target_machine_features) { 551 PrimitiveType type = target_array.GetShape().element_type(); 552 TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type); 553 DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array, 554 lhs_array, rhs_array, addend_array, 555 executable_run_options_value, ir_builder, 556 hlo_module_config, target_machine_features); 557 return dot_emitter.Emit(); 558} 559 560bool DotOpEmitter::ShapesAreLegalForRuntimeDot() const { return true; } 561 562bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { 563 if (dot_.shape().dimensions_size() != 2) { 564 return false; 565 } 566 567 PrimitiveType primitive_type = dot_.shape().element_type(); 568 569 if (!primitive_util::IsFloatingPointType(primitive_type) && 570 !primitive_util::IsIntegralType(primitive_type)) { 571 return false; 572 } 573 574 MatMultDims mat_mult_dims = GetMatMultDims(); 575 bool is_column_major_matrix_vector = false; 576 bool is_row_major_matrix_vector = false; 577 578 int64 m, k; 579 bool swap_operands; 580 581 if (mat_mult_dims.m == 1) { 582 bool rhs_effectively_row_major = 583 transpose_rhs_ ^ !mat_mult_dims.rhs_column_major; 584 if (rhs_effectively_row_major) { 585 k = mat_mult_dims.k; 586 m = mat_mult_dims.n; 587 is_column_major_matrix_vector = true; 588 swap_operands = true; 589 } else { 590 k = mat_mult_dims.k; 591 m = mat_mult_dims.n; 592 is_row_major_matrix_vector = true; 593 swap_operands = true; 594 } 595 } 596 597 if (mat_mult_dims.n == 1) { 598 bool lhs_effectively_column_major = 599 transpose_lhs_ ^ mat_mult_dims.lhs_column_major; 600 if (lhs_effectively_column_major) { 601 m = mat_mult_dims.m; 602 k = mat_mult_dims.k; 603 is_column_major_matrix_vector = true; 604 swap_operands = false; 605 } else { 606 m = mat_mult_dims.m; 607 k = mat_mult_dims.k; 608 is_row_major_matrix_vector = true; 609 swap_operands = false; 610 } 611 } 612 613 if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) { 614 return false; 615 } 616 617 int64 tiling_factor = GetGemvTilingFactor(); 618 CHECK_GT(tiling_factor, 0); 619 620 llvm::Value* result_op = target_array_.GetBasePointer(); 621 llvm::Value* lhs_op = 622 swap_operands ? rhs_array_.GetBasePointer() : lhs_array_.GetBasePointer(); 623 llvm::Value* rhs_op = 624 swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer(); 625 626 const bool enable_fast_math = 627 hlo_module_config_.debug_options().xla_enable_fast_math(); 628 const bool optimize_for_size = 629 options::OptimizeForSizeRequested(hlo_module_config_); 630 631 const int target_vector_register_element_size = 632 target_machine_features_.vector_register_num_elements( 633 *ir_builder_->GetInsertBlock()->getParent(), primitive_type); 634 635 // We may not always know the vector register size for the target we're 636 // compiling against, in which case target_vector_register_element_size is 0. 637 // In these cases we choose a default LLVM IR register size. 638 const int kUnknownTargetVectorRegisterSize = 4; 639 const int vector_register_element_size = 640 target_vector_register_element_size == 0 641 ? kUnknownTargetVectorRegisterSize 642 : target_vector_register_element_size; 643 644 if (is_column_major_matrix_vector) { 645 VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m 646 << " and k = " << k; 647 int64 tile_rows = vector_register_element_size; 648 int64 tile_cols = tiling_factor; 649 650 string kernel_name = tensorflow::strings::StrCat( 651 "col_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, 652 "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : ""); 653 654 KernelSupportLibrary::EmitAndCallOutlinedKernel( 655 /*enable_fast_math=*/enable_fast_math, 656 /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name, 657 lhs_op, rhs_op, 658 addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, 659 [this, tile_rows, tile_cols, m, k, primitive_type]( 660 llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, 661 llvm::Value* result_op) { 662 ColumnMajorMatrixVectorProductEmitter emitter( 663 primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, 664 addend_op, result_op, ir_builder_); 665 emitter.Emit(); 666 }); 667 } else { 668 VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m 669 << " and k = " << k; 670 int64 tile_rows = tiling_factor; 671 int64 tile_cols = vector_register_element_size; 672 673 string kernel_name = tensorflow::strings::StrCat( 674 "row_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, 675 "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : ""); 676 677 KernelSupportLibrary::EmitAndCallOutlinedKernel( 678 /*enable_fast_math=*/enable_fast_math, 679 /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name, 680 lhs_op, rhs_op, 681 addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, 682 [this, tile_rows, tile_cols, m, k, primitive_type]( 683 llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, 684 llvm::Value* result_op) { 685 RowMajorMatrixVectorProductEmitter emitter( 686 primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, 687 addend_op, result_op, ir_builder_); 688 emitter.Emit(); 689 }); 690 } 691 692 return true; 693} 694 695tensorflow::Status DotOpEmitter::Emit() { 696 // The dot operation performs a sum of products over dimension 0 of the left 697 // hand side operand and dimension 1 of the right hand side operand. 698 // 699 // Let the shapes of lhs and rhs be defined as below: 700 // 701 // lhs = [L{n-1} x L{n-2} x ... L{0}] 702 // rhs = [R{m-1} x R{m-2} x ... R{0}] 703 // 704 // The sum-of-products dimension in the lhs has size L{0} and the dimension in 705 // the rhs has size R{1}. Necessarily, then: 706 // 707 // L{0} == R{1} 708 // 709 // The output of the operation has the following shape: 710 // 711 // output = [L{n-1} x L{n-2} x ... L{1} x R{m-1} x R{m-2} x ... R{2} x R{0}] 712 // 713 // To perform the operation we construct a loop nest with one for-loop for 714 // each dimension of the output. Inside this loop nest is another for-loop 715 // which performs the sum-of-products (the reduction loop) before storing 716 // the result in the output buffer. 717 718 const Shape& lhs_shape = lhs_array_.GetShape(); 719 const Shape& rhs_shape = rhs_array_.GetShape(); 720 721 if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) { 722 // If the operands are scalar, don't emit any loops. 723 TF_RET_CHECK(ShapeUtil::IsScalar(lhs_shape) && 724 ShapeUtil::IsScalar(rhs_shape)); 725 return EmitScalarDot(); 726 } 727 728 if (EmitLlvmIrDotIfProfitable()) { 729 return Status::OK(); 730 } 731 732 CHECK_EQ(addend_array_, nullptr); 733 734 if (PotentiallyImplementedAsEigenDot(dot_)) { 735 return EmitCallToRuntime(); 736 } 737 738 // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special 739 // case where the reduction dimension is 0 for both LHS and RHS. This results 740 // in a vector dot product producing a scalar. 741 int64 lhs_reduction_dimension = 0; 742 if (ShapeUtil::Rank(lhs_shape) >= 2) { 743 lhs_reduction_dimension = 744 ShapeUtil::GetDimensionNumber(lhs_shape, transpose_lhs_ ? -2 : -1); 745 } 746 int64 rhs_reduction_dimension = 0; 747 if (ShapeUtil::Rank(rhs_shape) >= 2) { 748 rhs_reduction_dimension = 749 ShapeUtil::GetDimensionNumber(rhs_shape, transpose_rhs_ ? -1 : -2); 750 } 751 752 // Verify the reduction dimension in the two operands are the same size. 753 TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) == 754 rhs_shape.dimensions(rhs_reduction_dimension)); 755 756 bool lhs_reduction_along_minor_dimension = 757 lhs_reduction_dimension == LayoutUtil::Minor(lhs_shape.layout(), 0); 758 bool rhs_reduction_along_minor_dimension = 759 rhs_reduction_dimension == LayoutUtil::Minor(rhs_shape.layout(), 0); 760 761 // Create loop nests which loop through the LHS operand dimensions and the RHS 762 // operand dimensions. The reduction dimension of the LHS and RHS are handled 763 // in a separate innermost loop which performs the sum of products. 764 llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(&dot_), ir_builder_); 765 llvm_ir::IrArray::Index lhs_index = EmitOperandArrayLoopNest( 766 &loop_nest, lhs_array_, lhs_reduction_dimension, "lhs"); 767 llvm_ir::IrArray::Index rhs_index = EmitOperandArrayLoopNest( 768 &loop_nest, rhs_array_, rhs_reduction_dimension, "rhs"); 769 770 // Create the loop which does the sum of products reduction. 771 // 772 // The prevent_unrolling bit is working around a deficiency in LLVM's loop 773 // vectorization pipeline, wherein in some cases unrolling a loop can prevent 774 // effective vectorization. Since we know that the IR we generate when 775 // reducing across the minor dimension in both LHS and RHS is vectorized well 776 // by the loop vectorizer, we block unrolling in that case to stop loop unroll 777 // from messing up the vectorization. 778 std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop( 779 0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction", 780 /*prevent_unrolling=*/lhs_reduction_along_minor_dimension && 781 rhs_reduction_along_minor_dimension); 782 783 // The final entry in the rhs and lhs indexes is the indvar of the 784 // reduction loop. 785 lhs_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue(); 786 rhs_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue(); 787 788 // For computing the sum of products we alloca a single location to store the 789 // dot product result as we accumulate it within the reduction loop. After the 790 // reduction loop we load the result and store into the output array. 791 792 // Function entry basic block. 793 // - Emit alloca for accumulator 794 llvm::Function* func = reduction_loop->GetPreheaderBasicBlock()->getParent(); 795 SetToFirstInsertPoint(&func->getEntryBlock(), ir_builder_); 796 llvm::Type* accum_type = target_array_.GetElementLlvmType(); 797 llvm::Value* accum_address = ir_builder_->CreateAlloca( 798 accum_type, /*ArraySize=*/nullptr, "accum_address"); 799 800 // Preheader basic block of reduction loop: 801 // - Initialize accumulator to zero. 802 llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock(); 803 ir_builder_->SetInsertPoint(preheader_bb->getTerminator()); 804 805 ir_builder_->CreateStore(llvm::Constant::getNullValue(accum_type), 806 accum_address); 807 808 // Body basic block of reduction loop: 809 // - Load elements from lhs and rhs array. 810 // - Multiply lhs-element and rhs-element. 811 // - Load accumulator and add to product. 812 // - Store sum back into accumulator. 813 SetToFirstInsertPoint(reduction_loop->GetBodyBasicBlock(), ir_builder_); 814 815 llvm::Value* lhs_element = 816 lhs_array_.EmitReadArrayElement(lhs_index, ir_builder_); 817 llvm::Value* rhs_element = 818 rhs_array_.EmitReadArrayElement(rhs_index, ir_builder_); 819 820 llvm::Value* accum = ir_builder_->CreateLoad(accum_address); 821 llvm::Value* updated_accum; 822 if (ShapeUtil::ElementIsComplex(lhs_shape)) { 823 auto real = [&](llvm::Value* x) { 824 return ir_builder_->CreateExtractValue(x, {0}); 825 }; 826 auto imag = [&](llvm::Value* x) { 827 return ir_builder_->CreateExtractValue(x, {1}); 828 }; 829 llvm::Value* product_real = ir_builder_->CreateFSub( 830 ir_builder_->CreateFMul(real(lhs_element), real(rhs_element)), 831 ir_builder_->CreateFMul(imag(lhs_element), imag(rhs_element))); 832 llvm::Value* product_imag = ir_builder_->CreateFAdd( 833 ir_builder_->CreateFMul(real(lhs_element), imag(rhs_element)), 834 ir_builder_->CreateFMul(imag(lhs_element), real(rhs_element))); 835 updated_accum = ir_builder_->CreateInsertValue( 836 accum, ir_builder_->CreateFAdd(real(accum), product_real), {0}); 837 updated_accum = ir_builder_->CreateInsertValue( 838 updated_accum, ir_builder_->CreateFAdd(imag(accum), product_imag), {1}); 839 } else { 840 llvm::Value* product = ir_builder_->CreateFMul(lhs_element, rhs_element); 841 updated_accum = ir_builder_->CreateFAdd(accum, product); 842 } 843 ir_builder_->CreateStore(updated_accum, accum_address); 844 845 // Exit basic block of reduction loop. 846 // - Load accumulator value (the result). 847 // - Store into output array. 848 SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), ir_builder_); 849 850 llvm::Value* result = ir_builder_->CreateLoad(accum_address); 851 852 // Create index into target address. The target index is the concatenation of 853 // the rhs and lhs indexes with the reduction dimensions removed. The terms 854 // from the rhs index are the lower dimensions in the index so we add them 855 // first. 856 llvm_ir::IrArray::Index target_index; 857 for (int dimension = 0; dimension < lhs_index.size(); ++dimension) { 858 if (dimension != lhs_reduction_dimension) { 859 target_index.push_back(lhs_index[dimension]); 860 } 861 } 862 for (int dimension = 0; dimension < rhs_index.size(); ++dimension) { 863 if (dimension != rhs_reduction_dimension) { 864 target_index.push_back(rhs_index[dimension]); 865 } 866 } 867 868 target_array_.EmitWriteArrayElement(target_index, result, ir_builder_); 869 870 // Set the IR builder insert point to the exit basic block of the outer most 871 // loop. 872 ir_builder_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); 873 874 return tensorflow::Status::OK(); 875} 876 877tensorflow::Status DotOpEmitter::EmitScalarDot() { 878 // A scalar dot is just a scalar multiply. 879 llvm::Value* result; 880 llvm::Value* lhs_value = 881 lhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_); 882 llvm::Value* rhs_value = 883 rhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_); 884 if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) { 885#define REAL(x) ir_builder_->CreateExtractValue(x, {0}) 886#define IMAG(x) ir_builder_->CreateExtractValue(x, {1}) 887 llvm::Value* real = ir_builder_->CreateFSub( 888 ir_builder_->CreateFMul(REAL(lhs_value), REAL(rhs_value)), 889 ir_builder_->CreateFMul(IMAG(lhs_value), IMAG(rhs_value))); 890 llvm::Value* imag = ir_builder_->CreateFAdd( 891 ir_builder_->CreateFMul(REAL(lhs_value), IMAG(rhs_value)), 892 ir_builder_->CreateFMul(IMAG(lhs_value), REAL(rhs_value))); 893#undef IMAG 894#undef REAL 895 result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType()); 896 result = ir_builder_->CreateInsertValue(result, real, {0}); 897 result = ir_builder_->CreateInsertValue(result, imag, {1}); 898 } else { 899 result = ir_builder_->CreateFMul(lhs_value, rhs_value); 900 } 901 target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_); 902 return tensorflow::Status::OK(); 903} 904 905tensorflow::Status DotOpEmitter::EmitCallToRuntime() { 906 DCHECK(ShapesAreLegalForRuntimeDot()); 907 908 // The signature of the Eigen runtime matmul function is: 909 // 910 // (void)(void* run_options, float* out, float* lhs, float* rhs, 911 // int64 m, int64 n, int64 k, int32 transpose_lhs, 912 // int32 transpose_rhs); 913 // The two transpose_... parameters are actually booleans, but we use int32 914 // to avoid target-dependent calling convention details. 915 916 bool multi_threaded_eigen = 917 hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); 918 PrimitiveType type = target_array_.GetShape().element_type(); 919 llvm::Type* float_type; 920 const char* fn_name; 921 switch (type) { 922 case F32: 923 fn_name = multi_threaded_eigen 924 ? runtime::kEigenMatMulF32SymbolName 925 : runtime::kEigenSingleThreadedMatMulF32SymbolName; 926 float_type = ir_builder_->getFloatTy(); 927 break; 928 case F64: 929 fn_name = multi_threaded_eigen 930 ? runtime::kEigenMatMulF64SymbolName 931 : runtime::kEigenSingleThreadedMatMulF64SymbolName; 932 float_type = ir_builder_->getDoubleTy(); 933 break; 934 default: 935 return Unimplemented("Invalid type %s for dot operation", 936 PrimitiveType_Name(type).c_str()); 937 } 938 939 llvm::Type* float_ptr_type = float_type->getPointerTo(); 940 llvm::Type* int64_type = ir_builder_->getInt64Ty(); 941 llvm::Type* int32_type = ir_builder_->getInt32Ty(); 942 llvm::Type* int8_ptr_type = ir_builder_->getInt8Ty()->getPointerTo(); 943 llvm::FunctionType* matmul_type = llvm::FunctionType::get( 944 ir_builder_->getVoidTy(), 945 {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type, 946 int64_type, int64_type, int64_type, int32_type, int32_type}, 947 /*isVarArg=*/false); 948 949 llvm::Function* function = ir_builder_->GetInsertBlock()->getParent(); 950 llvm::Module* module = function->getParent(); 951 952 llvm::Function* matmul_func = llvm::cast<llvm::Function>( 953 module->getOrInsertFunction(fn_name, matmul_type)); 954 matmul_func->setCallingConv(llvm::CallingConv::C); 955 matmul_func->setDoesNotThrow(); 956 matmul_func->setOnlyAccessesArgMemory(); 957 958 // The Eigen runtime function expects column-major layout. If the matrices are 959 // row major, then use the following identity to compute the product: 960 // 961 // (A x B)^T = B^T x A^T 962 // 963 // The connection between this identity and memory layout is that the 964 // transpose operation can also be considered as an operation that changes the 965 // memory layout of a matrix from row-major to column-major or vice versa. 966 // 967 // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'. 968 969 MatMultDims mat_mult_dims = GetMatMultDims(); 970 971 CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major); 972 973 const llvm_ir::IrArray* lhs = &lhs_array_; 974 const llvm_ir::IrArray* rhs = &rhs_array_; 975 bool transpose_lhs = transpose_lhs_; 976 bool transpose_rhs = transpose_rhs_; 977 978 if (!mat_mult_dims.lhs_column_major) { 979 std::swap(mat_mult_dims.m, mat_mult_dims.n); 980 std::swap(lhs, rhs); 981 std::swap(transpose_lhs, transpose_rhs); 982 } 983 984 ir_builder_->CreateCall( 985 matmul_func, 986 {ir_builder_->CreateBitCast(executable_run_options_value_, int8_ptr_type), 987 ir_builder_->CreateBitCast(target_array_.GetBasePointer(), 988 float_ptr_type), 989 ir_builder_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type), 990 ir_builder_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type), 991 ir_builder_->getInt64(mat_mult_dims.m), 992 ir_builder_->getInt64(mat_mult_dims.n), 993 ir_builder_->getInt64(mat_mult_dims.k), 994 ir_builder_->getInt32(transpose_lhs), 995 ir_builder_->getInt32(transpose_rhs)}); 996 return tensorflow::Status::OK(); 997} 998 999DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { 1000 CHECK_EQ(dot_.shape().dimensions_size(), 2); 1001 1002 const Shape& lhs_shape = lhs_array_.GetShape(); 1003 const Shape& rhs_shape = rhs_array_.GetShape(); 1004 1005 return {lhs_shape.dimensions(transpose_lhs_ ? 1 : 0), 1006 lhs_shape.dimensions(transpose_lhs_ ? 0 : 1), 1007 rhs_shape.dimensions(transpose_rhs_ ? 0 : 1), 1008 LayoutUtil::Minor(lhs_shape.layout(), 0) == 0, 1009 LayoutUtil::Minor(rhs_shape.layout(), 0) == 0}; 1010} 1011 1012llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( 1013 llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array, 1014 int64 reduction_dimension, tensorflow::StringPiece name_suffix) { 1015 // Prepares the dimension list we will use to emit the loop nest. Outermost 1016 // loops are added first. Add loops in major-to-minor order, and skip the 1017 // reduction dimension. 1018 std::vector<int64> dimensions; 1019 const Shape& shape = operand_array.GetShape(); 1020 for (int i = LayoutUtil::MinorToMajor(shape).size() - 1; i >= 0; --i) { 1021 int64 dimension = LayoutUtil::Minor(shape.layout(), i); 1022 if (dimension != reduction_dimension) { 1023 dimensions.push_back(dimension); 1024 } 1025 } 1026 1027 // Create loop nest with one for-loop for each dimension of the 1028 // output. 1029 llvm_ir::IrArray::Index index = 1030 loop_nest->AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix); 1031 // Verify every dimension except the reduction dimension was set in the index. 1032 for (int dimension = 0; dimension < index.size(); ++dimension) { 1033 if (dimension == reduction_dimension) { 1034 DCHECK_EQ(nullptr, index[dimension]); 1035 } else { 1036 DCHECK_NE(nullptr, index[dimension]); 1037 } 1038 } 1039 return index; 1040} 1041 1042// Return whether the given shape is a matrix with no padding. 1043static bool IsRank2WithNoPadding(const Shape& shape) { 1044 return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape); 1045} 1046 1047// In a gemm operation where output = lhs * rhs, check whether the given shapes 1048// are valid for the operation. 1049static bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, 1050 const Shape& output_shape) { 1051 // The inputs and the output must 1052 // 1) be matrices with no padding, and 1053 // 2) have an allowed element type. 1054 return output_shape.element_type() == F32 && 1055 IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && 1056 IsRank2WithNoPadding(output_shape); 1057} 1058 1059bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { 1060 // For certain types of Dot, we can call Eigen 1061 if (hlo.opcode() == HloOpcode::kDot) { 1062 const Shape& lhs_shape = hlo.operand(0)->shape(); 1063 const Shape& rhs_shape = hlo.operand(1)->shape(); 1064 1065 if (ShapeUtil::HasZeroElements(lhs_shape) || 1066 ShapeUtil::HasZeroElements(rhs_shape)) { 1067 return false; 1068 } 1069 1070 if (ProfitableToImplementDotInTiledLlvmIr(hlo)) { 1071 return false; 1072 } 1073 1074 // If gemm can accept the operand shapes, use it rather than a custom 1075 // kernel. 1076 if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { 1077 // The size of the reduction dimension should match. The shape inference 1078 // guarantees this invariant, so the check here is for programming 1079 // errors. 1080 CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0)); 1081 return true; 1082 } 1083 } 1084 1085 if (hlo.opcode() == HloOpcode::kFusion && 1086 hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot && 1087 hlo.fused_expression_root()->opcode() == HloOpcode::kDot) { 1088 auto* dot = hlo.fused_expression_root(); 1089 const Shape& lhs_shape = dot->operand(0)->shape(); 1090 const Shape& rhs_shape = dot->operand(1)->shape(); 1091 if (ShapeUtil::HasZeroElements(lhs_shape) || 1092 ShapeUtil::HasZeroElements(rhs_shape)) { 1093 return false; 1094 } 1095 return true; 1096 } 1097 1098 return false; 1099} 1100 1101// For vector-matrix dot products, it is always profitable to make the Rhs 1102// column major. 1103tensorflow::gtl::optional<int64> ProfitableToMakeDotOperandColumnMajor( 1104 const HloInstruction& hlo) { 1105 if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() == 2 && 1106 hlo.shape().dimensions(0) == 1) { 1107 if (hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) == 0) { 1108 return 1; 1109 } 1110 return {}; 1111 } 1112 1113 if (hlo.opcode() == HloOpcode::kFusion && 1114 hlo.fusion_kind() == HloInstruction::FusionKind::kOutput) { 1115 auto* fusion_root = 1116 hlo.fused_instructions_computation()->root_instruction(); 1117 if (fusion_root->opcode() != HloOpcode::kAdd) { 1118 return {}; 1119 } 1120 1121 for (auto* fusion_root_op : fusion_root->operands()) { 1122 if (fusion_root_op->opcode() != HloOpcode::kDot) { 1123 continue; 1124 } 1125 if (auto operand_num = 1126 ProfitableToMakeDotOperandColumnMajor(*fusion_root_op)) { 1127 auto* operand = fusion_root_op->operand(*operand_num); 1128 if (operand->opcode() == HloOpcode::kParameter && 1129 operand->user_count() == 1) { 1130 return operand->parameter_number(); 1131 } 1132 } 1133 } 1134 } 1135 1136 return {}; 1137} 1138 1139bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) { 1140 // Any Matrix-Vector product of floating point or integral type, or 1141 // a transpose-dot fusion of the same can be lowered to a tiled LLVM 1142 // IR implementation. 1143 const Shape& shape = dot.shape(); 1144 return shape.dimensions_size() == 2 && 1145 (shape.dimensions(0) == 1 || shape.dimensions(1) == 1) && 1146 (primitive_util::IsFloatingPointType(shape.element_type()) || 1147 primitive_util::IsIntegralType(shape.element_type())); 1148} 1149 1150} // namespace cpu 1151} // namespace xla 1152