1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <memory>
17#include <numeric>
18#include <vector>
19
20#include "tensorflow/compiler/xla/array2d.h"
21#include "tensorflow/compiler/xla/array4d.h"
22#include "tensorflow/compiler/xla/client/computation_builder.h"
23#include "tensorflow/compiler/xla/client/local_client.h"
24#include "tensorflow/compiler/xla/literal_util.h"
25#include "tensorflow/compiler/xla/statusor.h"
26#include "tensorflow/compiler/xla/test.h"
27#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
28#include "tensorflow/compiler/xla/tests/literal_test_util.h"
29#include "tensorflow/compiler/xla/tests/test_macros.h"
30
31namespace xla {
32namespace {
33
34class BroadcastSimpleTest : public ClientLibraryTestBase {
35 public:
36  ComputationDataHandle BuildBinOp(HloOpcode op,
37                                   const ComputationDataHandle& lhs,
38                                   const ComputationDataHandle& rhs,
39                                   ComputationBuilder* builder) {
40    switch (op) {
41      case HloOpcode::kMinimum: {
42        return builder->Min(lhs, rhs);
43      }
44      case HloOpcode::kMaximum: {
45        return builder->Max(lhs, rhs);
46      }
47      case HloOpcode::kMultiply: {
48        return builder->Mul(lhs, rhs);
49      }
50      default: {
51        // Default to Add
52        return builder->Add(lhs, rhs);
53      }
54    }
55  }
56
57  std::unique_ptr<GlobalData> MakeR3Data(
58      tensorflow::gtl::ArraySlice<int64> bounds,
59      tensorflow::gtl::ArraySlice<int64> minor_to_major, Shape* r3_shape,
60      Array3D<float>* r3_array, float start, float end, int seed) {
61    *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
62    r3_array->FillRandom(start, end, seed);
63    auto r3_data = Literal::CreateR3FromArray3D(*r3_array)->Relayout(
64        LayoutUtil::MakeLayout(minor_to_major));
65    std::unique_ptr<GlobalData> r3_global_data =
66        client_->TransferToServer(*r3_data).ConsumeValueOrDie();
67    return r3_global_data;
68  }
69
70  std::unique_ptr<GlobalData> MakeR2Data(
71      tensorflow::gtl::ArraySlice<int64> bounds,
72      tensorflow::gtl::ArraySlice<int64> minor_to_major, Shape* r2_shape,
73      Array2D<float>* r2_array, float start, float end, int seed) {
74    *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
75    r2_array->FillRandom(start, end, seed);
76    auto r2_data = Literal::CreateR2FromArray2D(*r2_array)->Relayout(
77        LayoutUtil::MakeLayout(minor_to_major));
78    std::unique_ptr<GlobalData> r2_global_data =
79        client_->TransferToServer(*r2_data).ConsumeValueOrDie();
80    return r2_global_data;
81  }
82
83  float ApplyOpToFloats(HloOpcode op, float lhs, float rhs) {
84    switch (op) {
85      case HloOpcode::kMinimum: {
86        return std::min(lhs, rhs);
87      }
88      case HloOpcode::kMaximum: {
89        return std::max(lhs, rhs);
90      }
91      case HloOpcode::kMultiply: {
92        return lhs * rhs;
93      }
94      case HloOpcode::kAdd: {
95        return lhs + rhs;
96      }
97      default: {
98        // Default to Add
99        LOG(FATAL);
100      }
101    }
102  }
103};
104
105using ::testing::HasSubstr;
106
107XLA_TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) {
108  ComputationBuilder b(client_, TestName());
109  b.Broadcast(b.ConstantR0<float>(1.5), {});
110  ComputeAndCompareR0<float>(&b, 1.5, {}, ErrorSpec(0.0001));
111}
112
113XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) {
114  ComputationBuilder b(client_, TestName());
115  b.Broadcast(b.ConstantR0<float>(2.25), {2, 3});
116  Array2D<float> expected(2, 3, 2.25);
117  ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
118}
119
120XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) {
121  ComputationBuilder b(client_, TestName());
122  ComputationDataHandle src;
123  std::unique_ptr<GlobalData> param_data =
124      CreateR0Parameter<float>(2.25f, /*parameter_number=*/0, /*name=*/"src",
125                               /*builder=*/&b, /*data_handle=*/&src);
126
127  b.Broadcast(src, {2, 3});
128  Array2D<float> expected(2, 3, 2.25);
129  ComputeAndCompareR2<float>(&b, expected, {param_data.get()},
130                             ErrorSpec(0.0001));
131}
132
133XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) {
134  ComputationBuilder b(client_, TestName());
135  b.Broadcast(b.ConstantR0<float>(2.25), {2, 0});
136  Array2D<float> expected(2, 0);
137  ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
138}
139
140XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_0x2) {
141  ComputationBuilder b(client_, TestName());
142  b.Broadcast(b.ConstantR0<float>(2.25), {0, 2});
143  Array2D<float> expected(0, 2);
144  ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
145}
146
147XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) {
148  ComputationBuilder b(client_, TestName());
149  b.Broadcast(b.ConstantR1<float>({1, 2, 3}), {2});
150
151  Array2D<float> expected(2, 3);
152  expected(0, 0) = 1;
153  expected(0, 1) = 2;
154  expected(0, 2) = 3;
155  expected(1, 0) = 1;
156  expected(1, 1) = 2;
157  expected(1, 2) = 3;
158  ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
159}
160
161// Tests implicit broadcasting of PREDs.
162XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) {
163  ComputationBuilder b(client_, TestName());
164
165  Array2D<bool> x_vals(2, 1);
166  x_vals(0, 0) = true;
167  x_vals(1, 0) = false;
168  Array3D<bool> y_vals(2, 2, 1);
169  y_vals(0, 0, 0) = false;
170  y_vals(0, 1, 0) = false;
171  y_vals(1, 0, 0) = true;
172  y_vals(1, 1, 0) = true;
173
174  ComputationDataHandle x, y;
175  auto x_data = CreateR2Parameter<bool>(x_vals, 0, "x", &b, &x);
176  auto y_data = CreateR3Parameter<bool>(y_vals, 1, "y", &b, &y);
177  b.And(x, y, /*broadcast_dimensions=*/{1, 2});
178
179  Array3D<bool> expected(2, 2, 1);
180  expected(0, 0, 0) = false;
181  expected(0, 1, 0) = false;
182  expected(1, 0, 0) = true;
183  expected(1, 1, 0) = false;
184
185  ComputeAndCompareR3<bool>(&b, expected, {x_data.get(), y_data.get()});
186}
187
188XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) {
189  ComputationBuilder b(client_, TestName());
190  b.Broadcast(b.ConstantR1<float>({}), {2});
191
192  Array2D<float> expected(2, 0);
193  ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
194}
195
196XLA_TEST_F(BroadcastSimpleTest, 1DToZeroElement2D) {
197  ComputationBuilder b(client_, TestName());
198  b.Broadcast(b.ConstantR1<float>({1, 2, 3}), {0});
199
200  Array2D<float> expected(0, 3);
201  ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
202}
203
204XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
205  // Verify that binary op and degenerate dimension broadcast work together in
206  // the same operation.
207  //
208  // The lhs shape [1, 2] is first broadcast up to [2, 1, 2] using in-dimension
209  // broadcasting (broadcast_dimensions {1, 2}), then is added to the rhs shape
210  // [2, 3, 1]. Degenerate dimension broadcasting then broadcasts the size one
211  // dimensions.
212  ComputationBuilder b(client_, TestName());
213
214  b.Add(b.ConstantR2<float>({{1.0, 5.0}}),
215        b.ConstantLiteral(*Literal::CreateR3<float>(
216            {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
217        /*broadcast_dimensions=*/{1, 2});
218
219  auto expected =
220      Literal::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
221                                {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
222
223  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
224}
225
226struct R3ImplicitBroadcastSpec {
227  std::array<int64, 3> output_bounds;
228  std::array<int64, 3> minor2major_layout;
229  std::array<int64, 3> input_bounds;
230  HloOpcode op;
231} kR3ImplicitBroadcastTestCases[] = {
232    {{{1, 1, 1}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd},
233    {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 5}}, HloOpcode::kMaximum},
234    {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 1}}, HloOpcode::kMinimum},
235    {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 1}}, HloOpcode::kMultiply},
236    {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd},
237    {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 5}}, HloOpcode::kAdd},
238    {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 4, 1}}, HloOpcode::kAdd},
239    {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 5}}, HloOpcode::kAdd},
240    {{{3, 199, 5}}, {{2, 1, 0}}, {{1, 199, 1}}, HloOpcode::kMinimum},
241    {{{3, 4, 199}}, {{2, 1, 0}}, {{1, 1, 199}}, HloOpcode::kAdd},
242};
243
244class BroadcastR3ImplicitTest
245    : public BroadcastSimpleTest,
246      public ::testing::WithParamInterface<R3ImplicitBroadcastSpec> {};
247
248XLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
249  const R3ImplicitBroadcastSpec& spec = GetParam();
250  ComputationBuilder builder(client_, TestName());
251
252  Shape r3_shape, r3_implicit_shape;
253  Array3D<float> r3_array(spec.output_bounds[0], spec.output_bounds[1],
254                          spec.output_bounds[2]);
255  Array3D<float> r3_implicit_array(spec.input_bounds[0], spec.input_bounds[1],
256                                   spec.input_bounds[2]);
257
258  std::unique_ptr<GlobalData> r3_global_data =
259      MakeR3Data(spec.output_bounds, spec.minor2major_layout, &r3_shape,
260                 &r3_array, 1.0, 2.5, 56789);
261  std::unique_ptr<GlobalData> r3_implicit_global_data =
262      MakeR3Data(spec.input_bounds, spec.minor2major_layout, &r3_implicit_shape,
263                 &r3_implicit_array, 1.0, 0.2, 56789);
264
265  auto r3_implicit_parameter = builder.Parameter(0, r3_implicit_shape, "input");
266  auto r3_parameter = builder.Parameter(1, r3_shape, "input");
267  ComputationDataHandle op =
268      BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder);
269
270  Array3D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1],
271                                spec.output_bounds[2]);
272  auto Each = ([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) {
273    float r3_implicit = r3_implicit_array(indices[0] % spec.input_bounds[0],
274                                          indices[1] % spec.input_bounds[1],
275                                          indices[2] % spec.input_bounds[2]);
276    float r3 = r3_array(indices[0], indices[1], indices[2]);
277    *value = ApplyOpToFloats(spec.op, r3_implicit, r3);
278  });
279
280  int n1 = expected_array.n1();
281  int n2 = expected_array.n2();
282  int n3 = expected_array.n3();
283  for (int64 i = 0; i < n1; i++) {
284    for (int64 j = 0; j < n2; j++) {
285      for (int64 k = 0; k < n3; k++) {
286        Each({i, j, k}, &expected_array(i, j, k));
287      }
288    }
289  }
290  auto expected = Literal::CreateR3FromArray3D(expected_array);
291  ComputeAndCompareLiteral(
292      &builder, *expected,
293      {r3_implicit_global_data.get(), r3_global_data.get()},
294      ErrorSpec(1e-7, 1e-7));
295}
296
297INSTANTIATE_TEST_CASE_P(BroadcastR3ImplicitTestInstances,
298                        BroadcastR3ImplicitTest,
299                        ::testing::ValuesIn(kR3ImplicitBroadcastTestCases));
300
301// r1 and r3's dim0 matches, and r1's dim1 and dim2 have size 1:
302XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) {
303  ComputationBuilder b(client_, TestName());
304  ComputationDataHandle r1h;
305  ComputationDataHandle r3h;
306
307  Array3D<float> r1d = {{{1}}, {{2}}};
308  Array3D<float> r3d = {{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}};
309  auto r1 = CreateR3Parameter(r1d, 1, "r1", &b, &r1h);
310  auto r3 = CreateR3Parameter(r3d, 0, "r3", &b, &r3h);
311
312  b.Add(r3h, r1h);
313
314  auto expected =
315      Literal::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}});
316
317  ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()},
318                           ErrorSpec(0.0001));
319}
320
321XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) {
322  ComputationBuilder b(client_, TestName());
323  auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1, 2}}}));
324  auto r3 = b.ConstantLiteral(
325      *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
326  b.Add(r3, r1);
327
328  auto expected =
329      Literal::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}});
330
331  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
332}
333
334XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) {
335  ComputationBuilder b(client_, TestName());
336  auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1}, {2}}}));
337  auto r3 = b.ConstantLiteral(
338      *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
339  b.Add(r3, r1);
340
341  auto expected =
342      Literal::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}});
343
344  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
345}
346
347XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) {
348  ComputationBuilder b(client_, TestName());
349  auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1, 2}, {3, 4}}}));
350  auto r3 = b.ConstantLiteral(
351      *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
352  b.Add(r3, r1);
353
354  auto expected =
355      Literal::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}});
356
357  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
358}
359
360XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) {
361  ComputationBuilder b(client_, TestName());
362  auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
363  auto r3 = b.ConstantLiteral(
364      *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
365  b.Add(r3, r1);
366
367  auto expected =
368      Literal::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}});
369
370  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
371}
372
373XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) {
374  ComputationBuilder b(client_, TestName());
375  auto r1 =
376      b.ConstantLiteral(*Literal::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
377  auto r3 = b.ConstantLiteral(
378      *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
379  b.Add(r3, r1);
380
381  auto expected =
382      Literal::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}});
383
384  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
385}
386
387XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) {
388  ComputationBuilder b(client_, TestName());
389  auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1}}}));
390  auto r3 = b.ConstantLiteral(
391      *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
392  b.Add(r3, r1);
393
394  auto expected =
395      Literal::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}});
396
397  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
398}
399
400struct R2ImplicitBroadcastSpec {
401  std::array<int64, 2> output_bounds;
402  std::array<int64, 2> minor2major_layout;
403  std::array<int64, 2> input_bounds1;
404  std::array<int64, 2> input_bounds2;
405  HloOpcode op1;
406  HloOpcode op2;
407} kR2ImplicitBroadcastTestCases[] = {
408    {{{2, 3}}, {{1, 0}}, {{2, 1}}, {{2, 1}}, HloOpcode::kAdd, HloOpcode::kAdd},
409    {{{2, 3}}, {{1, 0}}, {{2, 1}}, {{1, 3}}, HloOpcode::kAdd, HloOpcode::kAdd},
410    {{{2, 3}},
411     {{1, 0}},
412     {{2, 1}},
413     {{1, 1}},
414     HloOpcode::kAdd,
415     HloOpcode::kMinimum},
416    {{{2, 3}},
417     {{1, 0}},
418     {{1, 3}},
419     {{1, 1}},
420     HloOpcode::kAdd,
421     HloOpcode::kMinimum},
422    {{{2, 3}},
423     {{1, 0}},
424     {{1, 1}},
425     {{1, 1}},
426     HloOpcode::kAdd,
427     HloOpcode::kMinimum},
428    {{{2, 3}}, {{0, 1}}, {{2, 1}}, {{2, 1}}, HloOpcode::kAdd, HloOpcode::kAdd},
429    {{{150, 150}},
430     {{1, 0}},
431     {{150, 1}},
432     {{150, 1}},
433     HloOpcode::kAdd,
434     HloOpcode::kAdd},
435    {{{150, 150}},
436     {{1, 0}},
437     {{150, 1}},
438     {{1, 150}},
439     HloOpcode::kAdd,
440     HloOpcode::kAdd},
441    {{{150, 150}},
442     {{1, 0}},
443     {{150, 1}},
444     {{1, 1}},
445     HloOpcode::kAdd,
446     HloOpcode::kAdd},
447    {{{50, 150}},
448     {{1, 0}},
449     {{50, 1}},
450     {{50, 1}},
451     HloOpcode::kAdd,
452     HloOpcode::kAdd},
453    {{{50, 150}},
454     {{1, 0}},
455     {{50, 1}},
456     {{1, 150}},
457     HloOpcode::kAdd,
458     HloOpcode::kAdd},
459    {{{50, 150}},
460     {{1, 0}},
461     {{50, 1}},
462     {{1, 1}},
463     HloOpcode::kAdd,
464     HloOpcode::kAdd},
465    {{{150, 50}},
466     {{1, 0}},
467     {{150, 1}},
468     {{150, 1}},
469     HloOpcode::kAdd,
470     HloOpcode::kAdd},
471    {{{150, 50}},
472     {{1, 0}},
473     {{150, 1}},
474     {{1, 50}},
475     HloOpcode::kAdd,
476     HloOpcode::kAdd},
477    {{{150, 50}},
478     {{1, 0}},
479     {{150, 1}},
480     {{1, 1}},
481     HloOpcode::kAdd,
482     HloOpcode::kAdd}};
483
484class BroadcastR2ImplicitTest
485    : public BroadcastSimpleTest,
486      public ::testing::WithParamInterface<R2ImplicitBroadcastSpec> {};
487
488// Test r2 op1 r2_implicit_1 op2 r2_implicit_2
489// where R2 is a rank-2 operand, and r2_implicit_2 are two
490// rank-2 operands with degenerate dimensions:
491XLA_TEST_P(BroadcastR2ImplicitTest, Doit) {
492  const R2ImplicitBroadcastSpec& spec = GetParam();
493
494  ComputationBuilder builder(client_, TestName());
495
496  // Operands with degenerate dimensions require implicit broadcasting:
497  Shape r2_shape, r2_implicit_shape1, r2_implicit_shape2;
498  Array2D<float> r2_array(spec.output_bounds[0], spec.output_bounds[1]);
499  Array2D<float> r2_implicit_array1(spec.input_bounds1[0],
500                                    spec.input_bounds1[1]);
501  Array2D<float> r2_implicit_array2(spec.input_bounds2[0],
502                                    spec.input_bounds2[1]);
503
504  std::unique_ptr<GlobalData> r2_global_data =
505      MakeR2Data(spec.output_bounds, spec.minor2major_layout, &r2_shape,
506                 &r2_array, 1.0, 2.5, 56789);
507  std::unique_ptr<GlobalData> r2_implicit_global_data1 =
508      MakeR2Data(spec.input_bounds1, spec.minor2major_layout,
509                 &r2_implicit_shape1, &r2_implicit_array1, 1.0, 0.2, 56789);
510  std::unique_ptr<GlobalData> r2_implicit_global_data2 =
511      MakeR2Data(spec.input_bounds2, spec.minor2major_layout,
512                 &r2_implicit_shape2, &r2_implicit_array2, 0.8, 0.4, 56789);
513
514  auto r2_implicit_parameter1 =
515      builder.Parameter(0, r2_implicit_shape1, "input0");
516  auto r2_parameter = builder.Parameter(1, r2_shape, "input1");
517  auto r2_implicit_parameter2 =
518      builder.Parameter(2, r2_implicit_shape2, "input2");
519
520  ComputationDataHandle op1 =
521      BuildBinOp(spec.op1, r2_implicit_parameter1, r2_parameter, &builder);
522  ComputationDataHandle op2 =
523      BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder);
524
525  Array2D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1]);
526
527  expected_array.Each([&](int64 i, int64 j, float* v) {
528    float v1 = r2_implicit_array1(i % spec.input_bounds1[0],
529                                  j % spec.input_bounds1[1]);
530    float v2 = r2_array(i, j);
531    float v3 = r2_implicit_array2(i % spec.input_bounds2[0],
532                                  j % spec.input_bounds2[1]);
533    float tmp = ApplyOpToFloats(spec.op1, v1, v2);
534    *v = ApplyOpToFloats(spec.op2, tmp, v3);
535  });
536
537  auto expected = Literal::CreateR2FromArray2D(expected_array);
538  ComputeAndCompareLiteral(
539      &builder, *expected,
540      {r2_implicit_global_data1.get(), r2_global_data.get(),
541       r2_implicit_global_data2.get()},
542      ErrorSpec(1e-6, 1e-6));
543}
544
545INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances,
546                        BroadcastR2ImplicitTest,
547                        ::testing::ValuesIn(kR2ImplicitBroadcastTestCases));
548
549XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) {
550  ComputationBuilder b(client_, TestName());
551  auto r1 = b.ConstantLiteral(*Literal::CreateR2<float>({{1, 2}}));
552  auto r2 = b.ConstantLiteral(*Literal::CreateR2<float>({{1, 2}, {3, 4}}));
553  b.Add(r2, r1);
554
555  auto expected = Literal::CreateR2<float>({{2, 4}, {4, 6}});
556
557  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
558}
559
560XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) {
561  ComputationBuilder b(client_, TestName());
562  auto r1 = b.ConstantLiteral(*Literal::CreateR2<float>({{1}, {2}}));
563  auto r2 = b.ConstantLiteral(*Literal::CreateR2<float>({{1, 2}, {3, 4}}));
564  b.Add(r2, r1);
565
566  auto expected = Literal::CreateR2<float>({{2, 3}, {5, 6}});
567
568  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
569}
570
571XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) {
572  ComputationBuilder b(client_, TestName());
573  auto r1 = b.ConstantR1<float>({10, 20});
574  auto r3 = b.ConstantLiteral(
575      *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
576  b.Add(r3, r1, {0});
577
578  auto expected =
579      Literal::CreateR3<float>({{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}});
580
581  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
582}
583
584XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) {
585  ComputationBuilder b(client_, TestName());
586  auto r1 = b.ConstantR1<float>({10, 20});
587  auto r3 = b.ConstantLiteral(
588      *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
589  b.Add(r1, r3, {1});
590
591  auto expected =
592      Literal::CreateR3<float>({{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}});
593
594  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
595}
596
597XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) {
598  ComputationBuilder b(client_, TestName());
599  auto r1 = b.ConstantR1<float>({10, 20});
600  auto r3 = b.ConstantLiteral(
601      *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
602  b.Add(r1, r3, {2});
603
604  auto expected =
605      Literal::CreateR3<float>({{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}});
606
607  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
608}
609
610XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
611  ComputationBuilder b(client_, TestName());
612  auto r1_0 = b.ConstantR1<float>({1000, 2000});
613  auto r1_1 = b.ConstantR1<float>({100, 200});
614  auto r1_2 = b.ConstantR1<float>({10, 20});
615  auto r3 = b.ConstantLiteral(
616      *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
617  for (int i = 0; i < 3; ++i) {
618    r3 = b.Add(r1_0, r3, {0});
619    r3 = b.Add(r3, r1_1, {1});
620    r3 = b.Add(r1_2, r3, {2});
621  }
622  r3 = b.Mul(r3, b.ConstantR0<float>(-2));
623
624  auto expected = Literal::CreateR3<float>(
625      {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}},
626       {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}});
627
628  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
629}
630
631XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
632  ComputationBuilder b(client_, TestName());
633  auto r1_0 = b.ConstantR1<float>({1000, 2000});
634  auto r1_1 = b.ConstantR1<float>({100, 200});
635  auto r1_2 = b.ConstantR1<float>({10, 20});
636  auto r0 = b.ConstantR0<float>(3);
637  auto r3 = b.Broadcast(r0, {2, 2, 2});
638  for (int i = 0; i < 3; ++i) {
639    r3 = b.Add(r1_0, r3, {0});
640    r3 = b.Add(r3, r1_1, {1});
641    r3 = b.Add(r1_2, r3, {2});
642  }
643  r3 = b.Mul(r3, b.ConstantR0<float>(-1));
644
645  auto expected = Literal::CreateR3<float>(
646      {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}},
647       {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}});
648
649  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
650}
651
652XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
653  // Binary dimension broadcasting of the smaller lhs ([2, 2] up to [2, 2, 2])
654  // results in a shape incompatible with the lhs [2, 3, 1].
655  ComputationBuilder b(client_, TestName());
656
657  b.Add(b.ConstantR2<float>({{1.0, 5.0}, {1.0, 5.0}}),
658        b.ConstantLiteral(*Literal::CreateR3<float>(
659            {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
660        /*broadcast_dimensions=*/{1, 2});
661
662  auto result_status = Execute(&b, {});
663  EXPECT_FALSE(result_status.ok());
664  EXPECT_THAT(result_status.status().error_message(),
665              HasSubstr("broadcast dimension 0 mismatch"));
666}
667
668XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) {
669  // Test invalid broadcasting with [1, 2] and [2, 3] inputs.
670  ComputationBuilder b(client_, TestName());
671
672  b.Add(b.ConstantR2<float>({{1.0, 2.0}}),
673        b.ConstantR2<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
674
675  auto result_status = Execute(&b, {});
676  EXPECT_FALSE(result_status.ok());
677  EXPECT_THAT(result_status.status().error_message(),
678              HasSubstr("binary op BINOP_ADD with incompatible shapes"));
679}
680
681XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) {
682  // Test invalid broadcasting with [1, 2] and [2, 3] inputs.
683  ComputationBuilder b(client_, TestName());
684
685  b.Add(b.ConstantR2<float>({{1.0, 2.0}}),
686        b.ConstantR2<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
687
688  auto result_status = Execute(&b, {});
689  EXPECT_FALSE(result_status.ok());
690  EXPECT_THAT(result_status.status().error_message(),
691              HasSubstr("binary op BINOP_ADD with incompatible shapes"));
692}
693
694}  // namespace
695}  // namespace xla
696