11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License"); 41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License. 51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at 61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins http://www.apache.org/licenses/LICENSE-2.0 81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software 101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS, 111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and 131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License. 141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/ 151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <memory> 171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <numeric> 181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <vector> 191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/array2d.h" 211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/array4d.h" 221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation_builder.h" 231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/local_client.h" 241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/literal_util.h" 251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/statusor.h" 265c8acccfc9e90d694a8394f5522097bfe87379b2A. Unique TensorFlower#include "tensorflow/compiler/xla/test.h" 271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/client_library_test_base.h" 281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/literal_test_util.h" 291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/test_macros.h" 301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla { 321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace { 331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 345f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlowerclass BroadcastSimpleTest : public ClientLibraryTestBase { 355f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower public: 365f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower ComputationDataHandle BuildBinOp(HloOpcode op, 375f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower const ComputationDataHandle& lhs, 385f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower const ComputationDataHandle& rhs, 395f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower ComputationBuilder* builder) { 405f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower switch (op) { 415f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower case HloOpcode::kMinimum: { 425f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower return builder->Min(lhs, rhs); 435f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower } 445f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower case HloOpcode::kMaximum: { 455f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower return builder->Max(lhs, rhs); 465f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower } 475f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower case HloOpcode::kMultiply: { 485f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower return builder->Mul(lhs, rhs); 495f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower } 505f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower default: { 515f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower // Default to Add 525f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower return builder->Add(lhs, rhs); 535f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower } 545f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower } 555f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower } 565f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower 575f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower std::unique_ptr<GlobalData> MakeR3Data( 585f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower tensorflow::gtl::ArraySlice<int64> bounds, 595f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower tensorflow::gtl::ArraySlice<int64> minor_to_major, Shape* r3_shape, 605f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower Array3D<float>* r3_array, float start, float end, int seed) { 615f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); 625f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower r3_array->FillRandom(start, end, seed); 6346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto r3_data = Literal::CreateR3FromArray3D(*r3_array)->Relayout( 6446737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower LayoutUtil::MakeLayout(minor_to_major)); 655f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower std::unique_ptr<GlobalData> r3_global_data = 665f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower client_->TransferToServer(*r3_data).ConsumeValueOrDie(); 675f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower return r3_global_data; 685f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower } 695f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower 705f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower std::unique_ptr<GlobalData> MakeR2Data( 715f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower tensorflow::gtl::ArraySlice<int64> bounds, 725f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower tensorflow::gtl::ArraySlice<int64> minor_to_major, Shape* r2_shape, 735f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower Array2D<float>* r2_array, float start, float end, int seed) { 745f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); 755f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower r2_array->FillRandom(start, end, seed); 7646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto r2_data = Literal::CreateR2FromArray2D(*r2_array)->Relayout( 7746737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower LayoutUtil::MakeLayout(minor_to_major)); 785f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower std::unique_ptr<GlobalData> r2_global_data = 795f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower client_->TransferToServer(*r2_data).ConsumeValueOrDie(); 805f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower return r2_global_data; 815f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower } 825f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower 835f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower float ApplyOpToFloats(HloOpcode op, float lhs, float rhs) { 845f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower switch (op) { 855f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower case HloOpcode::kMinimum: { 865f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower return std::min(lhs, rhs); 875f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower } 885f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower case HloOpcode::kMaximum: { 895f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower return std::max(lhs, rhs); 905f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower } 915f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower case HloOpcode::kMultiply: { 925f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower return lhs * rhs; 935f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower } 945f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower case HloOpcode::kAdd: { 955f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower return lhs + rhs; 965f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower } 975f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower default: { 985f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower // Default to Add 996dd43ec8cb299459b835e50faa4f3ffad044098cA. Unique TensorFlower LOG(FATAL); 1005f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower } 1015f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower } 1025f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower } 1035f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower}; 1045f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower 1059992074410d0b8d7102b7a63ff5f01a1a4554357A. Unique TensorFlowerusing ::testing::HasSubstr; 1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) { 1081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder b(client_, TestName()); 1091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins b.Broadcast(b.ConstantR0<float>(1.5), {}); 1101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR0<float>(&b, 1.5, {}, ErrorSpec(0.0001)); 1111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) { 1141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder b(client_, TestName()); 1151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins b.Broadcast(b.ConstantR0<float>(2.25), {2, 3}); 1161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array2D<float> expected(2, 3, 2.25); 1171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001)); 1181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1200b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlowerXLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) { 1210b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower ComputationBuilder b(client_, TestName()); 1220b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower ComputationDataHandle src; 1230b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower std::unique_ptr<GlobalData> param_data = 1240b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower CreateR0Parameter<float>(2.25f, /*parameter_number=*/0, /*name=*/"src", 1250b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower /*builder=*/&b, /*data_handle=*/&src); 1260b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower 1270b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower b.Broadcast(src, {2, 3}); 1280b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower Array2D<float> expected(2, 3, 2.25); 1290b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower ComputeAndCompareR2<float>(&b, expected, {param_data.get()}, 1300b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower ErrorSpec(0.0001)); 1310b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower} 1320b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower 1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) { 1341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder b(client_, TestName()); 1351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins b.Broadcast(b.ConstantR0<float>(2.25), {2, 0}); 1361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array2D<float> expected(2, 0); 1371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001)); 1381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_0x2) { 1411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder b(client_, TestName()); 1421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins b.Broadcast(b.ConstantR0<float>(2.25), {0, 2}); 1431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array2D<float> expected(0, 2); 1441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001)); 1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { 1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder b(client_, TestName()); 1491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins b.Broadcast(b.ConstantR1<float>({1, 2, 3}), {2}); 1501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array2D<float> expected(2, 3); 1521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins expected(0, 0) = 1; 1531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins expected(0, 1) = 2; 1541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins expected(0, 2) = 3; 1551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins expected(1, 0) = 1; 1561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins expected(1, 1) = 2; 1571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins expected(1, 2) = 3; 1581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001)); 1591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 161074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins// Tests implicit broadcasting of PREDs. 162e56628b085ffa7922e5238537f6ebd6deee0f0ccA. Unique TensorFlowerXLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { 163074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins ComputationBuilder b(client_, TestName()); 164074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins 165074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins Array2D<bool> x_vals(2, 1); 166074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins x_vals(0, 0) = true; 167074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins x_vals(1, 0) = false; 168074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins Array3D<bool> y_vals(2, 2, 1); 169074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins y_vals(0, 0, 0) = false; 170074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins y_vals(0, 1, 0) = false; 171074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins y_vals(1, 0, 0) = true; 172074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins y_vals(1, 1, 0) = true; 173074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins 174074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins ComputationDataHandle x, y; 175074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins auto x_data = CreateR2Parameter<bool>(x_vals, 0, "x", &b, &x); 176074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins auto y_data = CreateR3Parameter<bool>(y_vals, 1, "y", &b, &y); 177e56628b085ffa7922e5238537f6ebd6deee0f0ccA. Unique TensorFlower b.And(x, y, /*broadcast_dimensions=*/{1, 2}); 178074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins 179074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins Array3D<bool> expected(2, 2, 1); 180074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins expected(0, 0, 0) = false; 181074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins expected(0, 1, 0) = false; 182074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins expected(1, 0, 0) = true; 183074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins expected(1, 1, 0) = false; 184074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins 185074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins ComputeAndCompareR3<bool>(&b, expected, {x_data.get(), y_data.get()}); 186074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins} 187074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins 1881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) { 1891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder b(client_, TestName()); 1901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins b.Broadcast(b.ConstantR1<float>({}), {2}); 1911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array2D<float> expected(2, 0); 1931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001)); 1941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(BroadcastSimpleTest, 1DToZeroElement2D) { 1971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder b(client_, TestName()); 1981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins b.Broadcast(b.ConstantR1<float>({1, 2, 3}), {0}); 1991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array2D<float> expected(0, 3); 2011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001)); 2021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 2031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { 2051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Verify that binary op and degenerate dimension broadcast work together in 2061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // the same operation. 2071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // 2081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // The lhs shape [1, 2] is first broadcast up to [2, 1, 2] using in-dimension 2091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // broadcasting (broadcast_dimensions {1, 2}), then is added to the rhs shape 2101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // [2, 3, 1]. Degenerate dimension broadcasting then broadcasts the size one 2111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // dimensions. 2121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder b(client_, TestName()); 2131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins b.Add(b.ConstantR2<float>({{1.0, 5.0}}), 21546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower b.ConstantLiteral(*Literal::CreateR3<float>( 2161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), 2171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*broadcast_dimensions=*/{1, 2}); 2181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto expected = 22046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, 22146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}}); 2221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 224688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower} 225688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 226688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerstruct R3ImplicitBroadcastSpec { 227688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower std::array<int64, 3> output_bounds; 228688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower std::array<int64, 3> minor2major_layout; 229688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower std::array<int64, 3> input_bounds; 230688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower HloOpcode op; 231688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower} kR3ImplicitBroadcastTestCases[] = { 232688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower {{{1, 1, 1}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd}, 233688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 5}}, HloOpcode::kMaximum}, 234688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 1}}, HloOpcode::kMinimum}, 235688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 1}}, HloOpcode::kMultiply}, 236688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd}, 237688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 5}}, HloOpcode::kAdd}, 238688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 4, 1}}, HloOpcode::kAdd}, 239688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 5}}, HloOpcode::kAdd}, 240688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower {{{3, 199, 5}}, {{2, 1, 0}}, {{1, 199, 1}}, HloOpcode::kMinimum}, 241688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower {{{3, 4, 199}}, {{2, 1, 0}}, {{1, 1, 199}}, HloOpcode::kAdd}, 242688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower}; 243688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 244688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerclass BroadcastR3ImplicitTest 245688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower : public BroadcastSimpleTest, 246688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower public ::testing::WithParamInterface<R3ImplicitBroadcastSpec> {}; 247688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 248688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerXLA_TEST_P(BroadcastR3ImplicitTest, Doit) { 249688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower const R3ImplicitBroadcastSpec& spec = GetParam(); 250688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ComputationBuilder builder(client_, TestName()); 2515f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower 2525f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower Shape r3_shape, r3_implicit_shape; 253688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower Array3D<float> r3_array(spec.output_bounds[0], spec.output_bounds[1], 254688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower spec.output_bounds[2]); 255688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower Array3D<float> r3_implicit_array(spec.input_bounds[0], spec.input_bounds[1], 256688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower spec.input_bounds[2]); 2575f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower 2585f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower std::unique_ptr<GlobalData> r3_global_data = 2595f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower MakeR3Data(spec.output_bounds, spec.minor2major_layout, &r3_shape, 2605f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower &r3_array, 1.0, 2.5, 56789); 261688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower std::unique_ptr<GlobalData> r3_implicit_global_data = 2625f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower MakeR3Data(spec.input_bounds, spec.minor2major_layout, &r3_implicit_shape, 2635f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower &r3_implicit_array, 1.0, 0.2, 56789); 264688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 265688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower auto r3_implicit_parameter = builder.Parameter(0, r3_implicit_shape, "input"); 266688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower auto r3_parameter = builder.Parameter(1, r3_shape, "input"); 2675f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower ComputationDataHandle op = 2685f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder); 269688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 270688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower Array3D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1], 271688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower spec.output_bounds[2]); 272688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower auto Each = ([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) { 273688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower float r3_implicit = r3_implicit_array(indices[0] % spec.input_bounds[0], 274688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower indices[1] % spec.input_bounds[1], 275688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower indices[2] % spec.input_bounds[2]); 276688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower float r3 = r3_array(indices[0], indices[1], indices[2]); 2775f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower *value = ApplyOpToFloats(spec.op, r3_implicit, r3); 278688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower }); 279688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 280688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower int n1 = expected_array.n1(); 281688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower int n2 = expected_array.n2(); 282688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower int n3 = expected_array.n3(); 283688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower for (int64 i = 0; i < n1; i++) { 284688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower for (int64 j = 0; j < n2; j++) { 285688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower for (int64 k = 0; k < n3; k++) { 286688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower Each({i, j, k}, &expected_array(i, j, k)); 287688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower } 288688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower } 289688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower } 29046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto expected = Literal::CreateR3FromArray3D(expected_array); 291688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ComputeAndCompareLiteral( 292688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower &builder, *expected, 293688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower {r3_implicit_global_data.get(), r3_global_data.get()}, 294688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ErrorSpec(1e-7, 1e-7)); 295688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower} 296688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 297688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerINSTANTIATE_TEST_CASE_P(BroadcastR3ImplicitTestInstances, 298688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower BroadcastR3ImplicitTest, 299688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ::testing::ValuesIn(kR3ImplicitBroadcastTestCases)); 300688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 301688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower// r1 and r3's dim0 matches, and r1's dim1 and dim2 have size 1: 302688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerXLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { 303688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ComputationBuilder b(client_, TestName()); 304688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ComputationDataHandle r1h; 305688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ComputationDataHandle r3h; 306688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 307688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower Array3D<float> r1d = {{{1}}, {{2}}}; 308688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower Array3D<float> r3d = {{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}; 309688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower auto r1 = CreateR3Parameter(r1d, 1, "r1", &b, &r1h); 310688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower auto r3 = CreateR3Parameter(r3d, 0, "r3", &b, &r3h); 311688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 312688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower b.Add(r3h, r1h); 313688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 314688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower auto expected = 31546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); 316688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 317688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()}, 318688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ErrorSpec(0.0001)); 319688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower} 320688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 321688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerXLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { 322688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ComputationBuilder b(client_, TestName()); 32346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1, 2}}})); 324688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower auto r3 = b.ConstantLiteral( 32546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); 326688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower b.Add(r3, r1); 327688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 328688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower auto expected = 32946737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); 330688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 331688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 332688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower} 333688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 334688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerXLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { 335688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ComputationBuilder b(client_, TestName()); 33646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1}, {2}}})); 337688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower auto r3 = b.ConstantLiteral( 33846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); 339688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower b.Add(r3, r1); 340688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 341688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower auto expected = 34246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); 343688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 344688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 345688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower} 346688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 347688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerXLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { 348688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ComputationBuilder b(client_, TestName()); 34946737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1, 2}, {3, 4}}})); 350688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower auto r3 = b.ConstantLiteral( 35146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); 352688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower b.Add(r3, r1); 353688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 354688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower auto expected = 35546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); 356688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 357688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 358688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower} 359688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 360688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerXLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { 361688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ComputationBuilder b(client_, TestName()); 36246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1, 2}}, {{3, 4}}})); 363688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower auto r3 = b.ConstantLiteral( 36446737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); 365688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower b.Add(r3, r1); 366688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 367688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower auto expected = 36846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); 369688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 370688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 371688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower} 372688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 373688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerXLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { 374688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ComputationBuilder b(client_, TestName()); 37546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto r1 = 37646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower b.ConstantLiteral(*Literal::CreateR3<float>({{{1}, {2}}, {{3}, {4}}})); 377688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower auto r3 = b.ConstantLiteral( 37846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); 379688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower b.Add(r3, r1); 380688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 381688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower auto expected = 38246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); 383688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 384688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 385688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower} 386688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 387688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerXLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) { 388688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ComputationBuilder b(client_, TestName()); 38946737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1}}})); 390688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower auto r3 = b.ConstantLiteral( 39146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); 392688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower b.Add(r3, r1); 393688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 394688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower auto expected = 39546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); 396688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 397688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 398688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower} 399688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 4005f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlowerstruct R2ImplicitBroadcastSpec { 4015f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower std::array<int64, 2> output_bounds; 4025f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower std::array<int64, 2> minor2major_layout; 4035f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower std::array<int64, 2> input_bounds1; 4045f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower std::array<int64, 2> input_bounds2; 4055f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode op1; 4065f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode op2; 4075f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower} kR2ImplicitBroadcastTestCases[] = { 4085f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{{2, 3}}, {{1, 0}}, {{2, 1}}, {{2, 1}}, HloOpcode::kAdd, HloOpcode::kAdd}, 4095f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{{2, 3}}, {{1, 0}}, {{2, 1}}, {{1, 3}}, HloOpcode::kAdd, HloOpcode::kAdd}, 4105f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{{2, 3}}, 4115f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{1, 0}}, 4125f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{2, 1}}, 4135f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{1, 1}}, 4145f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kAdd, 4155f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kMinimum}, 4165f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{{2, 3}}, 4175f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{1, 0}}, 4185f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{1, 3}}, 4195f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{1, 1}}, 4205f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kAdd, 4215f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kMinimum}, 4225f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{{2, 3}}, 4235f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{1, 0}}, 4245f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{1, 1}}, 4255f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{1, 1}}, 4265f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kAdd, 4275f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kMinimum}, 4285f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{{2, 3}}, {{0, 1}}, {{2, 1}}, {{2, 1}}, HloOpcode::kAdd, HloOpcode::kAdd}, 4295f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{{150, 150}}, 4305f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{1, 0}}, 4315f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{150, 1}}, 4325f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{150, 1}}, 4335f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kAdd, 4345f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kAdd}, 4355f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{{150, 150}}, 4365f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{1, 0}}, 4375f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{150, 1}}, 4385f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{1, 150}}, 4395f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kAdd, 4405f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kAdd}, 4415f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{{150, 150}}, 4425f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{1, 0}}, 4435f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{150, 1}}, 4445f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{1, 1}}, 4455f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kAdd, 4465f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kAdd}, 4475f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{{50, 150}}, 4485f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{1, 0}}, 4495f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{50, 1}}, 4505f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{50, 1}}, 4515f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kAdd, 4525f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kAdd}, 4535f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{{50, 150}}, 4545f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{1, 0}}, 4555f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{50, 1}}, 4565f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{1, 150}}, 4575f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kAdd, 4585f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kAdd}, 4595f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{{50, 150}}, 4605f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{1, 0}}, 4615f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{50, 1}}, 4625f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{1, 1}}, 4635f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kAdd, 4645f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kAdd}, 4655f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{{150, 50}}, 4665f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{1, 0}}, 4675f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{150, 1}}, 4685f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{150, 1}}, 4695f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kAdd, 4705f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kAdd}, 4715f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{{150, 50}}, 4725f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{1, 0}}, 4735f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{150, 1}}, 4745f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{1, 50}}, 4755f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kAdd, 4765f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kAdd}, 4775f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{{150, 50}}, 4785f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{1, 0}}, 4795f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{150, 1}}, 4805f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {{1, 1}}, 4815f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kAdd, 4825f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower HloOpcode::kAdd}}; 4835f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower 4845f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlowerclass BroadcastR2ImplicitTest 4855f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower : public BroadcastSimpleTest, 4865f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower public ::testing::WithParamInterface<R2ImplicitBroadcastSpec> {}; 4875f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower 4885f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower// Test r2 op1 r2_implicit_1 op2 r2_implicit_2 4895f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower// where R2 is a rank-2 operand, and r2_implicit_2 are two 4905f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower// rank-2 operands with degenerate dimensions: 4915f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlowerXLA_TEST_P(BroadcastR2ImplicitTest, Doit) { 4925f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower const R2ImplicitBroadcastSpec& spec = GetParam(); 4935f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower 4945f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower ComputationBuilder builder(client_, TestName()); 4955f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower 4965f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower // Operands with degenerate dimensions require implicit broadcasting: 4975f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower Shape r2_shape, r2_implicit_shape1, r2_implicit_shape2; 4985f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower Array2D<float> r2_array(spec.output_bounds[0], spec.output_bounds[1]); 4995f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower Array2D<float> r2_implicit_array1(spec.input_bounds1[0], 5005f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower spec.input_bounds1[1]); 5015f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower Array2D<float> r2_implicit_array2(spec.input_bounds2[0], 5025f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower spec.input_bounds2[1]); 5035f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower 5045f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower std::unique_ptr<GlobalData> r2_global_data = 5055f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower MakeR2Data(spec.output_bounds, spec.minor2major_layout, &r2_shape, 5065f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower &r2_array, 1.0, 2.5, 56789); 5075f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower std::unique_ptr<GlobalData> r2_implicit_global_data1 = 5085f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower MakeR2Data(spec.input_bounds1, spec.minor2major_layout, 5095f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower &r2_implicit_shape1, &r2_implicit_array1, 1.0, 0.2, 56789); 5105f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower std::unique_ptr<GlobalData> r2_implicit_global_data2 = 5115f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower MakeR2Data(spec.input_bounds2, spec.minor2major_layout, 5125f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower &r2_implicit_shape2, &r2_implicit_array2, 0.8, 0.4, 56789); 5135f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower 5145f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower auto r2_implicit_parameter1 = 5155f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower builder.Parameter(0, r2_implicit_shape1, "input0"); 5165f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower auto r2_parameter = builder.Parameter(1, r2_shape, "input1"); 5175f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower auto r2_implicit_parameter2 = 5185f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower builder.Parameter(2, r2_implicit_shape2, "input2"); 5195f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower 5205f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower ComputationDataHandle op1 = 5215f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower BuildBinOp(spec.op1, r2_implicit_parameter1, r2_parameter, &builder); 5225f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower ComputationDataHandle op2 = 5235f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder); 5245f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower 5255f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower Array2D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1]); 5265f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower 5275f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower expected_array.Each([&](int64 i, int64 j, float* v) { 5285f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower float v1 = r2_implicit_array1(i % spec.input_bounds1[0], 5295f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower j % spec.input_bounds1[1]); 5305f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower float v2 = r2_array(i, j); 5315f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower float v3 = r2_implicit_array2(i % spec.input_bounds2[0], 5325f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower j % spec.input_bounds2[1]); 5335f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower float tmp = ApplyOpToFloats(spec.op1, v1, v2); 5345f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower *v = ApplyOpToFloats(spec.op2, tmp, v3); 5355f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower }); 5365f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower 53746737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto expected = Literal::CreateR2FromArray2D(expected_array); 5385f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower ComputeAndCompareLiteral( 5395f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower &builder, *expected, 5405f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower {r2_implicit_global_data1.get(), r2_global_data.get(), 5415f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower r2_implicit_global_data2.get()}, 5425f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower ErrorSpec(1e-6, 1e-6)); 5435f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower} 5445f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower 5455f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlowerINSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances, 5465f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower BroadcastR2ImplicitTest, 5475f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower ::testing::ValuesIn(kR2ImplicitBroadcastTestCases)); 5485f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower 549688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerXLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { 550688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ComputationBuilder b(client_, TestName()); 55146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto r1 = b.ConstantLiteral(*Literal::CreateR2<float>({{1, 2}})); 55246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto r2 = b.ConstantLiteral(*Literal::CreateR2<float>({{1, 2}, {3, 4}})); 553688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower b.Add(r2, r1); 554688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 55546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto expected = Literal::CreateR2<float>({{2, 4}, {4, 6}}); 556688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 557688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 558688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower} 559688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 560688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerXLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { 561688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ComputationBuilder b(client_, TestName()); 56246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto r1 = b.ConstantLiteral(*Literal::CreateR2<float>({{1}, {2}})); 56346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto r2 = b.ConstantLiteral(*Literal::CreateR2<float>({{1, 2}, {3, 4}})); 564688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower b.Add(r2, r1); 565688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 56646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto expected = Literal::CreateR2<float>({{2, 3}, {5, 6}}); 567688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower 568688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 5691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 571902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt RouneXLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { 572902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune ComputationBuilder b(client_, TestName()); 573902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune auto r1 = b.ConstantR1<float>({10, 20}); 574902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune auto r3 = b.ConstantLiteral( 57546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); 576902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune b.Add(r3, r1, {0}); 577902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune 57846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto expected = 57946737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR3<float>({{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); 580902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune 581902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 582902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune} 583902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune 584902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt RouneXLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { 585902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune ComputationBuilder b(client_, TestName()); 586902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune auto r1 = b.ConstantR1<float>({10, 20}); 587902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune auto r3 = b.ConstantLiteral( 58846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); 589902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune b.Add(r1, r3, {1}); 590902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune 59146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto expected = 59246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR3<float>({{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); 593902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune 594902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 595902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune} 596902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune 597902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt RouneXLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { 598902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune ComputationBuilder b(client_, TestName()); 599902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune auto r1 = b.ConstantR1<float>({10, 20}); 600902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune auto r3 = b.ConstantLiteral( 60146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); 602902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune b.Add(r1, r3, {2}); 603902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune 60446737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto expected = 60546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower Literal::CreateR3<float>({{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); 606902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune 607902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 608902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune} 609902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune 610902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt RouneXLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { 611902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune ComputationBuilder b(client_, TestName()); 612902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune auto r1_0 = b.ConstantR1<float>({1000, 2000}); 613902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune auto r1_1 = b.ConstantR1<float>({100, 200}); 614902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune auto r1_2 = b.ConstantR1<float>({10, 20}); 615902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune auto r3 = b.ConstantLiteral( 61646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); 617902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune for (int i = 0; i < 3; ++i) { 618902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune r3 = b.Add(r1_0, r3, {0}); 619902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune r3 = b.Add(r3, r1_1, {1}); 620902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune r3 = b.Add(r1_2, r3, {2}); 621902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune } 622902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune r3 = b.Mul(r3, b.ConstantR0<float>(-2)); 623902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune 62446737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto expected = Literal::CreateR3<float>( 625902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}}, 626902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}}); 627902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune 628902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 629902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune} 630902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune 631ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt RouneXLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { 632ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune ComputationBuilder b(client_, TestName()); 633ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune auto r1_0 = b.ConstantR1<float>({1000, 2000}); 634ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune auto r1_1 = b.ConstantR1<float>({100, 200}); 635ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune auto r1_2 = b.ConstantR1<float>({10, 20}); 636ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune auto r0 = b.ConstantR0<float>(3); 637ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune auto r3 = b.Broadcast(r0, {2, 2, 2}); 638ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune for (int i = 0; i < 3; ++i) { 639ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune r3 = b.Add(r1_0, r3, {0}); 640ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune r3 = b.Add(r3, r1_1, {1}); 641ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune r3 = b.Add(r1_2, r3, {2}); 642ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune } 643ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune r3 = b.Mul(r3, b.ConstantR0<float>(-1)); 644ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune 64546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower auto expected = Literal::CreateR3<float>( 646ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}}, 647ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}}); 648ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune 649ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 650ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune} 651ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune 6521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { 6531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Binary dimension broadcasting of the smaller lhs ([2, 2] up to [2, 2, 2]) 6541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // results in a shape incompatible with the lhs [2, 3, 1]. 6551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder b(client_, TestName()); 6561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins b.Add(b.ConstantR2<float>({{1.0, 5.0}, {1.0, 5.0}}), 65846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower b.ConstantLiteral(*Literal::CreateR3<float>( 6591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), 6601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*broadcast_dimensions=*/{1, 2}); 6611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result_status = Execute(&b, {}); 6631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins EXPECT_FALSE(result_status.ok()); 6645c8acccfc9e90d694a8394f5522097bfe87379b2A. Unique TensorFlower EXPECT_THAT(result_status.status().error_message(), 6659992074410d0b8d7102b7a63ff5f01a1a4554357A. Unique TensorFlower HasSubstr("broadcast dimension 0 mismatch")); 6661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 6671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { 6691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Test invalid broadcasting with [1, 2] and [2, 3] inputs. 6701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder b(client_, TestName()); 6711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins b.Add(b.ConstantR2<float>({{1.0, 2.0}}), 6731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins b.ConstantR2<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); 6741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result_status = Execute(&b, {}); 6761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins EXPECT_FALSE(result_status.ok()); 6779992074410d0b8d7102b7a63ff5f01a1a4554357A. Unique TensorFlower EXPECT_THAT(result_status.status().error_message(), 6789992074410d0b8d7102b7a63ff5f01a1a4554357A. Unique TensorFlower HasSubstr("binary op BINOP_ADD with incompatible shapes")); 6791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 6801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { 6821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Test invalid broadcasting with [1, 2] and [2, 3] inputs. 6831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder b(client_, TestName()); 6841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins b.Add(b.ConstantR2<float>({{1.0, 2.0}}), 6861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins b.ConstantR2<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); 6871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto result_status = Execute(&b, {}); 6891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins EXPECT_FALSE(result_status.ok()); 6909992074410d0b8d7102b7a63ff5f01a1a4554357A. Unique TensorFlower EXPECT_THAT(result_status.status().error_message(), 6919992074410d0b8d7102b7a63ff5f01a1a4554357A. Unique TensorFlower HasSubstr("binary op BINOP_ADD with incompatible shapes")); 6921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 6931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 6941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace 6951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace xla 696