fusion_test.cc revision 3b41352a3177c2fe8a1329e8981b285bb6aacf8b
11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <math.h>
171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <algorithm>
181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <memory>
191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <new>
201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <utility>
211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/array2d.h"
233b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/compiler/xla/client/client_library.h"
243b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/compiler/xla/client/computation.h"
253b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/compiler/xla/client/computation_builder.h"
269641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/literal_util.h"
281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/primitive_util.h"
291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/ptr_util.h"
301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/hlo_computation.h"
311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/hlo_instruction.h"
321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/hlo_module.h"
331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/hlo_opcode.h"
343b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/compiler/xla/service/platform_util.h"
351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/shape_util.h"
363b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/literal_test_util.h"
391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/test_macros.h"
401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/xla_data.pb.h"
411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/lib/gtl/array_slice.h"
421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/logging.h"
431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/protobuf.h"
443b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/core/platform/test_benchmark.h"
451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/types.h"
461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsusing tensorflow::gtl::ArraySlice;
481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
493b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlowernamespace se = ::perftools::gputools;
503b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower
511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla {
521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace {
531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsconst int test_width = 2, test_height = 3;
551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsconst float test_float_vals[3][test_width][test_height] = {
571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    {{-1.0, -1.0, 1.0}, {-3.0, 0.0, -1.0}},
581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    {{-3.0, 2.0, 1.0}, {0.0, -3.0, 1.0}},
591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    {{-3.0, 0.0, -3.0}, {-1.0, -2.0, 1.0}}};
601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// Test whether fusion operations are emitted with no errors and compute
621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// accurate outputs.
631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass FusionTest : public HloTestBase {
641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins protected:
651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  template <typename T, int Arity>
661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  void TestElementwise2D(HloOpcode opcode) {
671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    Array2D<float> operand_data[Arity];
681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int i = 0; i < Arity; ++i) {
691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      new (&operand_data[i]) Array2D<float>(test_width, test_height);
701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    Array2D<T> answer_data(test_width, test_height);
721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int i = 0; i < test_width; ++i) {
731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int j = 0; j < test_height; ++j) {
741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        float xs[Arity];
751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        for (int k = 0; k < Arity; ++k) {
761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          xs[k] = test_float_vals[k][i][j];
771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          operand_data[k](i, j) = xs[k];
781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        }
791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        answer_data(i, j) = ComputeElementwiseAnswer<T>(opcode, xs);
801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto builder = HloComputation::Builder(TestName());
849641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky    auto hlo_module = CreateNewModule();
851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto prim_type = primitive_util::NativeToPrimitiveType<T>();
871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    HloInstruction* hlos[4];
891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int i = 0; i < Arity; ++i) {
901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      hlos[i + 1] = builder.AddInstruction(HloInstruction::CreateConstant(
9146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower          Literal::CreateR2FromArray2D(operand_data[i])));
921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto answer_shape =
941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        ShapeUtil::MakeShape(prim_type, {test_width, test_height});
951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::unique_ptr<HloInstruction> root_hlo;
961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    switch (Arity) {
971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      case 1:
981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        root_hlo = HloInstruction::CreateUnary(answer_shape, opcode, hlos[1]);
991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        break;
1001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      case 2:
1011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1],
1021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                hlos[2]);
1031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        break;
1041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      case 3:
1051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        root_hlo = HloInstruction::CreateTernary(answer_shape, opcode, hlos[1],
1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                 hlos[2], hlos[3]);
1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        break;
1081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      default:
1091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        LOG(FATAL) << "Bad arity: " << Arity;
1101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
1111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    hlos[0] = builder.AddInstruction(std::move(root_hlo));
1121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    hlo_module->AddEntryComputation(builder.Build())
1131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        ->CreateFusionInstruction(
1141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            ArraySlice<HloInstruction*>(hlos, 0, Arity + 1),
1151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            HloInstruction::FusionKind::kLoop);
1161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
11746737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower    auto expected = Literal::CreateR2FromArray2D(answer_data);
1181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto actual = ExecuteAndTransfer(std::move(hlo_module), {});
1191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    if (primitive_util::IsFloatingPointType(prim_type)) {
1201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      LiteralTestUtil::ExpectNear(*expected, *actual, ErrorSpec(1e-4));
1211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    } else {
1221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      LiteralTestUtil::ExpectEqual(*expected, *actual);
1231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
1241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins private:
1271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  template <typename T>
1281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  T ComputeElementwiseAnswer(HloOpcode opcode, ArraySlice<float> xs);
1291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
1301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinstemplate <>
1321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsfloat FusionTest::ComputeElementwiseAnswer<float>(HloOpcode opcode,
1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                  ArraySlice<float> xs) {
1341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  switch (opcode) {
1351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kAdd:
1361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return xs[0] + xs[1];
1371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kSubtract:
1381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return xs[0] - xs[1];
1391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kMultiply:
1401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return xs[0] * xs[1];
1411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kDivide:
1421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return xs[0] / xs[1];
1431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kPower:
1441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return powf(xs[0], xs[1]);
1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kMinimum:
1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return std::min(xs[0], xs[1]);
1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kMaximum:
1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return std::max(xs[0], xs[1]);
1491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kClamp:
1501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return std::min(xs[2], std::max(xs[1], xs[0]));
1511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    default:
1521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      LOG(FATAL) << "No elementwise opcode: " << opcode;
1531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinstemplate <>
1571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsuint8 FusionTest::ComputeElementwiseAnswer<uint8>(HloOpcode opcode,
1581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                  ArraySlice<float> xs) {
1591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  switch (opcode) {
1601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kEq:
1611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return xs[0] == xs[1];
1621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kNe:
1631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return xs[0] != xs[1];
1641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kGt:
1651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return xs[0] > xs[1];
1661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kLt:
1671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return xs[0] < xs[1];
1681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kGe:
1691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return xs[0] >= xs[1];
1701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kLe:
1711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return xs[0] <= xs[1];
1721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    default:
1731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      LOG(FATAL) << "No comparatory opcode: " << opcode;
1741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Test) {
1781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // test expression:
1791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // slice(select({{T, F, T}, {F, T, F}},
1801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  //              concat(transpose({{1.0}, {2.0}, {3.0}} +
1811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  //                               {{-1.0}, {-1.0}, {-1.0}}),
1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  //                     {{1.62, 2.72, 3.14}}) +
1831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  //                     (-{{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}),
1841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  //              {{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})) = {{0.5}, {2.72}}
1851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
1869641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
1871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
18846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<float>({{1.0}, {2.0}, {3.0}})));
1891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
19046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<float>({{-1.0}, {-1.0}, {-1.0}})));
1911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
1921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(F32, {3, 1}), HloOpcode::kAdd, const0, const1));
1931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape3 = builder.AddInstruction(HloInstruction::CreateTranspose(
1941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(F32, {1, 3}), add2, {1, 0}));
1951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
19646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<float>({{1.62, 2.72, 3.14}})));
1971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto concat5 = builder.AddInstruction(HloInstruction::CreateConcatenate(
1981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(F32, {2, 3}), {reshape3, const4}, 0));
1991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const6 = builder.AddInstruction(HloInstruction::CreateConstant(
20046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<float>({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}})));
2011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto negate7 = builder.AddInstruction(HloInstruction::CreateUnary(
2021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kNegate, const6));
2031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto add8 = builder.AddInstruction(HloInstruction::CreateBinary(
2041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kAdd, concat5, negate7));
2051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const9 = builder.AddInstruction(HloInstruction::CreateConstant(
20646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<float>({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})));
20746737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto const10 = builder.AddInstruction(HloInstruction::CreateConstant(
20846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<bool>({{true, false, true}, {false, true, false}})));
2091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto select11 = builder.AddInstruction(
2101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}),
2111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                    HloOpcode::kSelect, const10, add8, const9));
2121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto slice12 = builder.AddInstruction(HloInstruction::CreateSlice(
2131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(F32, {2, 1}), select11, {0, 1}, {2, 2}));
2141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // CreateFusionInstruction needs the `instructions_to_fuse` argument in
2151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // reverse topological order, so the first element in `instructions_to_fuse`
2161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // must be the root.
2171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
2181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(
2191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          {slice12, select11, const10, const9, add8, negate7, const6, concat5,
2201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins           const4, reshape3, add2, const1, const0},
2211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          HloInstruction::FusionKind::kLoop);
2221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
22346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>({{0.5}, {2.72}}),
2241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                              *ExecuteAndTransfer(std::move(hlo_module), {}),
2251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                              ErrorSpec(1e-4));
2261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
2271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// Test whether we emit appropriate code for parameters of fusion instructions.
2291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Parameter) {
2301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Build a computation and fuse part of it so the fusion instruction has an
2311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // operand parameter.
2321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
2339641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
2341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
23546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<float>({{1.0, 2.0, 3.0}})));
2361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary(
2371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kCopy, const0));
2381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const2 = builder.AddInstruction(HloInstruction::CreateConstant(
23946737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<float>({{-2.0, -2.0, -2.0}})));
2401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // add3 = copy1 + const2 = const0 + const2 = {1,2,3} + {-2,-2,-2} = {-1,0,+1}
2411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
2421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kAdd, copy1, const2));
2431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // CreateFusionInstruction needs `instructions_to_fuse` in reverse topological
2441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // order.
2451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
2461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2},
2471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
2481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
24946737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>({{-1.0, 0.0, 1.0}}),
2501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                              *ExecuteAndTransfer(std::move(hlo_module), {}),
2511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                              ErrorSpec(1e-4));
2521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
2531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
2551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
2569641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
2571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant(
25846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR1<float>({1.0, 2.0, 3.0})));
2591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const_array = builder.AddInstruction(HloInstruction::CreateConstant(
26046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<float>({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}})));
2611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto broadcast = builder.AddInstruction(
2621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateBroadcast(const_array->shape(), const_vector, {1}));
2631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // add2 = broadcast(const_vector) + const_array
2641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  //      = broadcast({1,2,3}) + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}
2651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  //      = {{1, 2, 3}, {1, 2, 3}} + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}
2661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto add2 = builder.AddInstruction(
2671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {2, 3}),
2681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                   HloOpcode::kAdd, broadcast, const_array));
2691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
2701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast},
2711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
2721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  LiteralTestUtil::ExpectNear(
27446737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
2751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4));
2761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
2771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, ReshapeToScalar) {
2791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
2809641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
2811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto single_element_array = builder.AddInstruction(
28246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR2<int32>({{5}})));
2831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape = builder.AddInstruction(HloInstruction::CreateReshape(
2841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {}), single_element_array));
2851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
2861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape},
2871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
28846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(5),
2891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                               *ExecuteAndTransfer(std::move(hlo_module), {}));
2901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
2911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
2931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
2949641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
2951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
29646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}})));
2971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
2981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {1, 2, 3}), const0));
2991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
3001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
3011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
3021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  LiteralTestUtil::ExpectEqual(
30346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
3041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      *ExecuteAndTransfer(std::move(hlo_module), {}));
3051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
3061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
3081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
3099641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
3101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
31146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}})));
3121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape1 = builder.AddInstruction(
3131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 2}), const0));
3141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
3151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
3161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
3171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  LiteralTestUtil::ExpectEqual(
31846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
3191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      *ExecuteAndTransfer(std::move(hlo_module), {}));
3201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
3211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape_1by1by1_) {
3231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
3249641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
3251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(
32646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR3<int32>({{{7}}})));
3271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape1 = builder.AddInstruction(
3281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0));
3291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
3301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
3311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
33246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(7),
3331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                               *ExecuteAndTransfer(std::move(hlo_module), {}));
3341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
3351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape__1by1by1) {
3371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
3389641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
3391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(
34046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR0<int32>(7)));
3411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
3421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {1, 1, 1}), const0));
3431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
3441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
3451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
34646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  LiteralTestUtil::ExpectEqual(*Literal::CreateR3<int32>({{{7}}}),
3471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                               *ExecuteAndTransfer(std::move(hlo_module), {}));
3481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
3491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape__) {
3511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
3529641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
3531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(
35446737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR0<int32>(7)));
3551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape1 = builder.AddInstruction(
3561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0));
3571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
3581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
3591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
36046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(7),
3611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                               *ExecuteAndTransfer(std::move(hlo_module), {}));
3621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
3631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
3651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
3669641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
3671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
36846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
3691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape1 = builder.AddInstruction(
3701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 3}), const0));
3711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
3721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
3731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
3741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  LiteralTestUtil::ExpectEqual(
37546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
3761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      *ExecuteAndTransfer(std::move(hlo_module), {}));
3771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
3781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Transpose_2by3) {
3801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
3819641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
3821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
38346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}})));
3841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
3851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {3, 2}), const0, {1, 0}));
3861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
3871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
3881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
3891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  LiteralTestUtil::ExpectEqual(
39046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
3911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      *ExecuteAndTransfer(std::move(hlo_module), {}));
3921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
3931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Transpose_3by3) {
3951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
3969641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
3971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
39846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
3991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
4001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {3, 3}), const0, {1, 0}));
4011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
4021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
4031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
4041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  LiteralTestUtil::ExpectEqual(
40546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
4061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      *ExecuteAndTransfer(std::move(hlo_module), {}));
4071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
4081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reverse) {
4101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
4119641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
4121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(
41346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3})));
4141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse(
4151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {3}), const0, {0}));
4161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
4171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1},
4181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
4191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
42046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({3, 2, 1}),
4211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                               *ExecuteAndTransfer(std::move(hlo_module), {}));
4221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
4231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsstd::unique_ptr<HloComputation> MakeReduceTestComputation() {
4251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder("add");
4261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
4271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*parameter_number=*/0, ShapeUtil::MakeShape(S32, {}), "lhs"));
4281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto rhs = builder.AddInstruction(HloInstruction::CreateParameter(
4291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*parameter_number=*/1, ShapeUtil::MakeShape(S32, {}), "rhs"));
4301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  builder.AddInstruction(HloInstruction::CreateBinary(
4311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, lhs, rhs));
4321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return builder.Build();
4331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
4341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
4369641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
4371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
43946737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto const0 = builder.AddInstruction(
44046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 4, 8})));
4411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const1 = builder.AddInstruction(
44246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
4431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
4441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {}), const0, const1, {0},
4451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
4461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
4471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2},
4481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
4491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
45046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(15),
4511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                               *ExecuteAndTransfer(std::move(hlo_module), {}));
4521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
4531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
4559641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
4561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
45846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto const0 = builder.AddInstruction(
45946737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 4, 8})));
4601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const1 = builder.AddInstruction(
46146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
4621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
4631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {}), const0, const1, {0},
4641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
4651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto negate3 = builder.AddInstruction(HloInstruction::CreateUnary(
4661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {1}), HloOpcode::kNegate, reduce2));
4671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
4681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2},
4691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
4701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
47146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-15}),
4721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                               *ExecuteAndTransfer(std::move(hlo_module), {}));
4731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
4741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
4761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
4779641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
4781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
47946737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<int32>({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}})));
4801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const1 = builder.AddInstruction(
48146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
4821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Window window;
4831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ASSERT_TRUE(
4841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      tensorflow::protobuf::TextFormat::ParseFromString("dimensions:{\n"
4851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "size:2\n"
4861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "stride:1\n"
4871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "padding_low:0\n"
4881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "padding_high:0\n"
4891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "window_dilation:1\n"
4901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "base_dilation:1\n"
4911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "}\n"
4921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "dimensions:{\n"
4931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "size:2\n"
4941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "stride:1\n"
4951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "padding_low:0\n"
4961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "padding_high:0\n"
4971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "window_dilation:1\n"
4981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "base_dilation:1\n"
4991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "}\n",
5001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        &window));
5011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto nested_builder = HloComputation::Builder("mul");
5021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  {
5031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto x = nested_builder.AddInstruction(
5041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "x"));
5051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto y = nested_builder.AddInstruction(
5061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(S32, {}), "y"));
5071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    nested_builder.AddInstruction(HloInstruction::CreateBinary(
5081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply, x, y));
5091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
5101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto nested_computation =
5111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      hlo_module->AddEmbeddedComputation(nested_builder.Build());
5121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reduce_window2 =
5131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.AddInstruction(HloInstruction::CreateReduceWindow(
5141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          ShapeUtil::MakeShape(S32, {2, 2}), const0, const1, window,
5151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          nested_computation));
5161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
5171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2},
5181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
5191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  LiteralTestUtil::ExpectEqual(
52146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
5221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      *ExecuteAndTransfer(std::move(hlo_module), {}));
5231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); }
5261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Subtract2D) {
5281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TestElementwise2D<float, 2>(HloOpcode::kSubtract);
5291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Multiply2D) {
5321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TestElementwise2D<float, 2>(HloOpcode::kMultiply);
5331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Divide2D) {
5361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TestElementwise2D<float, 2>(HloOpcode::kDivide);
5371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Power2D) {
5401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TestElementwise2D<float, 2>(HloOpcode::kPower);
5411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Minimum2D) {
5441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TestElementwise2D<float, 2>(HloOpcode::kMinimum);
5451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Maximum2D) {
5481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TestElementwise2D<float, 2>(HloOpcode::kMaximum);
5491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Equal2D) { TestElementwise2D<uint8, 2>(HloOpcode::kEq); }
5521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Inequal2D) {
5541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TestElementwise2D<uint8, 2>(HloOpcode::kNe);
5551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Greater2D) {
5581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TestElementwise2D<uint8, 2>(HloOpcode::kGt);
5591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Lesser2D) {
5621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TestElementwise2D<uint8, 2>(HloOpcode::kLt);
5631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, GreaterOrEqual2D) {
5661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TestElementwise2D<uint8, 2>(HloOpcode::kGe);
5671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, LesserOrEqual2D) {
5701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TestElementwise2D<uint8, 2>(HloOpcode::kLe);
5711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Clamp2D) {
5741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TestElementwise2D<float, 3>(HloOpcode::kClamp);
5751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5773b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlowervoid BM_ParallelFusion(int num_iters) {
5783b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  // Simple element-wise computation to benchmark parallel task partitioning.
5793b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  tensorflow::testing::StopTiming();
5803b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower
5813b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
5823b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
5833b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  StreamExecutorMemoryAllocator allocator(platform, executors);
5843b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower
5853b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  const int64 intra_op_parallelism_threads = 16;
5863b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  xla::LocalClientOptions client_options;
5873b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  client_options.set_platform(platform);
5883b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  client_options.set_intra_op_parallelism_threads(intra_op_parallelism_threads);
5893b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  auto client =
5903b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower      ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie();
5913b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower
5923b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  const int64 dim_size = 1024;
5933b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  // Create a simple fusable elementwise computation.
5943b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  ComputationBuilder builder(client, "ParallelFusion");
5953b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  Shape input_shape = ShapeUtil::MakeShape(F32, {dim_size, dim_size});
5963b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  auto input0 = builder.Broadcast(builder.ConstantR0<float>(1.5f),
5973b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower                                  AsInt64Slice(input_shape.dimensions()));
5983b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  auto input1 = builder.Broadcast(builder.ConstantR0<float>(2.0f),
5993b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower                                  AsInt64Slice(input_shape.dimensions()));
6003b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  auto input2 = builder.Broadcast(builder.ConstantR0<float>(3.0f),
6013b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower                                  AsInt64Slice(input_shape.dimensions()));
6023b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  auto x = builder.Mul(input0, input1);
6033b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  auto y = builder.Add(x, input2);
6043b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  auto computation = builder.Build().ConsumeValueOrDie();
6053b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower
6063b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  std::unique_ptr<LocalExecutable> executable =
6073b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower      client->Compile(computation, {}, ExecutableBuildOptions())
6083b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower          .ConsumeValueOrDie();
6093b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower
6103b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  // Run some warm-up executions.
6113b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  ExecutableRunOptions options;
6123b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  options.set_allocator(&allocator);
6133b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  const int kWarmups = 2;
6143b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  for (int i = 0; i < kWarmups; ++i) {
6153b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower    auto result = executable->Run({}, options);
6163b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower    ASSERT_TRUE(result.ok());
6173b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  }
6183b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower
6193b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  // Run benchmark.
6203b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  tensorflow::testing::BytesProcessed(static_cast<int64>(num_iters) * dim_size *
6213b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower                                      dim_size * sizeof(float));
6223b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  tensorflow::testing::UseRealTime();
6233b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  tensorflow::testing::StartTiming();
6243b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  for (int i = 0; i < num_iters; ++i) {
6253b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower    auto result = executable->Run({}, options);
6263b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower    ASSERT_TRUE(result.ok());
6273b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  }
6283b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower}
6293b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower
6303b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlowerBENCHMARK(BM_ParallelFusion);
6313b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower
6321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace
6331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace xla
6341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsint main(int argc, char** argv) {
6361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<tensorflow::Flag> flag_list;
6379641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
6381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
6391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
6401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  if (!parse_result) {
6411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    LOG(ERROR) << "\n" << usage;
6421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return 2;
6431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
6441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  testing::InitGoogleTest(&argc, argv);
6451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  if (argc > 1) {
6461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
6471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return 2;
6481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
6493b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  tensorflow::testing::RunBenchmarks();
6501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return RUN_ALL_TESTS();
6511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
652