1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include <memory> 17#include <numeric> 18#include <vector> 19 20#include "tensorflow/compiler/xla/array2d.h" 21#include "tensorflow/compiler/xla/array4d.h" 22#include "tensorflow/compiler/xla/client/computation_builder.h" 23#include "tensorflow/compiler/xla/client/local_client.h" 24#include "tensorflow/compiler/xla/literal_util.h" 25#include "tensorflow/compiler/xla/statusor.h" 26#include "tensorflow/compiler/xla/test.h" 27#include "tensorflow/compiler/xla/tests/client_library_test_base.h" 28#include "tensorflow/compiler/xla/tests/literal_test_util.h" 29#include "tensorflow/compiler/xla/tests/test_macros.h" 30 31namespace xla { 32namespace { 33 34class BroadcastSimpleTest : public ClientLibraryTestBase { 35 public: 36 ComputationDataHandle BuildBinOp(HloOpcode op, 37 const ComputationDataHandle& lhs, 38 const ComputationDataHandle& rhs, 39 ComputationBuilder* builder) { 40 switch (op) { 41 case HloOpcode::kMinimum: { 42 return builder->Min(lhs, rhs); 43 } 44 case HloOpcode::kMaximum: { 45 return builder->Max(lhs, rhs); 46 } 47 case HloOpcode::kMultiply: { 48 return builder->Mul(lhs, rhs); 49 } 50 default: { 51 // Default to Add 52 return builder->Add(lhs, rhs); 53 } 54 } 55 } 56 57 std::unique_ptr<GlobalData> MakeR3Data( 58 tensorflow::gtl::ArraySlice<int64> bounds, 59 tensorflow::gtl::ArraySlice<int64> minor_to_major, Shape* r3_shape, 60 Array3D<float>* r3_array, float start, float end, int seed) { 61 *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); 62 r3_array->FillRandom(start, end, seed); 63 auto r3_data = Literal::CreateR3FromArray3D(*r3_array)->Relayout( 64 LayoutUtil::MakeLayout(minor_to_major)); 65 std::unique_ptr<GlobalData> r3_global_data = 66 client_->TransferToServer(*r3_data).ConsumeValueOrDie(); 67 return r3_global_data; 68 } 69 70 std::unique_ptr<GlobalData> MakeR2Data( 71 tensorflow::gtl::ArraySlice<int64> bounds, 72 tensorflow::gtl::ArraySlice<int64> minor_to_major, Shape* r2_shape, 73 Array2D<float>* r2_array, float start, float end, int seed) { 74 *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); 75 r2_array->FillRandom(start, end, seed); 76 auto r2_data = Literal::CreateR2FromArray2D(*r2_array)->Relayout( 77 LayoutUtil::MakeLayout(minor_to_major)); 78 std::unique_ptr<GlobalData> r2_global_data = 79 client_->TransferToServer(*r2_data).ConsumeValueOrDie(); 80 return r2_global_data; 81 } 82 83 float ApplyOpToFloats(HloOpcode op, float lhs, float rhs) { 84 switch (op) { 85 case HloOpcode::kMinimum: { 86 return std::min(lhs, rhs); 87 } 88 case HloOpcode::kMaximum: { 89 return std::max(lhs, rhs); 90 } 91 case HloOpcode::kMultiply: { 92 return lhs * rhs; 93 } 94 case HloOpcode::kAdd: { 95 return lhs + rhs; 96 } 97 default: { 98 // Default to Add 99 LOG(FATAL); 100 } 101 } 102 } 103}; 104 105using ::testing::HasSubstr; 106 107XLA_TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) { 108 ComputationBuilder b(client_, TestName()); 109 b.Broadcast(b.ConstantR0<float>(1.5), {}); 110 ComputeAndCompareR0<float>(&b, 1.5, {}, ErrorSpec(0.0001)); 111} 112 113XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) { 114 ComputationBuilder b(client_, TestName()); 115 b.Broadcast(b.ConstantR0<float>(2.25), {2, 3}); 116 Array2D<float> expected(2, 3, 2.25); 117 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001)); 118} 119 120XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) { 121 ComputationBuilder b(client_, TestName()); 122 ComputationDataHandle src; 123 std::unique_ptr<GlobalData> param_data = 124 CreateR0Parameter<float>(2.25f, /*parameter_number=*/0, /*name=*/"src", 125 /*builder=*/&b, /*data_handle=*/&src); 126 127 b.Broadcast(src, {2, 3}); 128 Array2D<float> expected(2, 3, 2.25); 129 ComputeAndCompareR2<float>(&b, expected, {param_data.get()}, 130 ErrorSpec(0.0001)); 131} 132 133XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) { 134 ComputationBuilder b(client_, TestName()); 135 b.Broadcast(b.ConstantR0<float>(2.25), {2, 0}); 136 Array2D<float> expected(2, 0); 137 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001)); 138} 139 140XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_0x2) { 141 ComputationBuilder b(client_, TestName()); 142 b.Broadcast(b.ConstantR0<float>(2.25), {0, 2}); 143 Array2D<float> expected(0, 2); 144 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001)); 145} 146 147XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { 148 ComputationBuilder b(client_, TestName()); 149 b.Broadcast(b.ConstantR1<float>({1, 2, 3}), {2}); 150 151 Array2D<float> expected(2, 3); 152 expected(0, 0) = 1; 153 expected(0, 1) = 2; 154 expected(0, 2) = 3; 155 expected(1, 0) = 1; 156 expected(1, 1) = 2; 157 expected(1, 2) = 3; 158 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001)); 159} 160 161// Tests implicit broadcasting of PREDs. 162XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { 163 ComputationBuilder b(client_, TestName()); 164 165 Array2D<bool> x_vals(2, 1); 166 x_vals(0, 0) = true; 167 x_vals(1, 0) = false; 168 Array3D<bool> y_vals(2, 2, 1); 169 y_vals(0, 0, 0) = false; 170 y_vals(0, 1, 0) = false; 171 y_vals(1, 0, 0) = true; 172 y_vals(1, 1, 0) = true; 173 174 ComputationDataHandle x, y; 175 auto x_data = CreateR2Parameter<bool>(x_vals, 0, "x", &b, &x); 176 auto y_data = CreateR3Parameter<bool>(y_vals, 1, "y", &b, &y); 177 b.And(x, y, /*broadcast_dimensions=*/{1, 2}); 178 179 Array3D<bool> expected(2, 2, 1); 180 expected(0, 0, 0) = false; 181 expected(0, 1, 0) = false; 182 expected(1, 0, 0) = true; 183 expected(1, 1, 0) = false; 184 185 ComputeAndCompareR3<bool>(&b, expected, {x_data.get(), y_data.get()}); 186} 187 188XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) { 189 ComputationBuilder b(client_, TestName()); 190 b.Broadcast(b.ConstantR1<float>({}), {2}); 191 192 Array2D<float> expected(2, 0); 193 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001)); 194} 195 196XLA_TEST_F(BroadcastSimpleTest, 1DToZeroElement2D) { 197 ComputationBuilder b(client_, TestName()); 198 b.Broadcast(b.ConstantR1<float>({1, 2, 3}), {0}); 199 200 Array2D<float> expected(0, 3); 201 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001)); 202} 203 204XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { 205 // Verify that binary op and degenerate dimension broadcast work together in 206 // the same operation. 207 // 208 // The lhs shape [1, 2] is first broadcast up to [2, 1, 2] using in-dimension 209 // broadcasting (broadcast_dimensions {1, 2}), then is added to the rhs shape 210 // [2, 3, 1]. Degenerate dimension broadcasting then broadcasts the size one 211 // dimensions. 212 ComputationBuilder b(client_, TestName()); 213 214 b.Add(b.ConstantR2<float>({{1.0, 5.0}}), 215 b.ConstantLiteral(*Literal::CreateR3<float>( 216 {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), 217 /*broadcast_dimensions=*/{1, 2}); 218 219 auto expected = 220 Literal::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, 221 {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}}); 222 223 ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 224} 225 226struct R3ImplicitBroadcastSpec { 227 std::array<int64, 3> output_bounds; 228 std::array<int64, 3> minor2major_layout; 229 std::array<int64, 3> input_bounds; 230 HloOpcode op; 231} kR3ImplicitBroadcastTestCases[] = { 232 {{{1, 1, 1}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd}, 233 {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 5}}, HloOpcode::kMaximum}, 234 {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 1}}, HloOpcode::kMinimum}, 235 {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 1}}, HloOpcode::kMultiply}, 236 {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd}, 237 {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 5}}, HloOpcode::kAdd}, 238 {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 4, 1}}, HloOpcode::kAdd}, 239 {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 5}}, HloOpcode::kAdd}, 240 {{{3, 199, 5}}, {{2, 1, 0}}, {{1, 199, 1}}, HloOpcode::kMinimum}, 241 {{{3, 4, 199}}, {{2, 1, 0}}, {{1, 1, 199}}, HloOpcode::kAdd}, 242}; 243 244class BroadcastR3ImplicitTest 245 : public BroadcastSimpleTest, 246 public ::testing::WithParamInterface<R3ImplicitBroadcastSpec> {}; 247 248XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { 249 const R3ImplicitBroadcastSpec& spec = GetParam(); 250 ComputationBuilder builder(client_, TestName()); 251 252 Shape r3_shape, r3_implicit_shape; 253 Array3D<float> r3_array(spec.output_bounds[0], spec.output_bounds[1], 254 spec.output_bounds[2]); 255 Array3D<float> r3_implicit_array(spec.input_bounds[0], spec.input_bounds[1], 256 spec.input_bounds[2]); 257 258 std::unique_ptr<GlobalData> r3_global_data = 259 MakeR3Data(spec.output_bounds, spec.minor2major_layout, &r3_shape, 260 &r3_array, 1.0, 2.5, 56789); 261 std::unique_ptr<GlobalData> r3_implicit_global_data = 262 MakeR3Data(spec.input_bounds, spec.minor2major_layout, &r3_implicit_shape, 263 &r3_implicit_array, 1.0, 0.2, 56789); 264 265 auto r3_implicit_parameter = builder.Parameter(0, r3_implicit_shape, "input"); 266 auto r3_parameter = builder.Parameter(1, r3_shape, "input"); 267 ComputationDataHandle op = 268 BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder); 269 270 Array3D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1], 271 spec.output_bounds[2]); 272 auto Each = ([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) { 273 float r3_implicit = r3_implicit_array(indices[0] % spec.input_bounds[0], 274 indices[1] % spec.input_bounds[1], 275 indices[2] % spec.input_bounds[2]); 276 float r3 = r3_array(indices[0], indices[1], indices[2]); 277 *value = ApplyOpToFloats(spec.op, r3_implicit, r3); 278 }); 279 280 int n1 = expected_array.n1(); 281 int n2 = expected_array.n2(); 282 int n3 = expected_array.n3(); 283 for (int64 i = 0; i < n1; i++) { 284 for (int64 j = 0; j < n2; j++) { 285 for (int64 k = 0; k < n3; k++) { 286 Each({i, j, k}, &expected_array(i, j, k)); 287 } 288 } 289 } 290 auto expected = Literal::CreateR3FromArray3D(expected_array); 291 ComputeAndCompareLiteral( 292 &builder, *expected, 293 {r3_implicit_global_data.get(), r3_global_data.get()}, 294 ErrorSpec(1e-7, 1e-7)); 295} 296 297INSTANTIATE_TEST_CASE_P(BroadcastR3ImplicitTestInstances, 298 BroadcastR3ImplicitTest, 299 ::testing::ValuesIn(kR3ImplicitBroadcastTestCases)); 300 301// r1 and r3's dim0 matches, and r1's dim1 and dim2 have size 1: 302XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { 303 ComputationBuilder b(client_, TestName()); 304 ComputationDataHandle r1h; 305 ComputationDataHandle r3h; 306 307 Array3D<float> r1d = {{{1}}, {{2}}}; 308 Array3D<float> r3d = {{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}; 309 auto r1 = CreateR3Parameter(r1d, 1, "r1", &b, &r1h); 310 auto r3 = CreateR3Parameter(r3d, 0, "r3", &b, &r3h); 311 312 b.Add(r3h, r1h); 313 314 auto expected = 315 Literal::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); 316 317 ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()}, 318 ErrorSpec(0.0001)); 319} 320 321XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { 322 ComputationBuilder b(client_, TestName()); 323 auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1, 2}}})); 324 auto r3 = b.ConstantLiteral( 325 *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); 326 b.Add(r3, r1); 327 328 auto expected = 329 Literal::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); 330 331 ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 332} 333 334XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { 335 ComputationBuilder b(client_, TestName()); 336 auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1}, {2}}})); 337 auto r3 = b.ConstantLiteral( 338 *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); 339 b.Add(r3, r1); 340 341 auto expected = 342 Literal::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); 343 344 ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 345} 346 347XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { 348 ComputationBuilder b(client_, TestName()); 349 auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1, 2}, {3, 4}}})); 350 auto r3 = b.ConstantLiteral( 351 *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); 352 b.Add(r3, r1); 353 354 auto expected = 355 Literal::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); 356 357 ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 358} 359 360XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { 361 ComputationBuilder b(client_, TestName()); 362 auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1, 2}}, {{3, 4}}})); 363 auto r3 = b.ConstantLiteral( 364 *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); 365 b.Add(r3, r1); 366 367 auto expected = 368 Literal::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); 369 370 ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 371} 372 373XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { 374 ComputationBuilder b(client_, TestName()); 375 auto r1 = 376 b.ConstantLiteral(*Literal::CreateR3<float>({{{1}, {2}}, {{3}, {4}}})); 377 auto r3 = b.ConstantLiteral( 378 *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); 379 b.Add(r3, r1); 380 381 auto expected = 382 Literal::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); 383 384 ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 385} 386 387XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) { 388 ComputationBuilder b(client_, TestName()); 389 auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1}}})); 390 auto r3 = b.ConstantLiteral( 391 *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); 392 b.Add(r3, r1); 393 394 auto expected = 395 Literal::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); 396 397 ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 398} 399 400struct R2ImplicitBroadcastSpec { 401 std::array<int64, 2> output_bounds; 402 std::array<int64, 2> minor2major_layout; 403 std::array<int64, 2> input_bounds1; 404 std::array<int64, 2> input_bounds2; 405 HloOpcode op1; 406 HloOpcode op2; 407} kR2ImplicitBroadcastTestCases[] = { 408 {{{2, 3}}, {{1, 0}}, {{2, 1}}, {{2, 1}}, HloOpcode::kAdd, HloOpcode::kAdd}, 409 {{{2, 3}}, {{1, 0}}, {{2, 1}}, {{1, 3}}, HloOpcode::kAdd, HloOpcode::kAdd}, 410 {{{2, 3}}, 411 {{1, 0}}, 412 {{2, 1}}, 413 {{1, 1}}, 414 HloOpcode::kAdd, 415 HloOpcode::kMinimum}, 416 {{{2, 3}}, 417 {{1, 0}}, 418 {{1, 3}}, 419 {{1, 1}}, 420 HloOpcode::kAdd, 421 HloOpcode::kMinimum}, 422 {{{2, 3}}, 423 {{1, 0}}, 424 {{1, 1}}, 425 {{1, 1}}, 426 HloOpcode::kAdd, 427 HloOpcode::kMinimum}, 428 {{{2, 3}}, {{0, 1}}, {{2, 1}}, {{2, 1}}, HloOpcode::kAdd, HloOpcode::kAdd}, 429 {{{150, 150}}, 430 {{1, 0}}, 431 {{150, 1}}, 432 {{150, 1}}, 433 HloOpcode::kAdd, 434 HloOpcode::kAdd}, 435 {{{150, 150}}, 436 {{1, 0}}, 437 {{150, 1}}, 438 {{1, 150}}, 439 HloOpcode::kAdd, 440 HloOpcode::kAdd}, 441 {{{150, 150}}, 442 {{1, 0}}, 443 {{150, 1}}, 444 {{1, 1}}, 445 HloOpcode::kAdd, 446 HloOpcode::kAdd}, 447 {{{50, 150}}, 448 {{1, 0}}, 449 {{50, 1}}, 450 {{50, 1}}, 451 HloOpcode::kAdd, 452 HloOpcode::kAdd}, 453 {{{50, 150}}, 454 {{1, 0}}, 455 {{50, 1}}, 456 {{1, 150}}, 457 HloOpcode::kAdd, 458 HloOpcode::kAdd}, 459 {{{50, 150}}, 460 {{1, 0}}, 461 {{50, 1}}, 462 {{1, 1}}, 463 HloOpcode::kAdd, 464 HloOpcode::kAdd}, 465 {{{150, 50}}, 466 {{1, 0}}, 467 {{150, 1}}, 468 {{150, 1}}, 469 HloOpcode::kAdd, 470 HloOpcode::kAdd}, 471 {{{150, 50}}, 472 {{1, 0}}, 473 {{150, 1}}, 474 {{1, 50}}, 475 HloOpcode::kAdd, 476 HloOpcode::kAdd}, 477 {{{150, 50}}, 478 {{1, 0}}, 479 {{150, 1}}, 480 {{1, 1}}, 481 HloOpcode::kAdd, 482 HloOpcode::kAdd}}; 483 484class BroadcastR2ImplicitTest 485 : public BroadcastSimpleTest, 486 public ::testing::WithParamInterface<R2ImplicitBroadcastSpec> {}; 487 488// Test r2 op1 r2_implicit_1 op2 r2_implicit_2 489// where R2 is a rank-2 operand, and r2_implicit_2 are two 490// rank-2 operands with degenerate dimensions: 491XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { 492 const R2ImplicitBroadcastSpec& spec = GetParam(); 493 494 ComputationBuilder builder(client_, TestName()); 495 496 // Operands with degenerate dimensions require implicit broadcasting: 497 Shape r2_shape, r2_implicit_shape1, r2_implicit_shape2; 498 Array2D<float> r2_array(spec.output_bounds[0], spec.output_bounds[1]); 499 Array2D<float> r2_implicit_array1(spec.input_bounds1[0], 500 spec.input_bounds1[1]); 501 Array2D<float> r2_implicit_array2(spec.input_bounds2[0], 502 spec.input_bounds2[1]); 503 504 std::unique_ptr<GlobalData> r2_global_data = 505 MakeR2Data(spec.output_bounds, spec.minor2major_layout, &r2_shape, 506 &r2_array, 1.0, 2.5, 56789); 507 std::unique_ptr<GlobalData> r2_implicit_global_data1 = 508 MakeR2Data(spec.input_bounds1, spec.minor2major_layout, 509 &r2_implicit_shape1, &r2_implicit_array1, 1.0, 0.2, 56789); 510 std::unique_ptr<GlobalData> r2_implicit_global_data2 = 511 MakeR2Data(spec.input_bounds2, spec.minor2major_layout, 512 &r2_implicit_shape2, &r2_implicit_array2, 0.8, 0.4, 56789); 513 514 auto r2_implicit_parameter1 = 515 builder.Parameter(0, r2_implicit_shape1, "input0"); 516 auto r2_parameter = builder.Parameter(1, r2_shape, "input1"); 517 auto r2_implicit_parameter2 = 518 builder.Parameter(2, r2_implicit_shape2, "input2"); 519 520 ComputationDataHandle op1 = 521 BuildBinOp(spec.op1, r2_implicit_parameter1, r2_parameter, &builder); 522 ComputationDataHandle op2 = 523 BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder); 524 525 Array2D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1]); 526 527 expected_array.Each([&](int64 i, int64 j, float* v) { 528 float v1 = r2_implicit_array1(i % spec.input_bounds1[0], 529 j % spec.input_bounds1[1]); 530 float v2 = r2_array(i, j); 531 float v3 = r2_implicit_array2(i % spec.input_bounds2[0], 532 j % spec.input_bounds2[1]); 533 float tmp = ApplyOpToFloats(spec.op1, v1, v2); 534 *v = ApplyOpToFloats(spec.op2, tmp, v3); 535 }); 536 537 auto expected = Literal::CreateR2FromArray2D(expected_array); 538 ComputeAndCompareLiteral( 539 &builder, *expected, 540 {r2_implicit_global_data1.get(), r2_global_data.get(), 541 r2_implicit_global_data2.get()}, 542 ErrorSpec(1e-6, 1e-6)); 543} 544 545INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances, 546 BroadcastR2ImplicitTest, 547 ::testing::ValuesIn(kR2ImplicitBroadcastTestCases)); 548 549XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { 550 ComputationBuilder b(client_, TestName()); 551 auto r1 = b.ConstantLiteral(*Literal::CreateR2<float>({{1, 2}})); 552 auto r2 = b.ConstantLiteral(*Literal::CreateR2<float>({{1, 2}, {3, 4}})); 553 b.Add(r2, r1); 554 555 auto expected = Literal::CreateR2<float>({{2, 4}, {4, 6}}); 556 557 ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 558} 559 560XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { 561 ComputationBuilder b(client_, TestName()); 562 auto r1 = b.ConstantLiteral(*Literal::CreateR2<float>({{1}, {2}})); 563 auto r2 = b.ConstantLiteral(*Literal::CreateR2<float>({{1, 2}, {3, 4}})); 564 b.Add(r2, r1); 565 566 auto expected = Literal::CreateR2<float>({{2, 3}, {5, 6}}); 567 568 ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 569} 570 571XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { 572 ComputationBuilder b(client_, TestName()); 573 auto r1 = b.ConstantR1<float>({10, 20}); 574 auto r3 = b.ConstantLiteral( 575 *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); 576 b.Add(r3, r1, {0}); 577 578 auto expected = 579 Literal::CreateR3<float>({{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); 580 581 ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 582} 583 584XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { 585 ComputationBuilder b(client_, TestName()); 586 auto r1 = b.ConstantR1<float>({10, 20}); 587 auto r3 = b.ConstantLiteral( 588 *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); 589 b.Add(r1, r3, {1}); 590 591 auto expected = 592 Literal::CreateR3<float>({{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); 593 594 ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 595} 596 597XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { 598 ComputationBuilder b(client_, TestName()); 599 auto r1 = b.ConstantR1<float>({10, 20}); 600 auto r3 = b.ConstantLiteral( 601 *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); 602 b.Add(r1, r3, {2}); 603 604 auto expected = 605 Literal::CreateR3<float>({{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); 606 607 ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 608} 609 610XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { 611 ComputationBuilder b(client_, TestName()); 612 auto r1_0 = b.ConstantR1<float>({1000, 2000}); 613 auto r1_1 = b.ConstantR1<float>({100, 200}); 614 auto r1_2 = b.ConstantR1<float>({10, 20}); 615 auto r3 = b.ConstantLiteral( 616 *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); 617 for (int i = 0; i < 3; ++i) { 618 r3 = b.Add(r1_0, r3, {0}); 619 r3 = b.Add(r3, r1_1, {1}); 620 r3 = b.Add(r1_2, r3, {2}); 621 } 622 r3 = b.Mul(r3, b.ConstantR0<float>(-2)); 623 624 auto expected = Literal::CreateR3<float>( 625 {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}}, 626 {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}}); 627 628 ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 629} 630 631XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { 632 ComputationBuilder b(client_, TestName()); 633 auto r1_0 = b.ConstantR1<float>({1000, 2000}); 634 auto r1_1 = b.ConstantR1<float>({100, 200}); 635 auto r1_2 = b.ConstantR1<float>({10, 20}); 636 auto r0 = b.ConstantR0<float>(3); 637 auto r3 = b.Broadcast(r0, {2, 2, 2}); 638 for (int i = 0; i < 3; ++i) { 639 r3 = b.Add(r1_0, r3, {0}); 640 r3 = b.Add(r3, r1_1, {1}); 641 r3 = b.Add(r1_2, r3, {2}); 642 } 643 r3 = b.Mul(r3, b.ConstantR0<float>(-1)); 644 645 auto expected = Literal::CreateR3<float>( 646 {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}}, 647 {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}}); 648 649 ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); 650} 651 652XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { 653 // Binary dimension broadcasting of the smaller lhs ([2, 2] up to [2, 2, 2]) 654 // results in a shape incompatible with the lhs [2, 3, 1]. 655 ComputationBuilder b(client_, TestName()); 656 657 b.Add(b.ConstantR2<float>({{1.0, 5.0}, {1.0, 5.0}}), 658 b.ConstantLiteral(*Literal::CreateR3<float>( 659 {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), 660 /*broadcast_dimensions=*/{1, 2}); 661 662 auto result_status = Execute(&b, {}); 663 EXPECT_FALSE(result_status.ok()); 664 EXPECT_THAT(result_status.status().error_message(), 665 HasSubstr("broadcast dimension 0 mismatch")); 666} 667 668XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { 669 // Test invalid broadcasting with [1, 2] and [2, 3] inputs. 670 ComputationBuilder b(client_, TestName()); 671 672 b.Add(b.ConstantR2<float>({{1.0, 2.0}}), 673 b.ConstantR2<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); 674 675 auto result_status = Execute(&b, {}); 676 EXPECT_FALSE(result_status.ok()); 677 EXPECT_THAT(result_status.status().error_message(), 678 HasSubstr("binary op BINOP_ADD with incompatible shapes")); 679} 680 681XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { 682 // Test invalid broadcasting with [1, 2] and [2, 3] inputs. 683 ComputationBuilder b(client_, TestName()); 684 685 b.Add(b.ConstantR2<float>({{1.0, 2.0}}), 686 b.ConstantR2<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); 687 688 auto result_status = Execute(&b, {}); 689 EXPECT_FALSE(result_status.ok()); 690 EXPECT_THAT(result_status.status().error_message(), 691 HasSubstr("binary op BINOP_ADD with incompatible shapes")); 692} 693 694} // namespace 695} // namespace xla 696