fusion_test.cc revision 3b41352a3177c2fe8a1329e8981b285bb6aacf8b
11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License"); 41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License. 51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at 61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins http://www.apache.org/licenses/LICENSE-2.0 81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software 101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS, 111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and 131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License. 141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/ 151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <math.h> 171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <algorithm> 181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <memory> 191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <new> 201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <utility> 211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/array2d.h" 233b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/compiler/xla/client/client_library.h" 243b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/compiler/xla/client/computation.h" 253b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/compiler/xla/client/computation_builder.h" 269641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" 271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/literal_util.h" 281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/primitive_util.h" 291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/ptr_util.h" 301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/hlo_computation.h" 311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/hlo_instruction.h" 321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/hlo_module.h" 331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/hlo_opcode.h" 343b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/compiler/xla/service/platform_util.h" 351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/shape_util.h" 363b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/compiler/xla/tests/client_library_test_base.h" 371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/hlo_test_base.h" 381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/literal_test_util.h" 391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/test_macros.h" 401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/xla_data.pb.h" 411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/lib/gtl/array_slice.h" 421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/logging.h" 431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/protobuf.h" 443b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/core/platform/test_benchmark.h" 451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/types.h" 461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsusing tensorflow::gtl::ArraySlice; 481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 493b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlowernamespace se = ::perftools::gputools; 503b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower 511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla { 521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace { 531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsconst int test_width = 2, test_height = 3; 551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsconst float test_float_vals[3][test_width][test_height] = { 571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {{-1.0, -1.0, 1.0}, {-3.0, 0.0, -1.0}}, 581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {{-3.0, 2.0, 1.0}, {0.0, -3.0, 1.0}}, 591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {{-3.0, 0.0, -3.0}, {-1.0, -2.0, 1.0}}}; 601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// Test whether fusion operations are emitted with no errors and compute 621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// accurate outputs. 631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass FusionTest : public HloTestBase { 641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins protected: 651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins template <typename T, int Arity> 661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins void TestElementwise2D(HloOpcode opcode) { 671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array2D<float> operand_data[Arity]; 681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i = 0; i < Arity; ++i) { 691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins new (&operand_data[i]) Array2D<float>(test_width, test_height); 701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array2D<T> answer_data(test_width, test_height); 721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i = 0; i < test_width; ++i) { 731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int j = 0; j < test_height; ++j) { 741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float xs[Arity]; 751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int k = 0; k < Arity; ++k) { 761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins xs[k] = test_float_vals[k][i][j]; 771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins operand_data[k](i, j) = xs[k]; 781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins answer_data(i, j) = ComputeElementwiseAnswer<T>(opcode, xs); 801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 849641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto prim_type = primitive_util::NativeToPrimitiveType<T>(); 871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction* hlos[4]; 891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i = 0; i < Arity; ++i) { 901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlos[i + 1] = builder.AddInstruction(HloInstruction::CreateConstant( 9146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2FromArray2D(operand_data[i]))); 921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto answer_shape = 941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(prim_type, {test_width, test_height}); 951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::unique_ptr<HloInstruction> root_hlo; 961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins switch (Arity) { 971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case 1: 981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins root_hlo = HloInstruction::CreateUnary(answer_shape, opcode, hlos[1]); 991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins break; 1001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case 2: 1011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1], 1021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlos[2]); 1031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins break; 1041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case 3: 1051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins root_hlo = HloInstruction::CreateTernary(answer_shape, opcode, hlos[1], 1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlos[2], hlos[3]); 1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins break; 1081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins default: 1091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LOG(FATAL) << "Bad arity: " << Arity; 1101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlos[0] = builder.AddInstruction(std::move(root_hlo)); 1121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 1131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction( 1141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ArraySlice<HloInstruction*>(hlos, 0, Arity + 1), 1151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 1161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 11746737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto expected = Literal::CreateR2FromArray2D(answer_data); 1181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto actual = ExecuteAndTransfer(std::move(hlo_module), {}); 1191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (primitive_util::IsFloatingPointType(prim_type)) { 1201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LiteralTestUtil::ExpectNear(*expected, *actual, ErrorSpec(1e-4)); 1211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } else { 1221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LiteralTestUtil::ExpectEqual(*expected, *actual); 1231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins private: 1271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins template <typename T> 1281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins T ComputeElementwiseAnswer(HloOpcode opcode, ArraySlice<float> xs); 1291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}; 1301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinstemplate <> 1321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsfloat FusionTest::ComputeElementwiseAnswer<float>(HloOpcode opcode, 1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ArraySlice<float> xs) { 1341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins switch (opcode) { 1351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kAdd: 1361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return xs[0] + xs[1]; 1371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kSubtract: 1381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return xs[0] - xs[1]; 1391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kMultiply: 1401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return xs[0] * xs[1]; 1411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kDivide: 1421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return xs[0] / xs[1]; 1431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kPower: 1441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return powf(xs[0], xs[1]); 1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kMinimum: 1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return std::min(xs[0], xs[1]); 1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kMaximum: 1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return std::max(xs[0], xs[1]); 1491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kClamp: 1501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return std::min(xs[2], std::max(xs[1], xs[0])); 1511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins default: 1521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LOG(FATAL) << "No elementwise opcode: " << opcode; 1531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinstemplate <> 1571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsuint8 FusionTest::ComputeElementwiseAnswer<uint8>(HloOpcode opcode, 1581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ArraySlice<float> xs) { 1591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins switch (opcode) { 1601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kEq: 1611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return xs[0] == xs[1]; 1621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kNe: 1631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return xs[0] != xs[1]; 1641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kGt: 1651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return xs[0] > xs[1]; 1661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kLt: 1671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return xs[0] < xs[1]; 1681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kGe: 1691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return xs[0] >= xs[1]; 1701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kLe: 1711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return xs[0] <= xs[1]; 1721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins default: 1731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LOG(FATAL) << "No comparatory opcode: " << opcode; 1741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Test) { 1781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // test expression: 1791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // slice(select({{T, F, T}, {F, T, F}}, 1801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // concat(transpose({{1.0}, {2.0}, {3.0}} + 1811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // {{-1.0}, {-1.0}, {-1.0}}), 1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // {{1.62, 2.72, 3.14}}) + 1831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // (-{{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}), 1841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // {{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})) = {{0.5}, {2.72}} 1851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 1869641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 1871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 18846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<float>({{1.0}, {2.0}, {3.0}}))); 1891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( 19046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<float>({{-1.0}, {-1.0}, {-1.0}}))); 1911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( 1921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(F32, {3, 1}), HloOpcode::kAdd, const0, const1)); 1931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reshape3 = builder.AddInstruction(HloInstruction::CreateTranspose( 1941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(F32, {1, 3}), add2, {1, 0})); 1951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const4 = builder.AddInstruction(HloInstruction::CreateConstant( 19646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<float>({{1.62, 2.72, 3.14}}))); 1971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto concat5 = builder.AddInstruction(HloInstruction::CreateConcatenate( 1981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(F32, {2, 3}), {reshape3, const4}, 0)); 1991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const6 = builder.AddInstruction(HloInstruction::CreateConstant( 20046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<float>({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}))); 2011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto negate7 = builder.AddInstruction(HloInstruction::CreateUnary( 2021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kNegate, const6)); 2031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto add8 = builder.AddInstruction(HloInstruction::CreateBinary( 2041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kAdd, concat5, negate7)); 2051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const9 = builder.AddInstruction(HloInstruction::CreateConstant( 20646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<float>({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}}))); 20746737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto const10 = builder.AddInstruction(HloInstruction::CreateConstant( 20846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<bool>({{true, false, true}, {false, true, false}}))); 2091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto select11 = builder.AddInstruction( 2101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}), 2111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloOpcode::kSelect, const10, add8, const9)); 2121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto slice12 = builder.AddInstruction(HloInstruction::CreateSlice( 2131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(F32, {2, 1}), select11, {0, 1}, {2, 2})); 2141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // CreateFusionInstruction needs the `instructions_to_fuse` argument in 2151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // reverse topological order, so the first element in `instructions_to_fuse` 2161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // must be the root. 2171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 2181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction( 2191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {slice12, select11, const10, const9, add8, negate7, const6, concat5, 2201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const4, reshape3, add2, const1, const0}, 2211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 2221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 22346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>({{0.5}, {2.72}}), 2241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {}), 2251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ErrorSpec(1e-4)); 2261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 2271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// Test whether we emit appropriate code for parameters of fusion instructions. 2291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Parameter) { 2301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Build a computation and fuse part of it so the fusion instruction has an 2311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // operand parameter. 2321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 2339641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 2341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 23546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<float>({{1.0, 2.0, 3.0}}))); 2361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary( 2371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kCopy, const0)); 2381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const2 = builder.AddInstruction(HloInstruction::CreateConstant( 23946737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<float>({{-2.0, -2.0, -2.0}}))); 2401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // add3 = copy1 + const2 = const0 + const2 = {1,2,3} + {-2,-2,-2} = {-1,0,+1} 2411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( 2421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kAdd, copy1, const2)); 2431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // CreateFusionInstruction needs `instructions_to_fuse` in reverse topological 2441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // order. 2451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 2461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2}, 2471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 2481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 24946737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>({{-1.0, 0.0, 1.0}}), 2501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {}), 2511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ErrorSpec(1e-4)); 2521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 2531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { 2551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 2569641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 2571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant( 25846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR1<float>({1.0, 2.0, 3.0}))); 2591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const_array = builder.AddInstruction(HloInstruction::CreateConstant( 26046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<float>({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}))); 2611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto broadcast = builder.AddInstruction( 2621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::CreateBroadcast(const_array->shape(), const_vector, {1})); 2631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // add2 = broadcast(const_vector) + const_array 2641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // = broadcast({1,2,3}) + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}} 2651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // = {{1, 2, 3}, {1, 2, 3}} + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}} 2661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto add2 = builder.AddInstruction( 2671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {2, 3}), 2681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloOpcode::kAdd, broadcast, const_array)); 2691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 2701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast}, 2711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 2721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LiteralTestUtil::ExpectNear( 27446737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), 2751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)); 2761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 2771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, ReshapeToScalar) { 2791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 2809641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 2811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto single_element_array = builder.AddInstruction( 28246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR2<int32>({{5}}))); 2831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( 2841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(S32, {}), single_element_array)); 2851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 2861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape}, 2871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 28846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(5), 2891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 2901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 2911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { 2931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 2949641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 2951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 29646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}))); 2971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( 2981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(S32, {1, 2, 3}), const0)); 2991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 3001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 3011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 3021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LiteralTestUtil::ExpectEqual( 30346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}), 3041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 3051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 3061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { 3081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 3099641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 3101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 31146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}))); 3121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reshape1 = builder.AddInstruction( 3131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 2}), const0)); 3141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 3151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 3161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 3171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LiteralTestUtil::ExpectEqual( 31846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}), 3191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 3201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 3211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape_1by1by1_) { 3231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 3249641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 3251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction( 32646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR3<int32>({{{7}}}))); 3271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reshape1 = builder.AddInstruction( 3281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0)); 3291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 3301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 3311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 33246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(7), 3331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 3341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 3351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape__1by1by1) { 3371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 3389641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 3391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction( 34046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR0<int32>(7))); 3411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( 3421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(S32, {1, 1, 1}), const0)); 3431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 3441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 3451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 34646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower LiteralTestUtil::ExpectEqual(*Literal::CreateR3<int32>({{{7}}}), 3471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 3481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 3491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape__) { 3511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 3529641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 3531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction( 35446737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR0<int32>(7))); 3551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reshape1 = builder.AddInstruction( 3561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0)); 3571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 3581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 3591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 36046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(7), 3611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 3621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 3631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape_3by3_3by3) { 3651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 3669641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 3671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 36846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); 3691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reshape1 = builder.AddInstruction( 3701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 3}), const0)); 3711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 3721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 3731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 3741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LiteralTestUtil::ExpectEqual( 37546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), 3761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 3771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 3781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Transpose_2by3) { 3801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 3819641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 3821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 38346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}}))); 3841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( 3851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(S32, {3, 2}), const0, {1, 0})); 3861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 3871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 3881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 3891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LiteralTestUtil::ExpectEqual( 39046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}), 3911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 3921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 3931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Transpose_3by3) { 3951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 3969641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 3971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 39846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); 3991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( 4001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(S32, {3, 3}), const0, {1, 0})); 4011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 4021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 4031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 4041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LiteralTestUtil::ExpectEqual( 40546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), 4061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 4071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 4081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reverse) { 4101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 4119641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 4121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction( 41346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3}))); 4141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( 4151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(S32, {3}), const0, {0})); 4161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 4171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1}, 4181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 4191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 42046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({3, 2, 1}), 4211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 4221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 4231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsstd::unique_ptr<HloComputation> MakeReduceTestComputation() { 4251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder("add"); 4261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto lhs = builder.AddInstruction(HloInstruction::CreateParameter( 4271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*parameter_number=*/0, ShapeUtil::MakeShape(S32, {}), "lhs")); 4281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto rhs = builder.AddInstruction(HloInstruction::CreateParameter( 4291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*parameter_number=*/1, ShapeUtil::MakeShape(S32, {}), "rhs")); 4301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.AddInstruction(HloInstruction::CreateBinary( 4311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, lhs, rhs)); 4321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return builder.Build(); 4331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 4341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { 4369641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 4371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 43946737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto const0 = builder.AddInstruction( 44046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 4, 8}))); 4411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const1 = builder.AddInstruction( 44246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))); 4431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce( 4441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(S32, {}), const0, const1, {0}, 4451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); 4461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 4471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2}, 4481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 4491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 45046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(15), 4511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 4521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 4531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { 4559641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 4561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 45846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto const0 = builder.AddInstruction( 45946737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 4, 8}))); 4601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const1 = builder.AddInstruction( 46146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))); 4621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce( 4631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(S32, {}), const0, const1, {0}, 4641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); 4651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto negate3 = builder.AddInstruction(HloInstruction::CreateUnary( 4661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(S32, {1}), HloOpcode::kNegate, reduce2)); 4671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 4681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2}, 4691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 4701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 47146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-15}), 4721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 4731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 4741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { 4761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 4779641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 4781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 47946737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<int32>({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}}))); 4801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const1 = builder.AddInstruction( 48146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR0<int32>(1))); 4821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Window window; 4831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ASSERT_TRUE( 4841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins tensorflow::protobuf::TextFormat::ParseFromString("dimensions:{\n" 4851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "size:2\n" 4861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "stride:1\n" 4871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "padding_low:0\n" 4881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "padding_high:0\n" 4891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "window_dilation:1\n" 4901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "base_dilation:1\n" 4911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "}\n" 4921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "dimensions:{\n" 4931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "size:2\n" 4941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "stride:1\n" 4951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "padding_low:0\n" 4961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "padding_high:0\n" 4971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "window_dilation:1\n" 4981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "base_dilation:1\n" 4991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "}\n", 5001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins &window)); 5011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto nested_builder = HloComputation::Builder("mul"); 5021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins { 5031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto x = nested_builder.AddInstruction( 5041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "x")); 5051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto y = nested_builder.AddInstruction( 5061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(S32, {}), "y")); 5071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins nested_builder.AddInstruction(HloInstruction::CreateBinary( 5081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply, x, y)); 5091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto nested_computation = 5111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEmbeddedComputation(nested_builder.Build()); 5121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reduce_window2 = 5131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.AddInstruction(HloInstruction::CreateReduceWindow( 5141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(S32, {2, 2}), const0, const1, window, 5151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins nested_computation)); 5161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 5171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2}, 5181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 5191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LiteralTestUtil::ExpectEqual( 52146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR2<int32>({{462, 2145}, {24871, 62491}}), 5221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 5231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); } 5261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Subtract2D) { 5281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TestElementwise2D<float, 2>(HloOpcode::kSubtract); 5291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Multiply2D) { 5321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TestElementwise2D<float, 2>(HloOpcode::kMultiply); 5331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Divide2D) { 5361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TestElementwise2D<float, 2>(HloOpcode::kDivide); 5371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Power2D) { 5401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TestElementwise2D<float, 2>(HloOpcode::kPower); 5411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Minimum2D) { 5441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TestElementwise2D<float, 2>(HloOpcode::kMinimum); 5451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Maximum2D) { 5481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TestElementwise2D<float, 2>(HloOpcode::kMaximum); 5491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Equal2D) { TestElementwise2D<uint8, 2>(HloOpcode::kEq); } 5521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Inequal2D) { 5541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TestElementwise2D<uint8, 2>(HloOpcode::kNe); 5551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Greater2D) { 5581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TestElementwise2D<uint8, 2>(HloOpcode::kGt); 5591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Lesser2D) { 5621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TestElementwise2D<uint8, 2>(HloOpcode::kLt); 5631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, GreaterOrEqual2D) { 5661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TestElementwise2D<uint8, 2>(HloOpcode::kGe); 5671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, LesserOrEqual2D) { 5701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TestElementwise2D<uint8, 2>(HloOpcode::kLe); 5711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Clamp2D) { 5741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TestElementwise2D<float, 3>(HloOpcode::kClamp); 5751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5773b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlowervoid BM_ParallelFusion(int num_iters) { 5783b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower // Simple element-wise computation to benchmark parallel task partitioning. 5793b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower tensorflow::testing::StopTiming(); 5803b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower 5813b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie(); 5823b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); 5833b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower StreamExecutorMemoryAllocator allocator(platform, executors); 5843b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower 5853b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower const int64 intra_op_parallelism_threads = 16; 5863b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower xla::LocalClientOptions client_options; 5873b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower client_options.set_platform(platform); 5883b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower client_options.set_intra_op_parallelism_threads(intra_op_parallelism_threads); 5893b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower auto client = 5903b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie(); 5913b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower 5923b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower const int64 dim_size = 1024; 5933b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower // Create a simple fusable elementwise computation. 5943b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower ComputationBuilder builder(client, "ParallelFusion"); 5953b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower Shape input_shape = ShapeUtil::MakeShape(F32, {dim_size, dim_size}); 5963b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower auto input0 = builder.Broadcast(builder.ConstantR0<float>(1.5f), 5973b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower AsInt64Slice(input_shape.dimensions())); 5983b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower auto input1 = builder.Broadcast(builder.ConstantR0<float>(2.0f), 5993b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower AsInt64Slice(input_shape.dimensions())); 6003b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower auto input2 = builder.Broadcast(builder.ConstantR0<float>(3.0f), 6013b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower AsInt64Slice(input_shape.dimensions())); 6023b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower auto x = builder.Mul(input0, input1); 6033b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower auto y = builder.Add(x, input2); 6043b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower auto computation = builder.Build().ConsumeValueOrDie(); 6053b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower 6063b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower std::unique_ptr<LocalExecutable> executable = 6073b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower client->Compile(computation, {}, ExecutableBuildOptions()) 6083b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower .ConsumeValueOrDie(); 6093b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower 6103b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower // Run some warm-up executions. 6113b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower ExecutableRunOptions options; 6123b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower options.set_allocator(&allocator); 6133b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower const int kWarmups = 2; 6143b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower for (int i = 0; i < kWarmups; ++i) { 6153b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower auto result = executable->Run({}, options); 6163b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower ASSERT_TRUE(result.ok()); 6173b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower } 6183b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower 6193b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower // Run benchmark. 6203b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower tensorflow::testing::BytesProcessed(static_cast<int64>(num_iters) * dim_size * 6213b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower dim_size * sizeof(float)); 6223b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower tensorflow::testing::UseRealTime(); 6233b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower tensorflow::testing::StartTiming(); 6243b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower for (int i = 0; i < num_iters; ++i) { 6253b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower auto result = executable->Run({}, options); 6263b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower ASSERT_TRUE(result.ok()); 6273b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower } 6283b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower} 6293b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower 6303b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlowerBENCHMARK(BM_ParallelFusion); 6313b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower 6321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace 6331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace xla 6341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsint main(int argc, char** argv) { 6361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<tensorflow::Flag> flag_list; 6379641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); 6381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); 6391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); 6401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (!parse_result) { 6411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LOG(ERROR) << "\n" << usage; 6421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return 2; 6431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins testing::InitGoogleTest(&argc, argv); 6451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (argc > 1) { 6461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; 6471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return 2; 6481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6493b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower tensorflow::testing::RunBenchmarks(); 6501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return RUN_ALL_TESTS(); 6511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 652