instruction_fusion_test.cc revision 5bc685d7f16b0fc27b936e63fa01668e4af4034c
11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/instruction_fusion.h"
171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
18efce4a0452f1eba55fd35a22d20249f449f0debeA. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_matchers.h"
191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
21efce4a0452f1eba55fd35a22d20249f449f0debeA. Unique TensorFlowernamespace op = xla::testing::opcode_matchers;
22efce4a0452f1eba55fd35a22d20249f449f0debeA. Unique TensorFlower
231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla {
241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsusing InstructionFusionTest = HloTestBase;
261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(InstructionFusionTest,
281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins       CostlyProducerAndOperandElementReusingConsumerNotFused) {
291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloComputation::Builder builder(TestName());
301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* const0 = builder.AddInstruction(
311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0));
341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* broadcast2 =
351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.AddInstruction(HloInstruction::CreateBroadcast(
361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          ShapeUtil::MakeShape(S32, {1}), exp1, {0}));
371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto module = MakeUnique<HloModule>(TestName());
391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto computation = module->AddEntryComputation(builder.Build());
401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_EQ(broadcast2, computation->root_instruction());
411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_TRUE(
425149785eb7175a791acbd9859872e07439b968b6David Majnemer      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
435149785eb7175a791acbd9859872e07439b968b6David Majnemer          .Run(module.get())
445149785eb7175a791acbd9859872e07439b968b6David Majnemer          .ValueOrDie());
451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_EQ(broadcast2, computation->root_instruction());
461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(InstructionFusionTest,
491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins       NonCostlyProducerAndOperandElementReusingConsumerFused) {
501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloComputation::Builder builder(TestName());
511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* const0 = builder.AddInstruction(
521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* negate1 = builder.AddInstruction(HloInstruction::CreateUnary(
541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, const0));
551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* broadcast2 =
561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.AddInstruction(HloInstruction::CreateBroadcast(
571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          ShapeUtil::MakeShape(S32, {1}), negate1, {0}));
581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto module = MakeUnique<HloModule>(TestName());
601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto computation = module->AddEntryComputation(builder.Build());
611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_EQ(broadcast2, computation->root_instruction());
621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_TRUE(
635149785eb7175a791acbd9859872e07439b968b6David Majnemer      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
645149785eb7175a791acbd9859872e07439b968b6David Majnemer          .Run(module.get())
655149785eb7175a791acbd9859872e07439b968b6David Majnemer          .ValueOrDie());
66efce4a0452f1eba55fd35a22d20249f449f0debeA. Unique TensorFlower  EXPECT_THAT(computation->root_instruction(), op::Fusion());
671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(InstructionFusionTest,
701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins       CostlyProducerAndNonOperandElementReusingConsumerFused_Reshape) {
711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloComputation::Builder builder(TestName());
721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* const0 = builder.AddInstruction(
731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0));
761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* reshape2 = builder.AddInstruction(
771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), exp1));
781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto module = MakeUnique<HloModule>(TestName());
801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto computation = module->AddEntryComputation(builder.Build());
811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_EQ(reshape2, computation->root_instruction());
821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_TRUE(
835149785eb7175a791acbd9859872e07439b968b6David Majnemer      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
845149785eb7175a791acbd9859872e07439b968b6David Majnemer          .Run(module.get())
855149785eb7175a791acbd9859872e07439b968b6David Majnemer          .ValueOrDie());
86efce4a0452f1eba55fd35a22d20249f449f0debeA. Unique TensorFlower  EXPECT_THAT(computation->root_instruction(), op::Fusion());
871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(InstructionFusionTest,
901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins       CostlyProducerAndNonOperandElementReusingConsumerFused_Transpose) {
911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloComputation::Builder builder(TestName());
921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* const0 = builder.AddInstruction(
931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0));
961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloInstruction* transpose2 = builder.AddInstruction(
971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateTranspose(ShapeUtil::MakeShape(S32, {}), exp1, {}));
981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto module = MakeUnique<HloModule>(TestName());
1001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto computation = module->AddEntryComputation(builder.Build());
1011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_EQ(transpose2, computation->root_instruction());
1021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_TRUE(
1035149785eb7175a791acbd9859872e07439b968b6David Majnemer      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
1045149785eb7175a791acbd9859872e07439b968b6David Majnemer          .Run(module.get())
1055149785eb7175a791acbd9859872e07439b968b6David Majnemer          .ValueOrDie());
106efce4a0452f1eba55fd35a22d20249f449f0debeA. Unique TensorFlower  EXPECT_THAT(computation->root_instruction(), op::Fusion());
1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) {
1101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloComputation::Builder builder(TestName());
1111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto param0 = builder.AddInstruction(
1121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0"));
1131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape1 = builder.AddInstruction(
1141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0));
1151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto module = MakeUnique<HloModule>(TestName());
1171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto computation = module->AddEntryComputation(builder.Build());
1181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_EQ(reshape1, computation->root_instruction());
1191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_FALSE(
1205149785eb7175a791acbd9859872e07439b968b6David Majnemer      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
1215149785eb7175a791acbd9859872e07439b968b6David Majnemer          .Run(module.get())
1225149785eb7175a791acbd9859872e07439b968b6David Majnemer          .ValueOrDie());
1231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) {
1261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloComputation::Builder builder(TestName());
1271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto param0 = builder.AddInstruction(
1281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0"));
1291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape1 = builder.AddInstruction(
1301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0));
1311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto module = MakeUnique<HloModule>(TestName());
1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto computation = module->AddEntryComputation(builder.Build());
1341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_EQ(reshape1, computation->root_instruction());
1351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_FALSE(
1365149785eb7175a791acbd9859872e07439b968b6David Majnemer      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
1375149785eb7175a791acbd9859872e07439b968b6David Majnemer          .Run(module.get())
1385149785eb7175a791acbd9859872e07439b968b6David Majnemer          .ValueOrDie());
1391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) {
1421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  HloComputation::Builder builder(TestName());
1431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto param0 = builder.AddInstruction(
1441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0"));
1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose(
1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {}), param0, {}));
1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto module = MakeUnique<HloModule>(TestName());
1491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto computation = module->AddEntryComputation(builder.Build());
1501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_EQ(transpose1, computation->root_instruction());
1511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_FALSE(
1525149785eb7175a791acbd9859872e07439b968b6David Majnemer      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
1535149785eb7175a791acbd9859872e07439b968b6David Majnemer          .Run(module.get())
1545149785eb7175a791acbd9859872e07439b968b6David Majnemer          .ValueOrDie());
1551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
157116e1dde1a896e986525d9feccfe88265e17f962A. Unique TensorFlowerTEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) {
158116e1dde1a896e986525d9feccfe88265e17f962A. Unique TensorFlower  HloComputation::Builder builder(TestName());
1595bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto shape = ShapeUtil::MakeShape(F32, {16, 16});
1605bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto param0 =
1615bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
1625bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto param1 =
1635bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
1645bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  HloInstruction* binary1 = builder.AddInstruction(
1655bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
1665bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  builder.AddInstruction(HloInstruction::CreateSend(binary1, 0));
1675bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  HloInstruction* unary = builder.AddInstruction(
1685bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
1695bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower
1705bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto module = MakeUnique<HloModule>(TestName());
1715bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto computation = module->AddEntryComputation(builder.Build());
1725bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  EXPECT_EQ(unary, computation->root_instruction());
1735bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  EXPECT_FALSE(
1745bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
1755bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower          .Run(module.get())
1765bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower          .ValueOrDie());
1775bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower}
1785bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower
1795bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlowerTEST_F(InstructionFusionTest, AllowUnaryDuplication) {
1805bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  HloComputation::Builder builder(TestName());
1815bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto shape = ShapeUtil::MakeShape(F32, {16, 16});
1825bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto param0 =
1835bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
1845bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  HloInstruction* unary1 = builder.AddInstruction(
1855bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      HloInstruction::CreateUnary(shape, HloOpcode::kFloor, param0));
186116e1dde1a896e986525d9feccfe88265e17f962A. Unique TensorFlower  builder.AddInstruction(HloInstruction::CreateSend(unary1, 0));
1875bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  HloInstruction* unary2 = builder.AddInstruction(
1885bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1));
189116e1dde1a896e986525d9feccfe88265e17f962A. Unique TensorFlower
190116e1dde1a896e986525d9feccfe88265e17f962A. Unique TensorFlower  auto module = MakeUnique<HloModule>(TestName());
191116e1dde1a896e986525d9feccfe88265e17f962A. Unique TensorFlower  auto computation = module->AddEntryComputation(builder.Build());
192116e1dde1a896e986525d9feccfe88265e17f962A. Unique TensorFlower  EXPECT_EQ(unary2, computation->root_instruction());
1935bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  EXPECT_TRUE(
1945bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
1955bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower          .Run(module.get())
1965bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower          .ValueOrDie());
1975bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower}
1985bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower
1995bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlowerTEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) {
2005bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto shape = ShapeUtil::MakeShape(F32, {16, 16});
2015bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto small_shape = ShapeUtil::MakeShape(F32, {16});
2025bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  HloComputation::Builder builder(TestName());
2035bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto param0 = builder.AddInstruction(
2045bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      HloInstruction::CreateParameter(0, small_shape, "0"));
2055bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto param1 =
2065bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
2075bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  HloInstruction* binary1 = builder.AddInstruction(
2085bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
2095bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  builder.AddInstruction(HloInstruction::CreateSend(binary1, 0));
2105bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  HloInstruction* unary = builder.AddInstruction(
2115bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower      HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
2125bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower
2135bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto module = MakeUnique<HloModule>(TestName());
2145bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  auto computation = module->AddEntryComputation(builder.Build());
2155bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  EXPECT_EQ(unary, computation->root_instruction());
2165bc685d7f16b0fc27b936e63fa01668e4af4034cA. Unique TensorFlower  EXPECT_TRUE(
217116e1dde1a896e986525d9feccfe88265e17f962A. Unique TensorFlower      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
218116e1dde1a896e986525d9feccfe88265e17f962A. Unique TensorFlower          .Run(module.get())
219116e1dde1a896e986525d9feccfe88265e17f962A. Unique TensorFlower          .ValueOrDie());
220116e1dde1a896e986525d9feccfe88265e17f962A. Unique TensorFlower}
221116e1dde1a896e986525d9feccfe88265e17f962A. Unique TensorFlower
2221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace xla
223