11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <math.h>
171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <algorithm>
181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <memory>
191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <new>
20a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower#include <random>
211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <utility>
221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
23a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower#define EIGEN_USE_THREADS
24a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
25a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/array2d.h"
273b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/compiler/xla/client/client_library.h"
283b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/compiler/xla/client/computation.h"
293b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/compiler/xla/client/computation_builder.h"
301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/literal_util.h"
311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/primitive_util.h"
321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/ptr_util.h"
331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/hlo_computation.h"
341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/hlo_instruction.h"
351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/hlo_module.h"
361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/hlo_opcode.h"
373b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/compiler/xla/service/platform_util.h"
381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/shape_util.h"
393b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/literal_test_util.h"
421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/test_macros.h"
431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/xla_data.pb.h"
44a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/lib/gtl/array_slice.h"
461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/logging.h"
471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/protobuf.h"
483b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/core/platform/test_benchmark.h"
491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/types.h"
501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsusing tensorflow::gtl::ArraySlice;
521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
533b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlowernamespace se = ::perftools::gputools;
543b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower
551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla {
561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace {
571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsconst int test_width = 2, test_height = 3;
591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsconst float test_float_vals[3][test_width][test_height] = {
611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    {{-1.0, -1.0, 1.0}, {-3.0, 0.0, -1.0}},
621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    {{-3.0, 2.0, 1.0}, {0.0, -3.0, 1.0}},
631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    {{-3.0, 0.0, -3.0}, {-1.0, -2.0, 1.0}}};
641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// Test whether fusion operations are emitted with no errors and compute
661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// accurate outputs.
671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass FusionTest : public HloTestBase {
681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins protected:
691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  template <typename T, int Arity>
701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  void TestElementwise2D(HloOpcode opcode) {
711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    Array2D<float> operand_data[Arity];
721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int i = 0; i < Arity; ++i) {
731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      new (&operand_data[i]) Array2D<float>(test_width, test_height);
741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    Array2D<T> answer_data(test_width, test_height);
761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int i = 0; i < test_width; ++i) {
771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      for (int j = 0; j < test_height; ++j) {
781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        float xs[Arity];
791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        for (int k = 0; k < Arity; ++k) {
801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          xs[k] = test_float_vals[k][i][j];
811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          operand_data[k](i, j) = xs[k];
821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        }
831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        answer_data(i, j) = ComputeElementwiseAnswer<T>(opcode, xs);
841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      }
851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto builder = HloComputation::Builder(TestName());
889641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky    auto hlo_module = CreateNewModule();
891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto prim_type = primitive_util::NativeToPrimitiveType<T>();
911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    HloInstruction* hlos[4];
931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int i = 0; i < Arity; ++i) {
941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      hlos[i + 1] = builder.AddInstruction(HloInstruction::CreateConstant(
9546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower          Literal::CreateR2FromArray2D(operand_data[i])));
961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto answer_shape =
981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        ShapeUtil::MakeShape(prim_type, {test_width, test_height});
991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::unique_ptr<HloInstruction> root_hlo;
1001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    switch (Arity) {
1011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      case 1:
1021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        root_hlo = HloInstruction::CreateUnary(answer_shape, opcode, hlos[1]);
1031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        break;
1041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      case 2:
1051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1],
1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                hlos[2]);
1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        break;
1081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      case 3:
1091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        root_hlo = HloInstruction::CreateTernary(answer_shape, opcode, hlos[1],
1101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                 hlos[2], hlos[3]);
1111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        break;
1121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      default:
1131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        LOG(FATAL) << "Bad arity: " << Arity;
1141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
1151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    hlos[0] = builder.AddInstruction(std::move(root_hlo));
1161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    hlo_module->AddEntryComputation(builder.Build())
1171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        ->CreateFusionInstruction(
1181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            ArraySlice<HloInstruction*>(hlos, 0, Arity + 1),
1191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            HloInstruction::FusionKind::kLoop);
1201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
12146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower    auto expected = Literal::CreateR2FromArray2D(answer_data);
1221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto actual = ExecuteAndTransfer(std::move(hlo_module), {});
1231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    if (primitive_util::IsFloatingPointType(prim_type)) {
1241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      LiteralTestUtil::ExpectNear(*expected, *actual, ErrorSpec(1e-4));
1251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    } else {
1261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      LiteralTestUtil::ExpectEqual(*expected, *actual);
1271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
1281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins private:
1311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  template <typename T>
1321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  T ComputeElementwiseAnswer(HloOpcode opcode, ArraySlice<float> xs);
1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
1341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinstemplate <>
1361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsfloat FusionTest::ComputeElementwiseAnswer<float>(HloOpcode opcode,
1371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                  ArraySlice<float> xs) {
1381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  switch (opcode) {
1391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kAdd:
1401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return xs[0] + xs[1];
1411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kSubtract:
1421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return xs[0] - xs[1];
1431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kMultiply:
1441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return xs[0] * xs[1];
1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kDivide:
1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return xs[0] / xs[1];
1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kPower:
1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return powf(xs[0], xs[1]);
1491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kMinimum:
1501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return std::min(xs[0], xs[1]);
1511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kMaximum:
1521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return std::max(xs[0], xs[1]);
1531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kClamp:
1541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return std::min(xs[2], std::max(xs[1], xs[0]));
1551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    default:
1561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      LOG(FATAL) << "No elementwise opcode: " << opcode;
1571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinstemplate <>
16111698cc8e157eefe71a60931f1e721ad327e08afMark Heffernanbool FusionTest::ComputeElementwiseAnswer<bool>(HloOpcode opcode,
16211698cc8e157eefe71a60931f1e721ad327e08afMark Heffernan                                                ArraySlice<float> xs) {
1631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  switch (opcode) {
1641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kEq:
1651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return xs[0] == xs[1];
1661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kNe:
1671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return xs[0] != xs[1];
1681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kGt:
1691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return xs[0] > xs[1];
1701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kLt:
1711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return xs[0] < xs[1];
1721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kGe:
1731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return xs[0] >= xs[1];
1741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    case HloOpcode::kLe:
1751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      return xs[0] <= xs[1];
1761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    default:
1771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      LOG(FATAL) << "No comparatory opcode: " << opcode;
1781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Test) {
1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // test expression:
1831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // slice(select({{T, F, T}, {F, T, F}},
1841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  //              concat(transpose({{1.0}, {2.0}, {3.0}} +
1851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  //                               {{-1.0}, {-1.0}, {-1.0}}),
1861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  //                     {{1.62, 2.72, 3.14}}) +
1871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  //                     (-{{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}),
1881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  //              {{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})) = {{0.5}, {2.72}}
1891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
1909641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
1911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
19246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<float>({{1.0}, {2.0}, {3.0}})));
1931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
19446737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<float>({{-1.0}, {-1.0}, {-1.0}})));
1951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
1961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(F32, {3, 1}), HloOpcode::kAdd, const0, const1));
1971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape3 = builder.AddInstruction(HloInstruction::CreateTranspose(
1981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(F32, {1, 3}), add2, {1, 0}));
1991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
20046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<float>({{1.62, 2.72, 3.14}})));
2011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto concat5 = builder.AddInstruction(HloInstruction::CreateConcatenate(
2021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(F32, {2, 3}), {reshape3, const4}, 0));
2031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const6 = builder.AddInstruction(HloInstruction::CreateConstant(
20446737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<float>({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}})));
2051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto negate7 = builder.AddInstruction(HloInstruction::CreateUnary(
2061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kNegate, const6));
2071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto add8 = builder.AddInstruction(HloInstruction::CreateBinary(
2081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kAdd, concat5, negate7));
2091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const9 = builder.AddInstruction(HloInstruction::CreateConstant(
21046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<float>({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})));
21146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto const10 = builder.AddInstruction(HloInstruction::CreateConstant(
21246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<bool>({{true, false, true}, {false, true, false}})));
2131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto select11 = builder.AddInstruction(
2141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}),
2151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                    HloOpcode::kSelect, const10, add8, const9));
2161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto slice12 = builder.AddInstruction(HloInstruction::CreateSlice(
21750b999a8336d19400ab75aea66fe46eca2f5fe0bA. Unique TensorFlower      ShapeUtil::MakeShape(F32, {2, 1}), select11, {0, 1}, {2, 2}, {1, 1}));
2181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // CreateFusionInstruction needs the `instructions_to_fuse` argument in
2191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // reverse topological order, so the first element in `instructions_to_fuse`
2201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // must be the root.
2211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
2221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(
2231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          {slice12, select11, const10, const9, add8, negate7, const6, concat5,
2241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins           const4, reshape3, add2, const1, const0},
2251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          HloInstruction::FusionKind::kLoop);
2261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
22746737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>({{0.5}, {2.72}}),
2281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                              *ExecuteAndTransfer(std::move(hlo_module), {}),
2291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                              ErrorSpec(1e-4));
2301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
2311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// Test whether we emit appropriate code for parameters of fusion instructions.
2331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Parameter) {
2341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Build a computation and fuse part of it so the fusion instruction has an
2351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // operand parameter.
2361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
2379641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
2381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
23946737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<float>({{1.0, 2.0, 3.0}})));
2401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary(
2411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kCopy, const0));
2421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const2 = builder.AddInstruction(HloInstruction::CreateConstant(
24346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<float>({{-2.0, -2.0, -2.0}})));
2441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // add3 = copy1 + const2 = const0 + const2 = {1,2,3} + {-2,-2,-2} = {-1,0,+1}
2451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
2461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kAdd, copy1, const2));
2471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // CreateFusionInstruction needs `instructions_to_fuse` in reverse topological
2481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // order.
2491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
2501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2},
2511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
2521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
25346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>({{-1.0, 0.0, 1.0}}),
2541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                              *ExecuteAndTransfer(std::move(hlo_module), {}),
2551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                              ErrorSpec(1e-4));
2561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
2571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
258a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlowerXLA_TEST_F(FusionTest, RandomizedParallelPartition) {
259a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  // Tests parallel partitioning of a fusion instruction.
260a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  // Create shape with random outer dimension size to generate random parallel
261a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  // partition counts for each test run.
262a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  const int seed = tensorflow::testing::RandomSeed();
263a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  LOG(INFO) << "RandomizedParallelPartition seed: " << seed;
264a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  std::mt19937 generator(seed);
265a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  std::uniform_int_distribution<int> distribution(128, 1024);
266a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  const int64 rand_dim0_size = distribution(generator);
267a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  const int64 dim1_size = 1024;
268a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  Shape shape =
269a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      ShapeUtil::MakeShapeWithLayout(F32, {rand_dim0_size, dim1_size}, {1, 0});
270a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  // Build simple fusion computation: y = x^2 (elementwise).
271a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  auto builder = HloComputation::Builder(TestName());
272a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  auto hlo_module = CreateNewModule();
273a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
274a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  auto two = builder.AddInstruction(
275a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
276a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  auto x =
277a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      builder.AddInstruction(HloInstruction::CreateBroadcast(shape, two, {}));
278a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  auto y = builder.AddInstruction(
279a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, x, x));
280a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
281a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  hlo_module->AddEntryComputation(builder.Build())
282a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      ->CreateFusionInstruction(/*instructions_to_fuse=*/{y, x, two},
283a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower                                HloInstruction::FusionKind::kLoop);
284a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  // Compute result.
285a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
286a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  // Every element of result should be y = x^2 = 4.0.
287a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  for (int i = 0; i < rand_dim0_size; ++i) {
288a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    for (int j = 0; j < dim1_size; ++j) {
289a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      EXPECT_EQ(4.0, result->Get<float>({i, j}));
290a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    }
291a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  }
292a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower}
293a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
2941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
2951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
2969641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
2971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant(
29846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR1<float>({1.0, 2.0, 3.0})));
2991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const_array = builder.AddInstruction(HloInstruction::CreateConstant(
30046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<float>({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}})));
3011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto broadcast = builder.AddInstruction(
3021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateBroadcast(const_array->shape(), const_vector, {1}));
3031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // add2 = broadcast(const_vector) + const_array
3041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  //      = broadcast({1,2,3}) + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}
3051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  //      = {{1, 2, 3}, {1, 2, 3}} + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}
3061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto add2 = builder.AddInstruction(
3071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {2, 3}),
3081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                   HloOpcode::kAdd, broadcast, const_array));
3091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
3101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast},
3111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
3121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  LiteralTestUtil::ExpectNear(
31446737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
3151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4));
3161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
3171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, ReshapeToScalar) {
3191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
3209641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
3211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto single_element_array = builder.AddInstruction(
32246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR2<int32>({{5}})));
3231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape = builder.AddInstruction(HloInstruction::CreateReshape(
3241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {}), single_element_array));
3251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
3261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape},
3271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
32846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(5),
3291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                               *ExecuteAndTransfer(std::move(hlo_module), {}));
3301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
3311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
3331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
3349641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
3351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
33646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}})));
3371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
3381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {1, 2, 3}), const0));
3391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
3401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
3411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
3421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  LiteralTestUtil::ExpectEqual(
34346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
3441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      *ExecuteAndTransfer(std::move(hlo_module), {}));
3451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
3461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
3481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
3499641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
3501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
35146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}})));
3521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape1 = builder.AddInstruction(
3531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 2}), const0));
3541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
3551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
3561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
3571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  LiteralTestUtil::ExpectEqual(
35846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
3591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      *ExecuteAndTransfer(std::move(hlo_module), {}));
3601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
3611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape_1by1by1_) {
3631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
3649641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
3651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(
36646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR3<int32>({{{7}}})));
3671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape1 = builder.AddInstruction(
3681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0));
3691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
3701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
3711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
37246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(7),
3731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                               *ExecuteAndTransfer(std::move(hlo_module), {}));
3741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
3751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape__1by1by1) {
3771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
3789641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
3791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(
38046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR0<int32>(7)));
3811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
3821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {1, 1, 1}), const0));
3831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
3841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
3851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
38646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  LiteralTestUtil::ExpectEqual(*Literal::CreateR3<int32>({{{7}}}),
3871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                               *ExecuteAndTransfer(std::move(hlo_module), {}));
3881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
3891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape__) {
3911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
3929641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
3931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(
39446737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR0<int32>(7)));
3951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape1 = builder.AddInstruction(
3961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0));
3971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
3981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
3991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
40046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(7),
4011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                               *ExecuteAndTransfer(std::move(hlo_module), {}));
4021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
4031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
4051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
4069641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
4071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
40846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
4091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape1 = builder.AddInstruction(
4101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 3}), const0));
4111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
4121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
4131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
4141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  LiteralTestUtil::ExpectEqual(
41546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
4161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      *ExecuteAndTransfer(std::move(hlo_module), {}));
4171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
4181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Transpose_2by3) {
4201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
4219641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
4221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
42346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}})));
4241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
4251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {3, 2}), const0, {1, 0}));
4261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
4271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
4281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
4291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  LiteralTestUtil::ExpectEqual(
43046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
4311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      *ExecuteAndTransfer(std::move(hlo_module), {}));
4321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
4331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Transpose_3by3) {
4351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
4369641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
4371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
43846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
4391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
4401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {3, 3}), const0, {1, 0}));
4411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
4421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
4431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
4441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  LiteralTestUtil::ExpectEqual(
44546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
4461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      *ExecuteAndTransfer(std::move(hlo_module), {}));
4471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
4481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reverse) {
4501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
4519641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
4521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(
45346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3})));
4541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse(
4551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {3}), const0, {0}));
4561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
4571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1},
4581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
4591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
46046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({3, 2, 1}),
4611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                               *ExecuteAndTransfer(std::move(hlo_module), {}));
4621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
4631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
46446ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlowerXLA_TEST_F(FusionTest, ReverseNegate) {
46546ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto builder = HloComputation::Builder(TestName());
46646ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto hlo_module = CreateNewModule();
46746ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto const0 = builder.AddInstruction(
46846ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3})));
46946ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse(
47046ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      ShapeUtil::MakeShape(S32, {3}), const0, {0}));
47146ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
47246ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      ShapeUtil::MakeShape(S32, {3}), HloOpcode::kNegate, reverse1));
47346ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  hlo_module->AddEntryComputation(builder.Build())
47446ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reverse1},
47546ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower                                HloInstruction::FusionKind::kLoop);
47646ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower
47746ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-3, -2, -1}),
47846ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower                               *ExecuteAndTransfer(std::move(hlo_module), {}));
47946ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower}
48046ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower
48146ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlowerXLA_TEST_F(FusionTest, BroadcastNegate) {
48246ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto builder = HloComputation::Builder(TestName());
48346ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto hlo_module = CreateNewModule();
48446ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto const0 = builder.AddInstruction(
48546ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
48646ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
48746ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      ShapeUtil::MakeShape(S32, {2}), const0, {}));
48846ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
48946ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, broadcast1));
49046ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  hlo_module->AddEntryComputation(builder.Build())
49146ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, broadcast1},
49246ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower                                HloInstruction::FusionKind::kLoop);
49346ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower
49446ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-1, -1}),
49546ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower                               *ExecuteAndTransfer(std::move(hlo_module), {}));
49646ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower}
49746ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower
49846ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlowerXLA_TEST_F(FusionTest, SliceNegate) {
49946ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto builder = HloComputation::Builder(TestName());
50046ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto hlo_module = CreateNewModule();
50146ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto const0 = builder.AddInstruction(
50246ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3, 4})));
50346ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto slice1 = builder.AddInstruction(HloInstruction::CreateSlice(
50446ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      ShapeUtil::MakeShape(S32, {2}), const0, {0}, {4}, {2}));
50546ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
50646ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, slice1));
50746ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  hlo_module->AddEntryComputation(builder.Build())
50846ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, slice1},
50946ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower                                HloInstruction::FusionKind::kLoop);
51046ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower
51146ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-1, -3}),
51246ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower                               *ExecuteAndTransfer(std::move(hlo_module), {}));
51346ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower}
51446ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower
51546ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlowerXLA_TEST_F(FusionTest, DynamicSliceNegate) {
51646ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto builder = HloComputation::Builder(TestName());
51746ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto hlo_module = CreateNewModule();
51846ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto const0 = builder.AddInstruction(
51946ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3, 4})));
52046ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto const1 = builder.AddInstruction(
52146ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR1<int32>({1})));
52246ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto dynamic_slice2 =
52346ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      builder.AddInstruction(HloInstruction::CreateDynamicSlice(
52446ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower          ShapeUtil::MakeShape(S32, {2}), const0, const1, {2}));
52546ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto negate3 = builder.AddInstruction(HloInstruction::CreateUnary(
52646ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, dynamic_slice2));
52746ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  hlo_module->AddEntryComputation(builder.Build())
52846ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      ->CreateFusionInstruction(
52946ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower          /*instructions_to_fuse=*/{negate3, dynamic_slice2},
53046ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower          HloInstruction::FusionKind::kLoop);
53146ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower
53246ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-2, -3}),
53346ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower                               *ExecuteAndTransfer(std::move(hlo_module), {}));
53446ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower}
53546ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower
53646ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlowerXLA_TEST_F(FusionTest, ReshapeNegate) {
53746ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto builder = HloComputation::Builder(TestName());
53846ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto hlo_module = CreateNewModule();
53946ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto const0 = builder.AddInstruction(
54046ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3, 4})));
54146ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto reshape1 = builder.AddInstruction(
54246ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {2, 2}), const0));
54346ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
54446ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      ShapeUtil::MakeShape(S32, {2, 2}), HloOpcode::kNegate, reshape1));
54546ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  hlo_module->AddEntryComputation(builder.Build())
54646ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1},
54746ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower                                HloInstruction::FusionKind::kLoop);
54846ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower
54946ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  LiteralTestUtil::ExpectEqual(*Literal::CreateR2<int32>({{-1, -2}, {-3, -4}}),
55046ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower                               *ExecuteAndTransfer(std::move(hlo_module), {}));
55146ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower}
55246ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower
55346ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower// TODO(b/64070202): Investigate failure.
55446ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlowerXLA_TEST_F(FusionTest, DISABLED_ON_GPU(TransposeNegate)) {
55546ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto builder = HloComputation::Builder(TestName());
55646ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto hlo_module = CreateNewModule();
55746ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
55846ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      Literal::CreateR2<int32>({{1, 2}, {3, 4}})));
55946ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose(
56046ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      ShapeUtil::MakeShape(S32, {2, 2}), const0, {1, 0}));
56146ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
56246ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      ShapeUtil::MakeShape(S32, {2, 2}), HloOpcode::kNegate, transpose1));
56346ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  hlo_module->AddEntryComputation(builder.Build())
56446ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1},
56546ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower                                HloInstruction::FusionKind::kLoop);
56646ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower
56746ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower  LiteralTestUtil::ExpectEqual(*Literal::CreateR2<int32>({{-1, -3}, {-2, -4}}),
56846ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower                               *ExecuteAndTransfer(std::move(hlo_module), {}));
56946ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower}
57046ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower
5711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsstd::unique_ptr<HloComputation> MakeReduceTestComputation() {
5721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder("add");
5731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
5741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*parameter_number=*/0, ShapeUtil::MakeShape(S32, {}), "lhs"));
5751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto rhs = builder.AddInstruction(HloInstruction::CreateParameter(
5761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      /*parameter_number=*/1, ShapeUtil::MakeShape(S32, {}), "rhs"));
5771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  builder.AddInstruction(HloInstruction::CreateBinary(
5781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, lhs, rhs));
5791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return builder.Build();
5801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
5839641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
5841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
58646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto const0 = builder.AddInstruction(
58746737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 4, 8})));
5881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const1 = builder.AddInstruction(
58946737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
5901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
5911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {}), const0, const1, {0},
5921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
5931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
5941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2},
5951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
5961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
59746737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(15),
5981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                               *ExecuteAndTransfer(std::move(hlo_module), {}));
5991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
6001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
6029641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
6031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
60546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto const0 = builder.AddInstruction(
60646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 4, 8})));
6071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const1 = builder.AddInstruction(
60846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
6091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
6101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ShapeUtil::MakeShape(S32, {}), const0, const1, {0},
6111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
6121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto negate3 = builder.AddInstruction(HloInstruction::CreateUnary(
61311698cc8e157eefe71a60931f1e721ad327e08afMark Heffernan      ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, reduce2));
6141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
6151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2},
6161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
6171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
61811698cc8e157eefe71a60931f1e721ad327e08afMark Heffernan  LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(-15),
6191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                               *ExecuteAndTransfer(std::move(hlo_module), {}));
6201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
6211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
6231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto builder = HloComputation::Builder(TestName());
6249641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky  auto hlo_module = CreateNewModule();
6251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
62646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR2<int32>({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}})));
6271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto const1 = builder.AddInstruction(
62846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
6291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Window window;
6301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ASSERT_TRUE(
6311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      tensorflow::protobuf::TextFormat::ParseFromString("dimensions:{\n"
6321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "size:2\n"
6331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "stride:1\n"
6341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "padding_low:0\n"
6351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "padding_high:0\n"
6361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "window_dilation:1\n"
6371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "base_dilation:1\n"
6381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "}\n"
6391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "dimensions:{\n"
6401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "size:2\n"
6411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "stride:1\n"
6421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "padding_low:0\n"
6431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "padding_high:0\n"
6441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "window_dilation:1\n"
6451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "base_dilation:1\n"
6461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        "}\n",
6471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                                        &window));
6481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto nested_builder = HloComputation::Builder("mul");
6491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  {
6501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto x = nested_builder.AddInstruction(
6511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "x"));
6521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto y = nested_builder.AddInstruction(
6531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(S32, {}), "y"));
6541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    nested_builder.AddInstruction(HloInstruction::CreateBinary(
6551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply, x, y));
6561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
6571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto nested_computation =
6581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      hlo_module->AddEmbeddedComputation(nested_builder.Build());
6591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto reduce_window2 =
6601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.AddInstruction(HloInstruction::CreateReduceWindow(
6611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          ShapeUtil::MakeShape(S32, {2, 2}), const0, const1, window,
6621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          nested_computation));
6631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  hlo_module->AddEntryComputation(builder.Build())
6641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2},
6651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                HloInstruction::FusionKind::kLoop);
6661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  LiteralTestUtil::ExpectEqual(
66846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
6691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      *ExecuteAndTransfer(std::move(hlo_module), {}));
6701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
6711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
672a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan// When a constant (or other op) which has multiple users is imported
673a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan// into a fusion, it should remain shared, rather than being duplicated
674a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan// within the fusion.
675a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay VasudevanXLA_TEST_F(FusionTest, SharedConstant) {
676a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan  auto hlo_module = CreateNewModule();
677a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan
678a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan  auto builder = HloComputation::Builder(TestName());
679a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan  auto const0 = builder.AddInstruction(
680a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan          HloInstruction::CreateConstant(Literal::CreateR1<int32>({0})));
681a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan  auto const1 = builder.AddInstruction(
682a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan          HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
683a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan  auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
684a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan          ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0));
685a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan  auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
686a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan          ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add1));
687a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan  auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
688a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan          ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add2));
689a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan  auto add4 = builder.AddInstruction(HloInstruction::CreateBinary(
690a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan          ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add3));
691a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan  hlo_module->AddEntryComputation(builder.Build())
692a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan      ->CreateFusionInstruction(
693a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan        {add4, add3, add2, add1, const1},
694a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan        HloInstruction::FusionKind::kLoop);
695a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan
696a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan  HloComputation* entry_comp = hlo_module->entry_computation();
697a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan
698a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan  // entry computation contains the constant(0) and the fusion
6999b1b5d85b9ce3c812dc772da1f3f5d09581e5b49Justin Lebar  EXPECT_EQ(entry_comp->instruction_count(), 2);
700a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan
701a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan  // fused instruction contains the constant(2), the parameter, and 4 adds
7029b1b5d85b9ce3c812dc772da1f3f5d09581e5b49Justin Lebar  EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6);
703a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan
704a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan  LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({8}),
705a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan          *ExecuteAndTransfer(std::move(hlo_module), {}));
706a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan}
707a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan
7081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); }
7091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
7101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Subtract2D) {
7111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TestElementwise2D<float, 2>(HloOpcode::kSubtract);
7121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
7131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
7141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Multiply2D) {
7151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TestElementwise2D<float, 2>(HloOpcode::kMultiply);
7161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
7171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
7181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Divide2D) {
7191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TestElementwise2D<float, 2>(HloOpcode::kDivide);
7201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
7211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
7221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Power2D) {
7231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TestElementwise2D<float, 2>(HloOpcode::kPower);
7241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
7251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
7261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Minimum2D) {
7271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TestElementwise2D<float, 2>(HloOpcode::kMinimum);
7281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
7291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
7301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Maximum2D) {
7311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TestElementwise2D<float, 2>(HloOpcode::kMaximum);
7321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
7331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
73411698cc8e157eefe71a60931f1e721ad327e08afMark HeffernanXLA_TEST_F(FusionTest, Equal2D) { TestElementwise2D<bool, 2>(HloOpcode::kEq); }
7351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
7361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Inequal2D) {
73711698cc8e157eefe71a60931f1e721ad327e08afMark Heffernan  TestElementwise2D<bool, 2>(HloOpcode::kNe);
7381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
7391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
7401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Greater2D) {
74111698cc8e157eefe71a60931f1e721ad327e08afMark Heffernan  TestElementwise2D<bool, 2>(HloOpcode::kGt);
7421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
7431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
74411698cc8e157eefe71a60931f1e721ad327e08afMark HeffernanXLA_TEST_F(FusionTest, Lesser2D) { TestElementwise2D<bool, 2>(HloOpcode::kLt); }
7451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
7461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, GreaterOrEqual2D) {
74711698cc8e157eefe71a60931f1e721ad327e08afMark Heffernan  TestElementwise2D<bool, 2>(HloOpcode::kGe);
7481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
7491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
7501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, LesserOrEqual2D) {
75111698cc8e157eefe71a60931f1e721ad327e08afMark Heffernan  TestElementwise2D<bool, 2>(HloOpcode::kLe);
7521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
7531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
7541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Clamp2D) {
7551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TestElementwise2D<float, 3>(HloOpcode::kClamp);
7561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
7571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
7583b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlowervoid BM_ParallelFusion(int num_iters) {
7593b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  // Simple element-wise computation to benchmark parallel task partitioning.
7603b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  tensorflow::testing::StopTiming();
7613b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower
7623b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
7633b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
7643b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  StreamExecutorMemoryAllocator allocator(platform, executors);
7653b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower
766a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  const int64 intra_op_parallelism_threads = 24;
7673b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  xla::LocalClientOptions client_options;
7683b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  client_options.set_platform(platform);
7693b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  client_options.set_intra_op_parallelism_threads(intra_op_parallelism_threads);
7703b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  auto client =
7713b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower      ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie();
7723b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower
773a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  int device_ordinal = client->default_device_ordinal();
774a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
775a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  // Computation shape parameters.
776a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  const int64 param0_dim0 = 1024;
777a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  const int64 param0_dim1 = 1024;
778a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  const int64 param1_dim0 = 1024;
779a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  const int64 param1_dim1 = 1024;
780a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  const int64 param2_dim0 = 1024;
781a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  const int64 param2_dim1 = 1024;
782a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
783a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  // Create computation.
7843b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  ComputationBuilder builder(client, "ParallelFusion");
785a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  Shape shape0 = ShapeUtil::MakeShape(F32, {param0_dim0, param0_dim1});
786a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  auto param0 = builder.Parameter(0, shape0, "param0");
787a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  Shape shape1 = ShapeUtil::MakeShape(F32, {param1_dim0, param1_dim1});
788a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  auto param1 = builder.Parameter(1, shape1, "param1");
789a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  Shape shape2 = ShapeUtil::MakeShape(F32, {param2_dim0, param2_dim1});
790a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  auto param2 = builder.Parameter(2, shape2, "param2");
791a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
792a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  auto x = builder.Mul(param0, param1);
793a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  auto y = builder.Add(x, param2);
7943b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  auto computation = builder.Build().ConsumeValueOrDie();
7953b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower
796a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  // Transfer literals to device.
797a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  auto param0_literal =
798a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      Literal::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1);
79922d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan  std::unique_ptr<ShapedBuffer> buffer0 =
80022d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan      client->LiteralToShapedBuffer(*param0_literal, device_ordinal)
801a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower          .ConsumeValueOrDie();
80222d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan
803a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  auto param1_literal =
804a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      Literal::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1);
80522d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan  std::unique_ptr<ShapedBuffer> buffer1 =
80622d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan      client->LiteralToShapedBuffer(*param1_literal, device_ordinal)
807a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower          .ConsumeValueOrDie();
80822d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan
809a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  auto param2_literal =
810a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      Literal::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1);
81122d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan  std::unique_ptr<ShapedBuffer> buffer2 =
81222d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan      client->LiteralToShapedBuffer(*param2_literal, device_ordinal)
81322d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan          .ConsumeValueOrDie();
814a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
815a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  // Build executable.
8163b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  std::unique_ptr<LocalExecutable> executable =
817a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      client
818a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower          ->Compile(computation,
819fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower                    {&buffer0->on_host_shape(), &buffer1->on_host_shape(),
820fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower                     &buffer2->on_host_shape()},
821a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower                    ExecutableBuildOptions())
8223b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower          .ConsumeValueOrDie();
8233b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower
82422d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan  se::Stream stream(executors[device_ordinal]);
825a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  stream.Init();
826a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
827a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  // Initialize thread pool.
828a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
829a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower                                      intra_op_parallelism_threads);
830a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  tensorflow::EigenThreadPoolWrapper tp(&pool);
831a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
832a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
833a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  // Initialize ExecutableRunOptions.
8341f1b2bb6c3833a472036da22b7c910f5f2bdf694Anna R  ExecutableRunOptions options;
835a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  options.set_allocator(&allocator).set_stream(&stream);
836a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  options.set_intra_op_thread_pool(&device);
837a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
838a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  // Run some warm-up executions.
8393b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  const int kWarmups = 2;
8403b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  for (int i = 0; i < kWarmups; ++i) {
841a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    auto result =
842a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower        executable->Run({buffer0.get(), buffer1.get(), buffer2.get()}, options);
8433b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower    ASSERT_TRUE(result.ok());
8443b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  }
8453b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower
8463b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  // Run benchmark.
847a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  const int64 total_bytes = param0_dim0 * param0_dim0 +
848a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower                            param1_dim0 * param1_dim0 +
849a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower                            param2_dim0 * param2_dim0;
850a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  tensorflow::testing::BytesProcessed(static_cast<int64>(num_iters) *
851a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower                                      total_bytes * sizeof(float));
8523b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  tensorflow::testing::UseRealTime();
8533b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  tensorflow::testing::StartTiming();
8543b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  for (int i = 0; i < num_iters; ++i) {
855a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    auto result =
856a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower        executable->Run({buffer0.get(), buffer1.get(), buffer2.get()}, options);
8573b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower    ASSERT_TRUE(result.ok());
8583b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower  }
8593b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower}
8603b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower
8613b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlowerBENCHMARK(BM_ParallelFusion);
8623b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower
8631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace
8641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace xla
865