11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/instruction_fusion.h"
171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
18efce4a0452f1eba55fd35a22d20249f449f0debeA. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_matchers.h"
191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla {
221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsusing InstructionFusionTest = HloTestBase;
241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) {
261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloComputation::Builder builder(TestName());
271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto param0 = builder.AddInstruction(
281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0"));
291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape1 = builder.AddInstruction(
301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0));
311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
329641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto module = CreateNewModule();
331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto computation = module->AddEntryComputation(builder.Build());
341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_EQ(reshape1, computation->root_instruction());
351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_FALSE(
365149785eb7175a791acbd9859872e07439b968b6David Majnemer      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
375149785eb7175a791acbd9859872e07439b968b6David Majnemer          .Run(module.get())
385149785eb7175a791acbd9859872e07439b968b6David Majnemer          .ValueOrDie());
391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) {
421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloComputation::Builder builder(TestName());
431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto param0 = builder.AddInstruction(
441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0"));
451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape1 = builder.AddInstruction(
461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0));
471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
489641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto module = CreateNewModule();
491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto computation = module->AddEntryComputation(builder.Build());
501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_EQ(reshape1, computation->root_instruction());
511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_FALSE(
525149785eb7175a791acbd9859872e07439b968b6David Majnemer      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
535149785eb7175a791acbd9859872e07439b968b6David Majnemer          .Run(module.get())
545149785eb7175a791acbd9859872e07439b968b6David Majnemer          .ValueOrDie());
551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) {
581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloComputation::Builder builder(TestName());
591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto param0 = builder.AddInstruction(
601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0"));
611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose(
621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {}), param0, {}));
631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
649641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto module = CreateNewModule();
651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto computation = module->AddEntryComputation(builder.Build());
661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_EQ(transpose1, computation->root_instruction());
671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_FALSE(
685149785eb7175a791acbd9859872e07439b968b6David Majnemer      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
695149785eb7175a791acbd9859872e07439b968b6David Majnemer          .Run(module.get())
705149785eb7175a791acbd9859872e07439b968b6David Majnemer          .ValueOrDie());
711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
73116e1dde1a896e986525d9feccfe88265e17f962A. Unique TensorFlowerTEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) {
74116e1dde1a896e986525d9feccfe88265e17f962A. Unique TensorFlower  HloComputation::Builder builder(TestName());
755bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto shape = ShapeUtil::MakeShape(F32, {16, 16});
765bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto param0 =
775bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
785bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto param1 =
795bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
805bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  HloInstruction* binary1 = builder.AddInstruction(
815bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
825bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  builder.AddInstruction(HloInstruction::CreateSend(binary1, 0));
835bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  HloInstruction* unary = builder.AddInstruction(
845bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
855bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower
869641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto module = CreateNewModule();
875bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto computation = module->AddEntryComputation(builder.Build());
885bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  EXPECT_EQ(unary, computation->root_instruction());
895bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  EXPECT_FALSE(
905bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
915bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower          .Run(module.get())
925bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower          .ValueOrDie());
935bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower}
945bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower
955bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlowerTEST_F(InstructionFusionTest, AllowUnaryDuplication) {
965bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  HloComputation::Builder builder(TestName());
975bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto shape = ShapeUtil::MakeShape(F32, {16, 16});
985bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto param0 =
995bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
1005bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  HloInstruction* unary1 = builder.AddInstruction(
1015bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      HloInstruction::CreateUnary(shape, HloOpcode::kFloor, param0));
102116e1dde1a896e986525d9feccfe88265e17f962A. Unique TensorFlower  builder.AddInstruction(HloInstruction::CreateSend(unary1, 0));
1035bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  HloInstruction* unary2 = builder.AddInstruction(
1045bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1));
105116e1dde1a896e986525d9feccfe88265e17f962A. Unique TensorFlower
1069641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto module = CreateNewModule();
107116e1dde1a896e986525d9feccfe88265e17f962A. Unique TensorFlower  auto computation = module->AddEntryComputation(builder.Build());
108116e1dde1a896e986525d9feccfe88265e17f962A. Unique TensorFlower  EXPECT_EQ(unary2, computation->root_instruction());
1095bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  EXPECT_TRUE(
1105bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
1115bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower          .Run(module.get())
1125bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower          .ValueOrDie());
1135bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower}
1145bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower
1155bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlowerTEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) {
1165bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto shape = ShapeUtil::MakeShape(F32, {16, 16});
1175bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto small_shape = ShapeUtil::MakeShape(F32, {16});
1185bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  HloComputation::Builder builder(TestName());
1195bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto param0 = builder.AddInstruction(
1205bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      HloInstruction::CreateParameter(0, small_shape, "0"));
1215bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto param1 =
1225bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
1235bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  HloInstruction* binary1 = builder.AddInstruction(
1245bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
1255bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  builder.AddInstruction(HloInstruction::CreateSend(binary1, 0));
1265bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  HloInstruction* unary = builder.AddInstruction(
1275bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
1285bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower
1299641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto module = CreateNewModule();
1305bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto computation = module->AddEntryComputation(builder.Build());
1315bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  EXPECT_EQ(unary, computation->root_instruction());
1325bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  EXPECT_TRUE(
133116e1dde1a896e986525d9feccfe88265e17f962A. Unique TensorFlower      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
134116e1dde1a896e986525d9feccfe88265e17f962A. Unique TensorFlower          .Run(module.get())
135116e1dde1a896e986525d9feccfe88265e17f962A. Unique TensorFlower          .ValueOrDie());
136116e1dde1a896e986525d9feccfe88265e17f962A. Unique TensorFlower}
137116e1dde1a896e986525d9feccfe88265e17f962A. Unique TensorFlower
1381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace xla
139