1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/service/transpose_folding.h" 17 18#include <memory> 19#include <unordered_set> 20#include <vector> 21 22#include "tensorflow/compiler/xla/client/computation_builder.h" 23#include "tensorflow/compiler/xla/literal_util.h" 24#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" 25#include "tensorflow/compiler/xla/service/hlo_computation.h" 26#include "tensorflow/compiler/xla/service/hlo_instruction.h" 27#include "tensorflow/compiler/xla/service/hlo_module.h" 28#include "tensorflow/compiler/xla/service/hlo_opcode.h" 29#include "tensorflow/compiler/xla/service/shape_inference.h" 30#include "tensorflow/compiler/xla/shape_util.h" 31#include "tensorflow/compiler/xla/test.h" 32#include "tensorflow/compiler/xla/test_helpers.h" 33#include "tensorflow/compiler/xla/tests/hlo_test_base.h" 34#include "tensorflow/compiler/xla/xla_data.pb.h" 35#include "tensorflow/core/platform/logging.h" 36 37namespace xla { 38namespace { 39 40class TransposeFoldingTest : public HloTestBase { 41 protected: 42 void FoldTranspose(HloModule* module) { 43 TransposeFolding transpose_folding( 44 [](const HloInstruction& dot, 45 const TransposeFolding::OperandIndices& candidate_operands) { 46 return candidate_operands; 47 }, 48 [](const HloInstruction& convolution, 49 const TransposeFolding::OperandIndices& candidate_operands) { 50 return candidate_operands; 51 }); 52 EXPECT_IS_OK(transpose_folding.Run(module).status()); 53 } 54}; 55 56TEST_F(TransposeFoldingTest, FoldDotTranspose) { 57 auto builder = HloComputation::Builder("entry_computation"); 58 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( 59 /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}), 60 /*name=*/"x")); 61 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( 62 /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}), 63 /*name=*/"y")); 64 HloInstruction* transpose_y = 65 builder.AddInstruction(HloInstruction::CreateTranspose( 66 ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); 67 DotDimensionNumbers dot_dnums; 68 dot_dnums.add_lhs_contracting_dimensions(1); 69 dot_dnums.add_rhs_contracting_dimensions(0); 70 HloInstruction* dot = builder.AddInstruction( 71 HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, 72 /*rhs=*/transpose_y, dot_dnums)); 73 74 HloModule module("test_module"); 75 HloComputation* entry_computation = 76 module.AddEntryComputation(builder.Build(dot)); 77 FoldTranspose(&module); 78 79 // Instructions after folding: x, y, and the fusion. 80 std::unordered_set<HloInstruction*> instruction_set( 81 entry_computation->instructions().begin(), 82 entry_computation->instructions().end()); 83 CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; 84 CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; 85 CHECK_EQ(1, instruction_set.size()) 86 << "entry_computation should contain exactly 3 instructions."; 87 HloInstruction* fusion = *instruction_set.begin(); 88 EXPECT_EQ(HloOpcode::kFusion, fusion->opcode()); 89 90 // The fusion instruction should contain two parameters, one transpose and 91 // one dot. 92 EXPECT_EQ(4, fusion->fused_instruction_count()); 93} 94 95TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { 96 auto builder = HloComputation::Builder("entry_computation"); 97 // 2x1 98 HloInstruction* const0 = builder.AddInstruction( 99 HloInstruction::CreateConstant(Literal::CreateR2<float>({{1}, {2}}))); 100 // 3x2 101 HloInstruction* const1 = 102 builder.AddInstruction(HloInstruction::CreateConstant( 103 Literal::CreateR2<float>({{1, 2}, {3, 4}, {5, 6}}))); 104 HloInstruction* transpose0 = 105 builder.AddInstruction(HloInstruction::CreateTranspose( 106 ShapeUtil::MakeShape(F32, {1, 2}), const0, {1, 0})); 107 HloInstruction* transpose1 = 108 builder.AddInstruction(HloInstruction::CreateTranspose( 109 ShapeUtil::MakeShape(F32, {2, 3}), const1, {1, 0})); 110 DotDimensionNumbers dot_dnums; 111 dot_dnums.add_lhs_contracting_dimensions(1); 112 dot_dnums.add_rhs_contracting_dimensions(0); 113 HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( 114 ShapeUtil::MakeShape(F32, {1, 3}), 115 /*lhs=*/transpose0, /*rhs=*/transpose1, dot_dnums)); 116 117 HloModule module("test_module"); 118 HloComputation* entry_computation = 119 module.AddEntryComputation(builder.Build(dot)); 120 FoldTranspose(&module); 121 122 for (auto* instruction : entry_computation->instructions()) { 123 if (instruction->opcode() == HloOpcode::kFusion) { 124 CHECK_EQ(2, instruction->operand_count()); 125 EXPECT_EQ(const0, instruction->operand(0)); 126 EXPECT_EQ(const1, instruction->operand(1)); 127 } 128 } 129 130 // The created fusion instruction should contain two parameters, two 131 // transposes (one for each parameter) and one dot. 132 EXPECT_EQ(5, 133 entry_computation->root_instruction()->fused_instruction_count()); 134} 135 136TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { 137 auto builder = HloComputation::Builder("entry"); 138 // (1.0 + 2.0) * (2.0 - 3.0) 139 HloInstruction* const1 = builder.AddInstruction( 140 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 141 HloInstruction* const2 = builder.AddInstruction( 142 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 143 HloInstruction* const3 = builder.AddInstruction( 144 HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0))); 145 HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( 146 const1->shape(), HloOpcode::kAdd, const1, const2)); 147 HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( 148 const2->shape(), HloOpcode::kSubtract, const2, const3)); 149 HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary( 150 add->shape(), HloOpcode::kMultiply, add, sub)); 151 152 HloModule module("fuse_with_constant_operands"); 153 HloComputation* entry_computation = 154 module.AddEntryComputation(builder.Build(mul)); 155 HloInstruction* call = module.OutlineExpressionFromComputation( 156 {add, sub, mul}, "", entry_computation); 157 EXPECT_EQ(call, entry_computation->root_instruction()); 158 HloComputation* callee_computation = call->to_apply(); 159 // The arguments to the call should be const1, const2, and const3. 160 EXPECT_THAT(call->operands(), 161 ::testing::UnorderedElementsAre(const1, const2, const3)); 162 163 // The callee should contain 3 parameters and 3 binary operators. 164 EXPECT_EQ(6, callee_computation->instruction_count()); 165} 166 167TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) { 168 auto builder = HloComputation::Builder("entry_computation"); 169 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( 170 /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}), 171 /*name=*/"x")); 172 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( 173 /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}), 174 /*name=*/"y")); 175 HloInstruction* transpose_y = 176 builder.AddInstruction(HloInstruction::CreateTranspose( 177 ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); 178 DotDimensionNumbers dot_dnums; 179 dot_dnums.add_lhs_contracting_dimensions(1); 180 dot_dnums.add_rhs_contracting_dimensions(0); 181 HloInstruction* dot = builder.AddInstruction( 182 HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, 183 /*rhs=*/transpose_y, dot_dnums)); 184 185 HloModule module("test_module"); 186 HloComputation* entry_computation = 187 module.AddEntryComputation(builder.Build(dot)); 188 189 HloInstruction* call = module.OutlineExpressionFromComputation( 190 {transpose_y, dot}, "outlined", entry_computation); 191 192 FoldTranspose(&module); 193 194 // Instructions after folding: x, y, and the fusion. 195 std::unordered_set<HloInstruction*> instruction_set( 196 entry_computation->instructions().begin(), 197 entry_computation->instructions().end()); 198 CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; 199 CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; 200 CHECK_EQ(1, instruction_set.erase(call)) 201 << "call is not in entry_computation."; 202 CHECK(instruction_set.empty()) 203 << "entry_computation should contain exactly 3 instructions."; 204 HloInstruction* fusion = 205 call->called_computations().front()->root_instruction(); 206 EXPECT_EQ(HloOpcode::kFusion, fusion->opcode()); 207 208 // The fusion instruction should contain two parameters, one transpose and 209 // one dot. 210 EXPECT_EQ(4, fusion->fused_instruction_count()); 211} 212 213// Test that a two dimension swap of the kernel gets folded into convolution. 214TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { 215 auto builder = HloComputation::Builder("entry_computation"); 216 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( 217 /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), 218 /*name=*/"x")); 219 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( 220 /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}), 221 /*name=*/"y")); 222 HloInstruction* transpose_y = 223 builder.AddInstruction(HloInstruction::CreateTranspose( 224 ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 0, 2, 3})); 225 auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); 226 Window window; 227 for (int i = 0; i < 2; ++i) { 228 WindowDimension* dim = window.add_dimensions(); 229 dim->set_padding_low(0); 230 dim->set_padding_high(0); 231 dim->set_base_dilation(1); 232 dim->set_window_dilation(1); 233 dim->set_stride(1); 234 dim->set_size( 235 transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); 236 } 237 StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape( 238 x->shape(), transpose_y->shape(), window, dnums); 239 EXPECT_IS_OK(conv_shape); 240 HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( 241 conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); 242 243 HloModule module("test_module"); 244 HloComputation* entry_computation = 245 module.AddEntryComputation(builder.Build(conv)); 246 FoldTranspose(&module); 247 248 // Instructions after folding: x, y, and the convolution. 249 std::unordered_set<HloInstruction*> instruction_set( 250 entry_computation->instructions().begin(), 251 entry_computation->instructions().end()); 252 CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; 253 CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; 254 CHECK_EQ(1, instruction_set.size()) 255 << "entry_computation should contain exactly 3 instructions."; 256 HloInstruction* new_conv = *instruction_set.begin(); 257 EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode()); 258 EXPECT_EQ(dnums.kernel_input_feature_dimension(), 259 new_conv->convolution_dimension_numbers() 260 .kernel_output_feature_dimension()); 261 EXPECT_EQ(dnums.kernel_output_feature_dimension(), 262 new_conv->convolution_dimension_numbers() 263 .kernel_input_feature_dimension()); 264} 265 266// Test that a complex transpose of the kernel gets folded into convolution. 267TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { 268 auto builder = HloComputation::Builder("entry_computation"); 269 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( 270 /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), 271 /*name=*/"x")); 272 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( 273 /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {1, 2, 1, 3}), 274 /*name=*/"y")); 275 HloInstruction* transpose_y = 276 builder.AddInstruction(HloInstruction::CreateTranspose( 277 ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 3, 0, 2})); 278 auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); 279 Window window; 280 for (int i = 0; i < 2; ++i) { 281 WindowDimension* dim = window.add_dimensions(); 282 dim->set_padding_low(0); 283 dim->set_padding_high(0); 284 dim->set_base_dilation(1); 285 dim->set_window_dilation(1); 286 dim->set_stride(1); 287 dim->set_size( 288 transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); 289 } 290 StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape( 291 x->shape(), transpose_y->shape(), window, dnums); 292 EXPECT_IS_OK(conv_shape); 293 HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( 294 conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); 295 296 HloModule module("test_module"); 297 HloComputation* entry_computation = 298 module.AddEntryComputation(builder.Build(conv)); 299 FoldTranspose(&module); 300 301 // Instructions after folding: x, y, and the convolution. 302 std::unordered_set<HloInstruction*> instruction_set( 303 entry_computation->instructions().begin(), 304 entry_computation->instructions().end()); 305 CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; 306 CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; 307 CHECK_EQ(1, instruction_set.size()) 308 << "entry_computation should contain exactly 3 instructions."; 309 HloInstruction* new_conv = *instruction_set.begin(); 310 EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode()); 311 EXPECT_EQ(dnums.kernel_input_feature_dimension(), 312 new_conv->convolution_dimension_numbers() 313 .kernel_output_feature_dimension()); 314 EXPECT_EQ(dnums.kernel_spatial_dimensions(1), 315 new_conv->convolution_dimension_numbers() 316 .kernel_input_feature_dimension()); 317 EXPECT_EQ( 318 dnums.kernel_output_feature_dimension(), 319 new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(0)); 320 EXPECT_EQ( 321 dnums.kernel_spatial_dimensions(0), 322 new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(1)); 323} 324 325// Test that a transpose of the activations gets folded into convolution. 326TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { 327 auto builder = HloComputation::Builder("entry_computation"); 328 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( 329 /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}), 330 /*name=*/"x")); 331 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( 332 /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), 333 /*name=*/"y")); 334 HloInstruction* transpose_x = 335 builder.AddInstruction(HloInstruction::CreateTranspose( 336 ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 2, 3})); 337 auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); 338 Window window; 339 for (int i = 0; i < 2; ++i) { 340 WindowDimension* dim = window.add_dimensions(); 341 dim->set_padding_low(0); 342 dim->set_padding_high(0); 343 dim->set_base_dilation(1); 344 dim->set_window_dilation(1); 345 dim->set_stride(1); 346 dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); 347 } 348 StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape( 349 transpose_x->shape(), y->shape(), window, dnums); 350 EXPECT_IS_OK(conv_shape); 351 HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( 352 conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); 353 354 HloModule module("test_module"); 355 HloComputation* entry_computation = 356 module.AddEntryComputation(builder.Build(conv)); 357 FoldTranspose(&module); 358 359 // Instructions after folding: x, y, and the convolution. 360 std::unordered_set<HloInstruction*> instruction_set( 361 entry_computation->instructions().begin(), 362 entry_computation->instructions().end()); 363 EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; 364 EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; 365 EXPECT_EQ(1, instruction_set.size()) 366 << "entry_computation should contain exactly 3 instructions."; 367 HloInstruction* new_conv = *instruction_set.begin(); 368 EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode()); 369 EXPECT_EQ(dnums.input_feature_dimension(), 370 new_conv->convolution_dimension_numbers().input_batch_dimension()); 371 EXPECT_EQ( 372 dnums.input_batch_dimension(), 373 new_conv->convolution_dimension_numbers().input_feature_dimension()); 374 EXPECT_EQ( 375 dnums.input_spatial_dimensions(0), 376 new_conv->convolution_dimension_numbers().input_spatial_dimensions(0)); 377 EXPECT_EQ( 378 dnums.input_spatial_dimensions(1), 379 new_conv->convolution_dimension_numbers().input_spatial_dimensions(1)); 380 EXPECT_EQ( 381 dnums.output_spatial_dimensions(0), 382 new_conv->convolution_dimension_numbers().output_spatial_dimensions(0)); 383 EXPECT_EQ( 384 dnums.output_spatial_dimensions(1), 385 new_conv->convolution_dimension_numbers().output_spatial_dimensions(1)); 386} 387 388// Test that a transpose of every dimension in the activations gets folded into 389// convolution. 390TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { 391 auto builder = HloComputation::Builder("entry_computation"); 392 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( 393 /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}), 394 /*name=*/"x")); 395 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( 396 /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), 397 /*name=*/"y")); 398 HloInstruction* transpose_x = 399 builder.AddInstruction(HloInstruction::CreateTranspose( 400 ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 3, 2})); 401 auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); 402 Window window; 403 for (int i = 0; i < 2; ++i) { 404 WindowDimension* dim = window.add_dimensions(); 405 dim->set_padding_low(0); 406 dim->set_padding_high(0); 407 dim->set_base_dilation(1); 408 dim->set_window_dilation(1); 409 dim->set_stride(1); 410 dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); 411 } 412 StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape( 413 transpose_x->shape(), y->shape(), window, dnums); 414 EXPECT_IS_OK(conv_shape); 415 HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( 416 conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); 417 418 HloModule module("test_module"); 419 HloComputation* entry_computation = 420 module.AddEntryComputation(builder.Build(conv)); 421 FoldTranspose(&module); 422 423 // Instructions after folding: x, y, and the convolution. 424 std::unordered_set<HloInstruction*> instruction_set( 425 entry_computation->instructions().begin(), 426 entry_computation->instructions().end()); 427 EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; 428 EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; 429 EXPECT_EQ(1, instruction_set.size()) 430 << "entry_computation should contain exactly 3 instructions."; 431 HloInstruction* new_conv = *instruction_set.begin(); 432 EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode()); 433 EXPECT_EQ(dnums.input_feature_dimension(), 434 new_conv->convolution_dimension_numbers().input_batch_dimension()); 435 EXPECT_EQ( 436 dnums.input_batch_dimension(), 437 new_conv->convolution_dimension_numbers().input_feature_dimension()); 438 EXPECT_EQ( 439 dnums.input_spatial_dimensions(0), 440 new_conv->convolution_dimension_numbers().input_spatial_dimensions(1)); 441 EXPECT_EQ( 442 dnums.input_spatial_dimensions(1), 443 new_conv->convolution_dimension_numbers().input_spatial_dimensions(0)); 444 EXPECT_EQ( 445 dnums.output_spatial_dimensions(0), 446 new_conv->convolution_dimension_numbers().output_spatial_dimensions(0)); 447 EXPECT_EQ( 448 dnums.output_spatial_dimensions(1), 449 new_conv->convolution_dimension_numbers().output_spatial_dimensions(1)); 450} 451 452} // namespace 453} // namespace xla 454