instruction_fusion_test.cc revision 5149785eb7175a791acbd9859872e07439b968b6
11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/instruction_fusion.h"
171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla {
211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsusing InstructionFusionTest = HloTestBase;
231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(InstructionFusionTest,
251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins       CostlyProducerAndOperandElementReusingConsumerNotFused) {
261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloComputation::Builder builder(TestName());
271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* const0 = builder.AddInstruction(
281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0));
311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* broadcast2 =
321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.AddInstruction(HloInstruction::CreateBroadcast(
331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          ShapeUtil::MakeShape(S32, {1}), exp1, {0}));
341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto module = MakeUnique<HloModule>(TestName());
361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto computation = module->AddEntryComputation(builder.Build());
371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_EQ(broadcast2, computation->root_instruction());
381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_TRUE(
395149785eb7175a791acbd9859872e07439b968b6David Majnemer      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
405149785eb7175a791acbd9859872e07439b968b6David Majnemer          .Run(module.get())
415149785eb7175a791acbd9859872e07439b968b6David Majnemer          .ValueOrDie());
421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_EQ(broadcast2, computation->root_instruction());
431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(InstructionFusionTest,
461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins       NonCostlyProducerAndOperandElementReusingConsumerFused) {
471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloComputation::Builder builder(TestName());
481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* const0 = builder.AddInstruction(
491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* negate1 = builder.AddInstruction(HloInstruction::CreateUnary(
511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, const0));
521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* broadcast2 =
531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.AddInstruction(HloInstruction::CreateBroadcast(
541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          ShapeUtil::MakeShape(S32, {1}), negate1, {0}));
551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto module = MakeUnique<HloModule>(TestName());
571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto computation = module->AddEntryComputation(builder.Build());
581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_EQ(broadcast2, computation->root_instruction());
591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_TRUE(
605149785eb7175a791acbd9859872e07439b968b6David Majnemer      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
615149785eb7175a791acbd9859872e07439b968b6David Majnemer          .Run(module.get())
625149785eb7175a791acbd9859872e07439b968b6David Majnemer          .ValueOrDie());
631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode());
641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(InstructionFusionTest,
671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins       CostlyProducerAndNonOperandElementReusingConsumerFused_Reshape) {
681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloComputation::Builder builder(TestName());
691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* const0 = builder.AddInstruction(
701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0));
731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* reshape2 = builder.AddInstruction(
741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), exp1));
751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto module = MakeUnique<HloModule>(TestName());
771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto computation = module->AddEntryComputation(builder.Build());
781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_EQ(reshape2, computation->root_instruction());
791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_TRUE(
805149785eb7175a791acbd9859872e07439b968b6David Majnemer      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
815149785eb7175a791acbd9859872e07439b968b6David Majnemer          .Run(module.get())
825149785eb7175a791acbd9859872e07439b968b6David Majnemer          .ValueOrDie());
831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode());
841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(InstructionFusionTest,
871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins       CostlyProducerAndNonOperandElementReusingConsumerFused_Transpose) {
881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloComputation::Builder builder(TestName());
891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* const0 = builder.AddInstruction(
901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0));
931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* transpose2 = builder.AddInstruction(
941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateTranspose(ShapeUtil::MakeShape(S32, {}), exp1, {}));
951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto module = MakeUnique<HloModule>(TestName());
971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto computation = module->AddEntryComputation(builder.Build());
981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_EQ(transpose2, computation->root_instruction());
991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_TRUE(
1005149785eb7175a791acbd9859872e07439b968b6David Majnemer      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
1015149785eb7175a791acbd9859872e07439b968b6David Majnemer          .Run(module.get())
1025149785eb7175a791acbd9859872e07439b968b6David Majnemer          .ValueOrDie());
1031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode());
1041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) {
1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloComputation::Builder builder(TestName());
1081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto param0 = builder.AddInstruction(
1091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0"));
1101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape1 = builder.AddInstruction(
1111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0));
1121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto module = MakeUnique<HloModule>(TestName());
1141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto computation = module->AddEntryComputation(builder.Build());
1151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_EQ(reshape1, computation->root_instruction());
1161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_FALSE(
1175149785eb7175a791acbd9859872e07439b968b6David Majnemer      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
1185149785eb7175a791acbd9859872e07439b968b6David Majnemer          .Run(module.get())
1195149785eb7175a791acbd9859872e07439b968b6David Majnemer          .ValueOrDie());
1201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) {
1231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloComputation::Builder builder(TestName());
1241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto param0 = builder.AddInstruction(
1251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0"));
1261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape1 = builder.AddInstruction(
1271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0));
1281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto module = MakeUnique<HloModule>(TestName());
1301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto computation = module->AddEntryComputation(builder.Build());
1311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_EQ(reshape1, computation->root_instruction());
1321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_FALSE(
1335149785eb7175a791acbd9859872e07439b968b6David Majnemer      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
1345149785eb7175a791acbd9859872e07439b968b6David Majnemer          .Run(module.get())
1355149785eb7175a791acbd9859872e07439b968b6David Majnemer          .ValueOrDie());
1361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) {
1391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloComputation::Builder builder(TestName());
1401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto param0 = builder.AddInstruction(
1411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0"));
1421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose(
1431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {}), param0, {}));
1441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto module = MakeUnique<HloModule>(TestName());
1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto computation = module->AddEntryComputation(builder.Build());
1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_EQ(transpose1, computation->root_instruction());
1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_FALSE(
1495149785eb7175a791acbd9859872e07439b968b6David Majnemer      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
1505149785eb7175a791acbd9859872e07439b968b6David Majnemer          .Run(module.get())
1515149785eb7175a791acbd9859872e07439b968b6David Majnemer          .ValueOrDie());
1521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace xla
155