11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License"); 41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License. 51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at 61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins http://www.apache.org/licenses/LICENSE-2.0 81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software 101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS, 111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and 131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License. 141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/ 151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <math.h> 171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <algorithm> 181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <memory> 191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <new> 20a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower#include <random> 211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <utility> 221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 23a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower#define EIGEN_USE_THREADS 24a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower 25a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/array2d.h" 273b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/compiler/xla/client/client_library.h" 283b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/compiler/xla/client/computation.h" 293b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/compiler/xla/client/computation_builder.h" 301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/literal_util.h" 311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/primitive_util.h" 321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/ptr_util.h" 331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/hlo_computation.h" 341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/hlo_instruction.h" 351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/hlo_module.h" 361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/hlo_opcode.h" 373b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/compiler/xla/service/platform_util.h" 381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/shape_util.h" 393b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/compiler/xla/tests/client_library_test_base.h" 401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/hlo_test_base.h" 411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/literal_test_util.h" 421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/test_macros.h" 431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/xla_data.pb.h" 44a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower#include "tensorflow/core/common_runtime/eigen_thread_pool.h" 451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/lib/gtl/array_slice.h" 461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/logging.h" 471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/protobuf.h" 483b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower#include "tensorflow/core/platform/test_benchmark.h" 491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/types.h" 501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsusing tensorflow::gtl::ArraySlice; 521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 533b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlowernamespace se = ::perftools::gputools; 543b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower 551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla { 561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace { 571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsconst int test_width = 2, test_height = 3; 591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsconst float test_float_vals[3][test_width][test_height] = { 611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {{-1.0, -1.0, 1.0}, {-3.0, 0.0, -1.0}}, 621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {{-3.0, 2.0, 1.0}, {0.0, -3.0, 1.0}}, 631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {{-3.0, 0.0, -3.0}, {-1.0, -2.0, 1.0}}}; 641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// Test whether fusion operations are emitted with no errors and compute 661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// accurate outputs. 671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass FusionTest : public HloTestBase { 681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins protected: 691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins template <typename T, int Arity> 701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins void TestElementwise2D(HloOpcode opcode) { 711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array2D<float> operand_data[Arity]; 721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i = 0; i < Arity; ++i) { 731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins new (&operand_data[i]) Array2D<float>(test_width, test_height); 741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array2D<T> answer_data(test_width, test_height); 761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i = 0; i < test_width; ++i) { 771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int j = 0; j < test_height; ++j) { 781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins float xs[Arity]; 791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int k = 0; k < Arity; ++k) { 801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins xs[k] = test_float_vals[k][i][j]; 811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins operand_data[k](i, j) = xs[k]; 821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins answer_data(i, j) = ComputeElementwiseAnswer<T>(opcode, xs); 841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 889641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto prim_type = primitive_util::NativeToPrimitiveType<T>(); 911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction* hlos[4]; 931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i = 0; i < Arity; ++i) { 941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlos[i + 1] = builder.AddInstruction(HloInstruction::CreateConstant( 9546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2FromArray2D(operand_data[i]))); 961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto answer_shape = 981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(prim_type, {test_width, test_height}); 991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::unique_ptr<HloInstruction> root_hlo; 1001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins switch (Arity) { 1011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case 1: 1021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins root_hlo = HloInstruction::CreateUnary(answer_shape, opcode, hlos[1]); 1031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins break; 1041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case 2: 1051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1], 1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlos[2]); 1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins break; 1081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case 3: 1091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins root_hlo = HloInstruction::CreateTernary(answer_shape, opcode, hlos[1], 1101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlos[2], hlos[3]); 1111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins break; 1121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins default: 1131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LOG(FATAL) << "Bad arity: " << Arity; 1141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlos[0] = builder.AddInstruction(std::move(root_hlo)); 1161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 1171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction( 1181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ArraySlice<HloInstruction*>(hlos, 0, Arity + 1), 1191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 1201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 12146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto expected = Literal::CreateR2FromArray2D(answer_data); 1221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto actual = ExecuteAndTransfer(std::move(hlo_module), {}); 1231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (primitive_util::IsFloatingPointType(prim_type)) { 1241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LiteralTestUtil::ExpectNear(*expected, *actual, ErrorSpec(1e-4)); 1251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } else { 1261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LiteralTestUtil::ExpectEqual(*expected, *actual); 1271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins private: 1311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins template <typename T> 1321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins T ComputeElementwiseAnswer(HloOpcode opcode, ArraySlice<float> xs); 1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}; 1341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinstemplate <> 1361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsfloat FusionTest::ComputeElementwiseAnswer<float>(HloOpcode opcode, 1371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ArraySlice<float> xs) { 1381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins switch (opcode) { 1391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kAdd: 1401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return xs[0] + xs[1]; 1411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kSubtract: 1421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return xs[0] - xs[1]; 1431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kMultiply: 1441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return xs[0] * xs[1]; 1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kDivide: 1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return xs[0] / xs[1]; 1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kPower: 1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return powf(xs[0], xs[1]); 1491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kMinimum: 1501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return std::min(xs[0], xs[1]); 1511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kMaximum: 1521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return std::max(xs[0], xs[1]); 1531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kClamp: 1541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return std::min(xs[2], std::max(xs[1], xs[0])); 1551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins default: 1561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LOG(FATAL) << "No elementwise opcode: " << opcode; 1571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinstemplate <> 16111698cc8e157eefe71a60931f1e721ad327e08afMark Heffernanbool FusionTest::ComputeElementwiseAnswer<bool>(HloOpcode opcode, 16211698cc8e157eefe71a60931f1e721ad327e08afMark Heffernan ArraySlice<float> xs) { 1631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins switch (opcode) { 1641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kEq: 1651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return xs[0] == xs[1]; 1661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kNe: 1671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return xs[0] != xs[1]; 1681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kGt: 1691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return xs[0] > xs[1]; 1701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kLt: 1711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return xs[0] < xs[1]; 1721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kGe: 1731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return xs[0] >= xs[1]; 1741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins case HloOpcode::kLe: 1751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return xs[0] <= xs[1]; 1761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins default: 1771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LOG(FATAL) << "No comparatory opcode: " << opcode; 1781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Test) { 1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // test expression: 1831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // slice(select({{T, F, T}, {F, T, F}}, 1841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // concat(transpose({{1.0}, {2.0}, {3.0}} + 1851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // {{-1.0}, {-1.0}, {-1.0}}), 1861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // {{1.62, 2.72, 3.14}}) + 1871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // (-{{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}), 1881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // {{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})) = {{0.5}, {2.72}} 1891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 1909641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 1911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 19246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<float>({{1.0}, {2.0}, {3.0}}))); 1931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( 19446737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<float>({{-1.0}, {-1.0}, {-1.0}}))); 1951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( 1961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(F32, {3, 1}), HloOpcode::kAdd, const0, const1)); 1971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reshape3 = builder.AddInstruction(HloInstruction::CreateTranspose( 1981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(F32, {1, 3}), add2, {1, 0})); 1991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const4 = builder.AddInstruction(HloInstruction::CreateConstant( 20046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<float>({{1.62, 2.72, 3.14}}))); 2011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto concat5 = builder.AddInstruction(HloInstruction::CreateConcatenate( 2021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(F32, {2, 3}), {reshape3, const4}, 0)); 2031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const6 = builder.AddInstruction(HloInstruction::CreateConstant( 20446737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<float>({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}))); 2051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto negate7 = builder.AddInstruction(HloInstruction::CreateUnary( 2061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kNegate, const6)); 2071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto add8 = builder.AddInstruction(HloInstruction::CreateBinary( 2081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kAdd, concat5, negate7)); 2091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const9 = builder.AddInstruction(HloInstruction::CreateConstant( 21046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<float>({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}}))); 21146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto const10 = builder.AddInstruction(HloInstruction::CreateConstant( 21246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<bool>({{true, false, true}, {false, true, false}}))); 2131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto select11 = builder.AddInstruction( 2141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}), 2151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloOpcode::kSelect, const10, add8, const9)); 2161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto slice12 = builder.AddInstruction(HloInstruction::CreateSlice( 21750b999a8336d19400ab75aea66fe46eca2f5fe0bA. Unique TensorFlower ShapeUtil::MakeShape(F32, {2, 1}), select11, {0, 1}, {2, 2}, {1, 1})); 2181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // CreateFusionInstruction needs the `instructions_to_fuse` argument in 2191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // reverse topological order, so the first element in `instructions_to_fuse` 2201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // must be the root. 2211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 2221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction( 2231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {slice12, select11, const10, const9, add8, negate7, const6, concat5, 2241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const4, reshape3, add2, const1, const0}, 2251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 2261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 22746737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>({{0.5}, {2.72}}), 2281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {}), 2291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ErrorSpec(1e-4)); 2301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 2311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// Test whether we emit appropriate code for parameters of fusion instructions. 2331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Parameter) { 2341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Build a computation and fuse part of it so the fusion instruction has an 2351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // operand parameter. 2361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 2379641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 2381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 23946737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<float>({{1.0, 2.0, 3.0}}))); 2401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary( 2411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kCopy, const0)); 2421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const2 = builder.AddInstruction(HloInstruction::CreateConstant( 24346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<float>({{-2.0, -2.0, -2.0}}))); 2441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // add3 = copy1 + const2 = const0 + const2 = {1,2,3} + {-2,-2,-2} = {-1,0,+1} 2451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( 2461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kAdd, copy1, const2)); 2471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // CreateFusionInstruction needs `instructions_to_fuse` in reverse topological 2481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // order. 2491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 2501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2}, 2511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 2521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 25346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>({{-1.0, 0.0, 1.0}}), 2541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {}), 2551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ErrorSpec(1e-4)); 2561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 2571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 258a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlowerXLA_TEST_F(FusionTest, RandomizedParallelPartition) { 259a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Tests parallel partitioning of a fusion instruction. 260a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Create shape with random outer dimension size to generate random parallel 261a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // partition counts for each test run. 262a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower const int seed = tensorflow::testing::RandomSeed(); 263a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower LOG(INFO) << "RandomizedParallelPartition seed: " << seed; 264a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower std::mt19937 generator(seed); 265a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower std::uniform_int_distribution<int> distribution(128, 1024); 266a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower const int64 rand_dim0_size = distribution(generator); 267a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower const int64 dim1_size = 1024; 268a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower Shape shape = 269a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower ShapeUtil::MakeShapeWithLayout(F32, {rand_dim0_size, dim1_size}, {1, 0}); 270a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Build simple fusion computation: y = x^2 (elementwise). 271a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower auto builder = HloComputation::Builder(TestName()); 272a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower auto hlo_module = CreateNewModule(); 273a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower 274a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower auto two = builder.AddInstruction( 275a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 276a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower auto x = 277a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower builder.AddInstruction(HloInstruction::CreateBroadcast(shape, two, {})); 278a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower auto y = builder.AddInstruction( 279a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, x, x)); 280a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower 281a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower hlo_module->AddEntryComputation(builder.Build()) 282a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower ->CreateFusionInstruction(/*instructions_to_fuse=*/{y, x, two}, 283a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower HloInstruction::FusionKind::kLoop); 284a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Compute result. 285a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower auto result = ExecuteAndTransfer(std::move(hlo_module), {}); 286a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Every element of result should be y = x^2 = 4.0. 287a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower for (int i = 0; i < rand_dim0_size; ++i) { 288a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower for (int j = 0; j < dim1_size; ++j) { 289a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower EXPECT_EQ(4.0, result->Get<float>({i, j})); 290a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower } 291a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower } 292a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower} 293a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower 2941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { 2951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 2969641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 2971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant( 29846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR1<float>({1.0, 2.0, 3.0}))); 2991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const_array = builder.AddInstruction(HloInstruction::CreateConstant( 30046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<float>({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}))); 3011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto broadcast = builder.AddInstruction( 3021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::CreateBroadcast(const_array->shape(), const_vector, {1})); 3031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // add2 = broadcast(const_vector) + const_array 3041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // = broadcast({1,2,3}) + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}} 3051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // = {{1, 2, 3}, {1, 2, 3}} + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}} 3061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto add2 = builder.AddInstruction( 3071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {2, 3}), 3081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloOpcode::kAdd, broadcast, const_array)); 3091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 3101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast}, 3111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 3121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LiteralTestUtil::ExpectNear( 31446737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), 3151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)); 3161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 3171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, ReshapeToScalar) { 3191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 3209641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 3211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto single_element_array = builder.AddInstruction( 32246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR2<int32>({{5}}))); 3231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( 3241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(S32, {}), single_element_array)); 3251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 3261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape}, 3271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 32846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(5), 3291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 3301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 3311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { 3331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 3349641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 3351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 33646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}))); 3371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( 3381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(S32, {1, 2, 3}), const0)); 3391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 3401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 3411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 3421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LiteralTestUtil::ExpectEqual( 34346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}), 3441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 3451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 3461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { 3481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 3499641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 3501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 35146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}))); 3521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reshape1 = builder.AddInstruction( 3531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 2}), const0)); 3541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 3551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 3561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 3571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LiteralTestUtil::ExpectEqual( 35846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}), 3591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 3601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 3611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape_1by1by1_) { 3631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 3649641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 3651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction( 36646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR3<int32>({{{7}}}))); 3671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reshape1 = builder.AddInstruction( 3681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0)); 3691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 3701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 3711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 37246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(7), 3731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 3741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 3751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape__1by1by1) { 3771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 3789641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 3791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction( 38046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR0<int32>(7))); 3811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( 3821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(S32, {1, 1, 1}), const0)); 3831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 3841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 3851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 38646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower LiteralTestUtil::ExpectEqual(*Literal::CreateR3<int32>({{{7}}}), 3871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 3881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 3891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape__) { 3911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 3929641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 3931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction( 39446737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR0<int32>(7))); 3951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reshape1 = builder.AddInstruction( 3961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0)); 3971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 3981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 3991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 40046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(7), 4011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 4021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 4031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reshape_3by3_3by3) { 4051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 4069641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 4071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 40846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); 4091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reshape1 = builder.AddInstruction( 4101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 3}), const0)); 4111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 4121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 4131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 4141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LiteralTestUtil::ExpectEqual( 41546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), 4161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 4171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 4181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Transpose_2by3) { 4201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 4219641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 4221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 42346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}}))); 4241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( 4251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(S32, {3, 2}), const0, {1, 0})); 4261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 4271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 4281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 4291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LiteralTestUtil::ExpectEqual( 43046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}), 4311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 4321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 4331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Transpose_3by3) { 4351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 4369641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 4371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 43846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); 4391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( 4401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(S32, {3, 3}), const0, {1, 0})); 4411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 4421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 4431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 4441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LiteralTestUtil::ExpectEqual( 44546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), 4461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 4471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 4481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Reverse) { 4501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 4519641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 4521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction( 45346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3}))); 4541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( 4551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(S32, {3}), const0, {0})); 4561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 4571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1}, 4581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 4591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 46046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({3, 2, 1}), 4611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 4621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 4631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 46446ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlowerXLA_TEST_F(FusionTest, ReverseNegate) { 46546ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto builder = HloComputation::Builder(TestName()); 46646ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto hlo_module = CreateNewModule(); 46746ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto const0 = builder.AddInstruction( 46846ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3}))); 46946ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( 47046ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower ShapeUtil::MakeShape(S32, {3}), const0, {0})); 47146ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( 47246ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower ShapeUtil::MakeShape(S32, {3}), HloOpcode::kNegate, reverse1)); 47346ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower hlo_module->AddEntryComputation(builder.Build()) 47446ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reverse1}, 47546ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower HloInstruction::FusionKind::kLoop); 47646ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower 47746ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-3, -2, -1}), 47846ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower *ExecuteAndTransfer(std::move(hlo_module), {})); 47946ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower} 48046ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower 48146ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlowerXLA_TEST_F(FusionTest, BroadcastNegate) { 48246ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto builder = HloComputation::Builder(TestName()); 48346ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto hlo_module = CreateNewModule(); 48446ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto const0 = builder.AddInstruction( 48546ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR0<int32>(1))); 48646ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast( 48746ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower ShapeUtil::MakeShape(S32, {2}), const0, {})); 48846ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( 48946ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, broadcast1)); 49046ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower hlo_module->AddEntryComputation(builder.Build()) 49146ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, broadcast1}, 49246ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower HloInstruction::FusionKind::kLoop); 49346ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower 49446ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-1, -1}), 49546ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower *ExecuteAndTransfer(std::move(hlo_module), {})); 49646ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower} 49746ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower 49846ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlowerXLA_TEST_F(FusionTest, SliceNegate) { 49946ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto builder = HloComputation::Builder(TestName()); 50046ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto hlo_module = CreateNewModule(); 50146ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto const0 = builder.AddInstruction( 50246ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3, 4}))); 50346ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto slice1 = builder.AddInstruction(HloInstruction::CreateSlice( 50446ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower ShapeUtil::MakeShape(S32, {2}), const0, {0}, {4}, {2})); 50546ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( 50646ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, slice1)); 50746ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower hlo_module->AddEntryComputation(builder.Build()) 50846ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, slice1}, 50946ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower HloInstruction::FusionKind::kLoop); 51046ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower 51146ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-1, -3}), 51246ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower *ExecuteAndTransfer(std::move(hlo_module), {})); 51346ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower} 51446ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower 51546ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlowerXLA_TEST_F(FusionTest, DynamicSliceNegate) { 51646ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto builder = HloComputation::Builder(TestName()); 51746ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto hlo_module = CreateNewModule(); 51846ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto const0 = builder.AddInstruction( 51946ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3, 4}))); 52046ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto const1 = builder.AddInstruction( 52146ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR1<int32>({1}))); 52246ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto dynamic_slice2 = 52346ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower builder.AddInstruction(HloInstruction::CreateDynamicSlice( 52446ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower ShapeUtil::MakeShape(S32, {2}), const0, const1, {2})); 52546ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto negate3 = builder.AddInstruction(HloInstruction::CreateUnary( 52646ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, dynamic_slice2)); 52746ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower hlo_module->AddEntryComputation(builder.Build()) 52846ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower ->CreateFusionInstruction( 52946ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower /*instructions_to_fuse=*/{negate3, dynamic_slice2}, 53046ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower HloInstruction::FusionKind::kLoop); 53146ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower 53246ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-2, -3}), 53346ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower *ExecuteAndTransfer(std::move(hlo_module), {})); 53446ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower} 53546ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower 53646ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlowerXLA_TEST_F(FusionTest, ReshapeNegate) { 53746ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto builder = HloComputation::Builder(TestName()); 53846ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto hlo_module = CreateNewModule(); 53946ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto const0 = builder.AddInstruction( 54046ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3, 4}))); 54146ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto reshape1 = builder.AddInstruction( 54246ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {2, 2}), const0)); 54346ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( 54446ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower ShapeUtil::MakeShape(S32, {2, 2}), HloOpcode::kNegate, reshape1)); 54546ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower hlo_module->AddEntryComputation(builder.Build()) 54646ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1}, 54746ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower HloInstruction::FusionKind::kLoop); 54846ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower 54946ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower LiteralTestUtil::ExpectEqual(*Literal::CreateR2<int32>({{-1, -2}, {-3, -4}}), 55046ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower *ExecuteAndTransfer(std::move(hlo_module), {})); 55146ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower} 55246ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower 55346ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower// TODO(b/64070202): Investigate failure. 55446ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlowerXLA_TEST_F(FusionTest, DISABLED_ON_GPU(TransposeNegate)) { 55546ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto builder = HloComputation::Builder(TestName()); 55646ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto hlo_module = CreateNewModule(); 55746ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 55846ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower Literal::CreateR2<int32>({{1, 2}, {3, 4}}))); 55946ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( 56046ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower ShapeUtil::MakeShape(S32, {2, 2}), const0, {1, 0})); 56146ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( 56246ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower ShapeUtil::MakeShape(S32, {2, 2}), HloOpcode::kNegate, transpose1)); 56346ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower hlo_module->AddEntryComputation(builder.Build()) 56446ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1}, 56546ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower HloInstruction::FusionKind::kLoop); 56646ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower 56746ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower LiteralTestUtil::ExpectEqual(*Literal::CreateR2<int32>({{-1, -3}, {-2, -4}}), 56846ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower *ExecuteAndTransfer(std::move(hlo_module), {})); 56946ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower} 57046ad960feeb59a985a49069a378e20e658dca436A. Unique TensorFlower 5711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsstd::unique_ptr<HloComputation> MakeReduceTestComputation() { 5721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder("add"); 5731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto lhs = builder.AddInstruction(HloInstruction::CreateParameter( 5741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*parameter_number=*/0, ShapeUtil::MakeShape(S32, {}), "lhs")); 5751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto rhs = builder.AddInstruction(HloInstruction::CreateParameter( 5761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*parameter_number=*/1, ShapeUtil::MakeShape(S32, {}), "rhs")); 5771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.AddInstruction(HloInstruction::CreateBinary( 5781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, lhs, rhs)); 5791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return builder.Build(); 5801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { 5839641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 5841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 58646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto const0 = builder.AddInstruction( 58746737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 4, 8}))); 5881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const1 = builder.AddInstruction( 58946737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))); 5901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce( 5911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(S32, {}), const0, const1, {0}, 5921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); 5931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 5941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2}, 5951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 5961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 59746737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(15), 5981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 5991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 6001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { 6029641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 6031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 60546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto const0 = builder.AddInstruction( 60646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 4, 8}))); 6071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const1 = builder.AddInstruction( 60846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))); 6091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce( 6101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(S32, {}), const0, const1, {0}, 6111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); 6121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto negate3 = builder.AddInstruction(HloInstruction::CreateUnary( 61311698cc8e157eefe71a60931f1e721ad327e08afMark Heffernan ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, reduce2)); 6141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 6151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2}, 6161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 6171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 61811698cc8e157eefe71a60931f1e721ad327e08afMark Heffernan LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(-15), 6191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 6201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 6211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { 6231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto builder = HloComputation::Builder(TestName()); 6249641b8edab3113f5bb83b5491de747dc9a43fe01Eli Bendersky auto hlo_module = CreateNewModule(); 6251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 62646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR2<int32>({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}}))); 6271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto const1 = builder.AddInstruction( 62846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower HloInstruction::CreateConstant(Literal::CreateR0<int32>(1))); 6291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Window window; 6301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ASSERT_TRUE( 6311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins tensorflow::protobuf::TextFormat::ParseFromString("dimensions:{\n" 6321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "size:2\n" 6331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "stride:1\n" 6341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "padding_low:0\n" 6351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "padding_high:0\n" 6361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "window_dilation:1\n" 6371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "base_dilation:1\n" 6381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "}\n" 6391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "dimensions:{\n" 6401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "size:2\n" 6411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "stride:1\n" 6421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "padding_low:0\n" 6431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "padding_high:0\n" 6441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "window_dilation:1\n" 6451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "base_dilation:1\n" 6461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "}\n", 6471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins &window)); 6481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto nested_builder = HloComputation::Builder("mul"); 6491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins { 6501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto x = nested_builder.AddInstruction( 6511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "x")); 6521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto y = nested_builder.AddInstruction( 6531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(S32, {}), "y")); 6541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins nested_builder.AddInstruction(HloInstruction::CreateBinary( 6551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply, x, y)); 6561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 6571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto nested_computation = 6581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEmbeddedComputation(nested_builder.Build()); 6591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto reduce_window2 = 6601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.AddInstruction(HloInstruction::CreateReduceWindow( 6611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(S32, {2, 2}), const0, const1, window, 6621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins nested_computation)); 6631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins hlo_module->AddEntryComputation(builder.Build()) 6641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2}, 6651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins HloInstruction::FusionKind::kLoop); 6661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins LiteralTestUtil::ExpectEqual( 66846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR2<int32>({{462, 2145}, {24871, 62491}}), 6691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins *ExecuteAndTransfer(std::move(hlo_module), {})); 6701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 6711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 672a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan// When a constant (or other op) which has multiple users is imported 673a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan// into a fusion, it should remain shared, rather than being duplicated 674a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan// within the fusion. 675a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay VasudevanXLA_TEST_F(FusionTest, SharedConstant) { 676a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan auto hlo_module = CreateNewModule(); 677a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan 678a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan auto builder = HloComputation::Builder(TestName()); 679a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan auto const0 = builder.AddInstruction( 680a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan HloInstruction::CreateConstant(Literal::CreateR1<int32>({0}))); 681a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan auto const1 = builder.AddInstruction( 682a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan HloInstruction::CreateConstant(Literal::CreateR1<int32>({2}))); 683a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( 684a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0)); 685a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( 686a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add1)); 687a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( 688a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add2)); 689a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan auto add4 = builder.AddInstruction(HloInstruction::CreateBinary( 690a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add3)); 691a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan hlo_module->AddEntryComputation(builder.Build()) 692a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan ->CreateFusionInstruction( 693a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan {add4, add3, add2, add1, const1}, 694a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan HloInstruction::FusionKind::kLoop); 695a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan 696a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan HloComputation* entry_comp = hlo_module->entry_computation(); 697a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan 698a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan // entry computation contains the constant(0) and the fusion 6999b1b5d85b9ce3c812dc772da1f3f5d09581e5b49Justin Lebar EXPECT_EQ(entry_comp->instruction_count(), 2); 700a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan 701a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan // fused instruction contains the constant(2), the parameter, and 4 adds 7029b1b5d85b9ce3c812dc772da1f3f5d09581e5b49Justin Lebar EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6); 703a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan 704a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({8}), 705a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan *ExecuteAndTransfer(std::move(hlo_module), {})); 706a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan} 707a1fba7f5ac3de39b106af36c3737ea854f09e9acVijay Vasudevan 7081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); } 7091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 7101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Subtract2D) { 7111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TestElementwise2D<float, 2>(HloOpcode::kSubtract); 7121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 7131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 7141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Multiply2D) { 7151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TestElementwise2D<float, 2>(HloOpcode::kMultiply); 7161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 7171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 7181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Divide2D) { 7191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TestElementwise2D<float, 2>(HloOpcode::kDivide); 7201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 7211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 7221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Power2D) { 7231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TestElementwise2D<float, 2>(HloOpcode::kPower); 7241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 7251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 7261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Minimum2D) { 7271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TestElementwise2D<float, 2>(HloOpcode::kMinimum); 7281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 7291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 7301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Maximum2D) { 7311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TestElementwise2D<float, 2>(HloOpcode::kMaximum); 7321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 7331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 73411698cc8e157eefe71a60931f1e721ad327e08afMark HeffernanXLA_TEST_F(FusionTest, Equal2D) { TestElementwise2D<bool, 2>(HloOpcode::kEq); } 7351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 7361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Inequal2D) { 73711698cc8e157eefe71a60931f1e721ad327e08afMark Heffernan TestElementwise2D<bool, 2>(HloOpcode::kNe); 7381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 7391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 7401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Greater2D) { 74111698cc8e157eefe71a60931f1e721ad327e08afMark Heffernan TestElementwise2D<bool, 2>(HloOpcode::kGt); 7421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 7431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 74411698cc8e157eefe71a60931f1e721ad327e08afMark HeffernanXLA_TEST_F(FusionTest, Lesser2D) { TestElementwise2D<bool, 2>(HloOpcode::kLt); } 7451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 7461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, GreaterOrEqual2D) { 74711698cc8e157eefe71a60931f1e721ad327e08afMark Heffernan TestElementwise2D<bool, 2>(HloOpcode::kGe); 7481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 7491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 7501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, LesserOrEqual2D) { 75111698cc8e157eefe71a60931f1e721ad327e08afMark Heffernan TestElementwise2D<bool, 2>(HloOpcode::kLe); 7521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 7531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 7541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(FusionTest, Clamp2D) { 7551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TestElementwise2D<float, 3>(HloOpcode::kClamp); 7561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 7571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 7583b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlowervoid BM_ParallelFusion(int num_iters) { 7593b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower // Simple element-wise computation to benchmark parallel task partitioning. 7603b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower tensorflow::testing::StopTiming(); 7613b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower 7623b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie(); 7633b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); 7643b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower StreamExecutorMemoryAllocator allocator(platform, executors); 7653b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower 766a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower const int64 intra_op_parallelism_threads = 24; 7673b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower xla::LocalClientOptions client_options; 7683b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower client_options.set_platform(platform); 7693b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower client_options.set_intra_op_parallelism_threads(intra_op_parallelism_threads); 7703b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower auto client = 7713b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie(); 7723b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower 773a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower int device_ordinal = client->default_device_ordinal(); 774a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower 775a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Computation shape parameters. 776a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower const int64 param0_dim0 = 1024; 777a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower const int64 param0_dim1 = 1024; 778a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower const int64 param1_dim0 = 1024; 779a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower const int64 param1_dim1 = 1024; 780a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower const int64 param2_dim0 = 1024; 781a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower const int64 param2_dim1 = 1024; 782a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower 783a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Create computation. 7843b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower ComputationBuilder builder(client, "ParallelFusion"); 785a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower Shape shape0 = ShapeUtil::MakeShape(F32, {param0_dim0, param0_dim1}); 786a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower auto param0 = builder.Parameter(0, shape0, "param0"); 787a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower Shape shape1 = ShapeUtil::MakeShape(F32, {param1_dim0, param1_dim1}); 788a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower auto param1 = builder.Parameter(1, shape1, "param1"); 789a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower Shape shape2 = ShapeUtil::MakeShape(F32, {param2_dim0, param2_dim1}); 790a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower auto param2 = builder.Parameter(2, shape2, "param2"); 791a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower 792a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower auto x = builder.Mul(param0, param1); 793a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower auto y = builder.Add(x, param2); 7943b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower auto computation = builder.Build().ConsumeValueOrDie(); 7953b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower 796a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Transfer literals to device. 797a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower auto param0_literal = 798a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower Literal::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1); 79922d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan std::unique_ptr<ShapedBuffer> buffer0 = 80022d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan client->LiteralToShapedBuffer(*param0_literal, device_ordinal) 801a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower .ConsumeValueOrDie(); 80222d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan 803a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower auto param1_literal = 804a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower Literal::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1); 80522d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan std::unique_ptr<ShapedBuffer> buffer1 = 80622d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan client->LiteralToShapedBuffer(*param1_literal, device_ordinal) 807a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower .ConsumeValueOrDie(); 80822d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan 809a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower auto param2_literal = 810a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower Literal::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1); 81122d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan std::unique_ptr<ShapedBuffer> buffer2 = 81222d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan client->LiteralToShapedBuffer(*param2_literal, device_ordinal) 81322d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan .ConsumeValueOrDie(); 814a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower 815a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Build executable. 8163b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower std::unique_ptr<LocalExecutable> executable = 817a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower client 818a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower ->Compile(computation, 819fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower {&buffer0->on_host_shape(), &buffer1->on_host_shape(), 820fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower &buffer2->on_host_shape()}, 821a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower ExecutableBuildOptions()) 8223b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower .ConsumeValueOrDie(); 8233b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower 82422d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan se::Stream stream(executors[device_ordinal]); 825a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower stream.Init(); 826a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower 827a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Initialize thread pool. 828a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen", 829a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower intra_op_parallelism_threads); 830a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower tensorflow::EigenThreadPoolWrapper tp(&pool); 831a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); 832a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower 833a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Initialize ExecutableRunOptions. 8341f1b2bb6c3833a472036da22b7c910f5f2bdf694Anna R ExecutableRunOptions options; 835a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower options.set_allocator(&allocator).set_stream(&stream); 836a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower options.set_intra_op_thread_pool(&device); 837a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower 838a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Run some warm-up executions. 8393b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower const int kWarmups = 2; 8403b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower for (int i = 0; i < kWarmups; ++i) { 841a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower auto result = 842a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower executable->Run({buffer0.get(), buffer1.get(), buffer2.get()}, options); 8433b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower ASSERT_TRUE(result.ok()); 8443b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower } 8453b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower 8463b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower // Run benchmark. 847a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower const int64 total_bytes = param0_dim0 * param0_dim0 + 848a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower param1_dim0 * param1_dim0 + 849a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower param2_dim0 * param2_dim0; 850a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower tensorflow::testing::BytesProcessed(static_cast<int64>(num_iters) * 851a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower total_bytes * sizeof(float)); 8523b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower tensorflow::testing::UseRealTime(); 8533b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower tensorflow::testing::StartTiming(); 8543b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower for (int i = 0; i < num_iters; ++i) { 855a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower auto result = 856a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower executable->Run({buffer0.get(), buffer1.get(), buffer2.get()}, options); 8573b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower ASSERT_TRUE(result.ok()); 8583b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower } 8593b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower} 8603b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower 8613b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlowerBENCHMARK(BM_ParallelFusion); 8623b41352a3177c2fe8a1329e8981b285bb6aacf8bA. Unique TensorFlower 8631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace 8641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace xla 865