1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <memory>
17#include <vector>
18
19#include "tensorflow/compiler/xla/array2d.h"
20#include "tensorflow/compiler/xla/array3d.h"
21#include "tensorflow/compiler/xla/client/computation.h"
22#include "tensorflow/compiler/xla/client/computation_builder.h"
23#include "tensorflow/compiler/xla/client/local_client.h"
24#include "tensorflow/compiler/xla/reference_util.h"
25#include "tensorflow/compiler/xla/statusor.h"
26#include "tensorflow/compiler/xla/test.h"
27#include "tensorflow/compiler/xla/test_helpers.h"
28#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
29#include "tensorflow/compiler/xla/tests/literal_test_util.h"
30#include "tensorflow/compiler/xla/tests/test_macros.h"
31#include "tensorflow/core/platform/test.h"
32
33namespace xla {
34namespace {
35
36using ConcatTest = ClientLibraryTestBase;
37using ::testing::HasSubstr;
38
39// Concatenate expects at least one argument.
40XLA_TEST_F(ConcatTest, Concat_Nothing) {
41  ComputationBuilder builder(client_, TestName());
42  auto concatenated = builder.ConcatInDim({}, 0);
43  StatusOr<Computation> computation_status = builder.Build();
44  ASSERT_FALSE(computation_status.ok());
45  EXPECT_THAT(computation_status.status().ToString(),
46              HasSubstr("Concatenate expects at least one argument"));
47}
48
49// Concatenate with one argument works.
50XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) {
51  ComputationBuilder builder(client_, TestName());
52  auto a = builder.ConstantR1<float>({42.0, 64.0});
53  auto concatenated = builder.ConcatInDim({a}, 0);
54
55  std::vector<float> expected = {42, 64};
56  ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
57}
58
59XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) {
60  ComputationBuilder builder(client_, TestName());
61  auto a = builder.ConstantR1<float>({});
62  auto concatenated = builder.ConcatInDim({a}, 0);
63
64  std::vector<float> expected = {};
65  ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
66}
67
68// Show that we can't concatenate R0 with R0 because we can't name the dimension
69// to concatenate on.
70XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) {
71  ComputationBuilder builder(client_, TestName());
72  auto a = builder.ConstantR0<float>(42.0);
73  auto b = builder.ConstantR0<float>(64.0);
74  auto concatenated = builder.ConcatInDim({a, b}, 0);
75  StatusOr<Computation> computation_status = builder.Build();
76  ASSERT_FALSE(computation_status.ok());
77  EXPECT_THAT(computation_status.status().ToString(),
78              HasSubstr("dimension to concatenate along out of bounds: 0"));
79}
80
81XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) {
82  ComputationBuilder builder(client_, TestName());
83  auto a = builder.ConstantR1<float>({});
84  auto b = builder.ConstantR1<float>({});
85  auto concatenated = builder.ConcatInDim({a, b}, 0);
86
87  std::vector<float> expected = {};
88  ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
89}
90
91XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) {
92  ComputationBuilder builder(client_, TestName());
93  auto a = builder.ConstantR1<float>({});
94  auto b = builder.ConstantR1<float>({256.0});
95  auto concatenated = builder.ConcatInDim({a, b}, 0);
96
97  std::vector<float> expected = {256};
98  ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
99}
100
101XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) {
102  ComputationBuilder builder(client_, TestName());
103  auto a = builder.ConstantR1<float>({42.0, 64.0});
104  auto b = builder.ConstantR1<float>({});
105  auto concatenated = builder.ConcatInDim({a, b}, 0);
106
107  std::vector<float> expected = {42, 64};
108  ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
109}
110
111XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L1) {
112  ComputationBuilder builder(client_, TestName());
113  auto a = builder.ConstantR1<float>({42.0, 64.0});
114  auto b = builder.ConstantR1<float>({256.0});
115  auto concatenated = builder.ConcatInDim({a, b}, 0);
116
117  std::vector<float> expected = {42, 64, 256};
118  ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
119}
120
121XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) {
122  std::vector<float> lhs(253);
123  std::vector<float> rhs(7);
124  std::vector<float> expected(253 + 7);
125  for (int i = 0; i < 253; ++i) {
126    expected[i] = lhs[i] = i + 1;
127  }
128  for (int i = 0; i < 7; ++i) {
129    expected[253 + i] = rhs[i] = 253 + i + 1;
130  }
131
132  ComputationBuilder builder(client_, TestName());
133  auto a = builder.ConstantR1<float>(lhs);
134  auto b = builder.ConstantR1<float>(rhs);
135  auto concatenated = builder.ConcatInDim({a, b}, 0);
136
137  ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
138}
139
140XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) {
141  for (int dim : {0, 1}) {
142    ComputationBuilder builder(client_, TestName());
143    auto a = builder.ConstantR2FromArray2D(Array2D<float>(0, 0));
144    auto b = builder.ConstantR2FromArray2D(Array2D<float>(0, 0));
145    auto concatenated = builder.ConcatInDim({a, b}, dim);
146
147    ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {},
148                               ErrorSpec(0.0001));
149  }
150}
151
152XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim0) {
153  ComputationBuilder builder(client_, TestName());
154  auto a_array = CreatePatternedMatrix(1, 1);
155  auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0);
156  auto a = builder.ConstantR2FromArray2D(*a_array);
157  auto b = builder.ConstantR2FromArray2D(*b_array);
158  auto concatenated = builder.ConcatInDim({a, b}, 0);
159
160  Array2D<float> expected({
161      {0}, {64},
162  });
163  ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
164}
165
166XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) {
167  ComputationBuilder builder(client_, TestName());
168  auto a_array = CreatePatternedMatrix(1, 1);
169  auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0);
170  auto a = builder.ConstantR2FromArray2D(*a_array);
171  auto b = builder.ConstantR2FromArray2D(*b_array);
172  auto concatenated = builder.ConcatInDim({a, b}, 1);
173
174  Array2D<float> expected({
175      {0, 64},
176  });
177  ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
178}
179
180XLA_TEST_F(ConcatTest, Concat2x0With2x5) {
181  ComputationBuilder builder(client_, TestName());
182  auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0);
183  auto a = builder.ConstantR2FromArray2D(Array2D<float>(2, 0));
184  auto b = builder.ConstantR2FromArray2D(*b_array);
185  auto concatenated = builder.ConcatInDim({a, b}, 1);
186
187  ComputeAndCompareR2<float>(&builder, *b_array, {}, ErrorSpec(0.0001));
188}
189
190XLA_TEST_F(ConcatTest, Concat2x3With2x5) {
191  ComputationBuilder builder(client_, TestName());
192  auto a_array = CreatePatternedMatrix(2, 3);
193  auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0);
194  auto a = builder.ConstantR2FromArray2D(*a_array);
195  auto b = builder.ConstantR2FromArray2D(*b_array);
196  auto concatenated = builder.ConcatInDim({a, b}, 1);
197
198  Array2D<float> expected({
199      {0, 1, 2, 64, 65, 66, 67, 68},
200      {1000, 1001, 1002, 1064, 1065, 1066, 1067, 1068},
201  });
202  ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
203}
204
205XLA_TEST_F(ConcatTest, Concat3x2With0x2) {
206  ComputationBuilder builder(client_, TestName());
207  auto a_array = CreatePatternedMatrix(3, 2);
208  auto a = builder.ConstantR2FromArray2D(*a_array);
209  auto b = builder.ConstantR2FromArray2D(Array2D<float>(0, 2));
210  auto concatenated = builder.ConcatInDim({a, b}, 0);
211
212  ComputeAndCompareR2<float>(&builder, *a_array, {}, ErrorSpec(0.0001));
213}
214
215XLA_TEST_F(ConcatTest, Concat3x2With5x2) {
216  ComputationBuilder builder(client_, TestName());
217  auto a_array = CreatePatternedMatrix(3, 2);
218  auto b_array = CreatePatternedMatrix(5, 2, /*offset=*/64.0);
219  auto a = builder.ConstantR2FromArray2D(*a_array);
220  auto b = builder.ConstantR2FromArray2D(*b_array);
221  auto concatenated = builder.ConcatInDim({a, b}, 0);
222
223  Array2D<float> expected({
224      {0, 1},
225      {1000, 1001},
226      {2000, 2001},
227      {64, 65},
228      {1064, 1065},
229      {2064, 2065},
230      {3064, 3065},
231      {4064, 4065},
232  });
233  ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
234}
235
236XLA_TEST_F(ConcatTest, Concat_R3_3x0x2_3x0x1) {
237  ComputationBuilder builder(client_, TestName());
238  auto a = builder.ConstantR3FromArray3D(Array3D<float>(3, 0, 2));
239  auto b = builder.ConstantR3FromArray3D(Array3D<float>(3, 0, 1));
240  auto concatenated = builder.ConcatInDim({a, b}, 2);
241  ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 3), {},
242                             ErrorSpec(0.0001));
243}
244
245XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) {
246  ComputationBuilder builder(client_, TestName());
247  Array3D<float> a_array({
248      // 3x1x2
249      {{0, 1}},
250      {{2, 3}},
251      {{4, 5}},
252  });
253  Array3D<float> b_array({
254      // 3x1x1
255      {{6}},
256      {{7}},
257      {{8}},
258  });
259  auto a = builder.ConstantR3FromArray3D(a_array);
260  auto b = builder.ConstantR3FromArray3D(b_array);
261  auto concatenated = builder.ConcatInDim({a, b}, 2);
262
263  Array3D<float> expected({
264      {{0, 1, 6}}, {{2, 3, 7}}, {{4, 5, 8}},
265  });
266  ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001));
267}
268
269XLA_TEST_F(ConcatTest, Concat_R1_1x1_1x1_1x1) {
270  ComputationBuilder builder(client_, TestName());
271  auto a = builder.ConstantR1<float>({42.0});
272  auto b = builder.ConstantR1<float>({64.0});
273  auto c = builder.ConstantR1<float>({256.0});
274  auto concatenated = builder.ConcatInDim({a, b, c}, 0);
275
276  std::vector<float> expected = {42, 64, 256};
277  ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
278}
279
280XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) {
281  ComputationBuilder builder(client_, TestName());
282  Array3D<float> a_array({
283      // 3x1x2
284      {{0, 1}},
285      {{4, 5}},
286      {{8, 9}},
287  });
288  Array3D<float> b_array({
289      // 3x1x1
290      {{2}},
291      {{6}},
292      {{10}},
293  });
294  Array3D<float> c_array({
295      // 3x1x1
296      {{3}},
297      {{7}},
298      {{11}},
299  });
300  auto a = builder.ConstantR3FromArray3D(a_array);
301  auto b = builder.ConstantR3FromArray3D(b_array);
302  auto c = builder.ConstantR3FromArray3D(c_array);
303  auto concatenated = builder.ConcatInDim({a, b, c}, 2);
304
305  Array3D<float> expected({
306      {{0, 1, 2, 3}}, {{4, 5, 6, 7}}, {{8, 9, 10, 11}},
307  });
308  ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001));
309}
310
311XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) {
312  ComputationBuilder builder(client_, TestName());
313  auto a = builder.ConstantR1<float>({42.0});
314  auto b = builder.ConstantR1<float>({64.0});
315  auto c = builder.ConstantR1<float>({256.0});
316  // concatenated = (a concat b) concat c
317  auto concatenated =
318      builder.ConcatInDim({builder.ConcatInDim({a, b}, 0), c}, 0);
319
320  std::vector<float> expected = {42, 64, 256};
321  ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
322}
323
324XLA_TEST_F(ConcatTest, DoubleConcatRightAssociative) {
325  ComputationBuilder builder(client_, TestName());
326  auto a = builder.ConstantR1<float>({42.0});
327  auto b = builder.ConstantR1<float>({64.0});
328  auto c = builder.ConstantR1<float>({256.0});
329  // concatenated = a concat (b concat c)
330  auto concatenated =
331      builder.ConcatInDim({a, builder.ConcatInDim({b, c}, 0)}, 0);
332
333  std::vector<float> expected = {42, 64, 256};
334  ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
335}
336
337XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim0) {
338  Array2D<float> lhs(1, 1024);
339  Array2D<float> rhs(1, 1024);
340  for (int i = 0; i < 1024; ++i) {
341    lhs(0, i) = i;
342    rhs(0, i) = i + 1024;
343  }
344
345  ComputationBuilder builder(client_, TestName());
346  auto a = builder.ConstantR2FromArray2D<float>(lhs);
347  auto b = builder.ConstantR2FromArray2D<float>(rhs);
348  builder.ConcatInDim({a, b}, 0);
349
350  Array2D<float> expected(2, 1024);
351  for (int i = 0; i < 1024; ++i) {
352    expected(0, i) = i;
353    expected(1, i) = i + 1024;
354  }
355  ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
356}
357
358XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim1) {
359  Array2D<float> lhs(1, 1024);
360  Array2D<float> rhs(1, 1024);
361  for (int i = 0; i < 1024; ++i) {
362    lhs(0, i) = i;
363    rhs(0, i) = i + 1024;
364  }
365
366  ComputationBuilder builder(client_, TestName());
367  auto a = builder.ConstantR2FromArray2D<float>(lhs);
368  auto b = builder.ConstantR2FromArray2D<float>(rhs);
369  builder.ConcatInDim({a, b}, 1);
370
371  Array2D<float> expected(1, 2048);
372  for (int i = 0; i < 1024; ++i) {
373    expected(0, i) = i;
374    expected(0, i + 1024) = i + 1024;
375  }
376  ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
377}
378
379XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) {
380  Array2D<float> lhs(64, 64);
381  Array2D<float> rhs(64, 2);
382  for (int i0 = 0; i0 < 64; ++i0) {
383    for (int i1 = 0; i1 < 64; ++i1) {
384      lhs(i0, i1) = (i0 << 10) | i1;
385    }
386    for (int i1 = 0; i1 < 2; ++i1) {
387      rhs(i0, i1) = (i0 << 10) | (i1 + 64);
388    }
389  }
390
391  ComputationBuilder builder(client_, TestName());
392  auto a = builder.ConstantR2FromArray2D<float>(lhs);
393  auto b = builder.ConstantR2FromArray2D<float>(rhs);
394  builder.ConcatInDim({a, b}, 1);
395
396  Array2D<float> expected(64, 66);
397  for (int i0 = 0; i0 < 64; ++i0) {
398    for (int i1 = 0; i1 < 66; ++i1) {
399      expected(i0, i1) = (i0 << 10) | i1;
400    }
401  }
402  ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
403}
404
405// Show that we can't concatenate with an opaques.
406XLA_TEST_F(ConcatTest, CannotConcatOpaques) {
407  ComputationBuilder builder(client_, TestName());
408  auto opaque_shape = ShapeUtil::MakeOpaqueShape();
409  auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1});
410  auto x = builder.Parameter(0, r1f32, "x");
411  auto y = builder.Parameter(1, opaque_shape, "y");
412  auto concatenated = builder.ConcatInDim({x, y}, 0);
413  StatusOr<Computation> computation_status = builder.Build();
414  ASSERT_FALSE(computation_status.ok());
415  EXPECT_THAT(
416      computation_status.status().ToString(),
417      HasSubstr("Expected non-opaque argument for operand of concatenation"));
418}
419
420XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) {
421  ComputationBuilder builder(client_, TestName());
422  auto p0 = builder.ConstantR1<bool>({true});
423  auto p1 = builder.ConstantR1<bool>({false});
424  auto p2 = builder.ConstantR1<bool>({true});
425  auto concatenated = builder.ConcatInDim({p0, p1, p2}, 0);
426
427  bool expected[] = {true, false, true};
428  ComputeAndCompareR1<bool>(&builder, expected, {});
429}
430
431XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) {
432  ComputationBuilder builder(client_, TestName());
433  auto a0 = builder.ConstantR1<int32>({1});
434  auto a1 = builder.ConstantR1<int32>({2, 3});
435  auto a2 = builder.ConstantR1<int32>({4, 5, 6});
436  auto a3 = builder.ConstantR1<int32>({7, 8, 9, 10});
437  auto concatenated = builder.ConcatInDim({a0, a1, a2, a3}, 0);
438
439  std::vector<int32> expected(10);
440  std::iota(expected.begin(), expected.end(), 1);
441  ComputeAndCompareR1<int32>(&builder, expected, {});
442}
443
444XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) {
445  ComputationBuilder builder(client_, TestName());
446
447  Array3D<float> arr0(9, 17, 1);
448  arr0.Fill(1);
449
450  Array3D<float> arr1(9, 17, 256);
451  arr1.Fill(2);
452
453  Array3D<float> expected(9, 17, arr0.n3() + arr1.n3());
454  for (int64 i = 0; i < expected.n1(); ++i) {
455    for (int64 j = 0; j < expected.n2(); ++j) {
456      int64 kk = 0;
457      for (const Array3D<float>& arr : {arr0, arr1}) {
458        for (int64 k = 0; k < arr.n3(); ++k, ++kk) {
459          expected(i, j, kk) = arr(i, j, k);
460        }
461      }
462    }
463  }
464
465  ComputationDataHandle h0;
466  auto p0 = CreateR3Parameter<float>(arr0, /*parameter_number=*/0, "p0",
467                                     &builder, &h0);
468  ComputationDataHandle h1;
469  auto p1 = CreateR3Parameter<float>(arr1, /*parameter_number=*/1, "p1",
470                                     &builder, &h1);
471
472  auto concatenated = builder.ConcatInDim({h0, h1}, 2);
473
474  ComputeAndCompareR3<float>(&builder, expected, {p0.get(), p1.get()});
475}
476
477// Describes a binary rank-2 concatenation test.
478struct R2BinarySpec {
479  int64 lhs_dim0;
480  int64 lhs_dim1;
481  int64 rhs_dim0;
482  int64 rhs_dim1;
483  int64 concat_dimension;
484};
485
486// TEST_P harness for binary rank-2 concatenation.
487class ConcatR2BinaryTest : public ClientLibraryTestBase,
488                           public ::testing::WithParamInterface<R2BinarySpec> {
489};
490
491TEST_P(ConcatR2BinaryTest, DoIt) {
492  const R2BinarySpec& spec = GetParam();
493  Array2D<int32> lhs(spec.lhs_dim0, spec.lhs_dim1);
494  lhs.FillUnique();
495  Array2D<int32> rhs(spec.rhs_dim0, spec.rhs_dim1);
496  rhs.FillUnique(1000);
497
498  ComputationBuilder builder(client_, TestName());
499  auto a0 = builder.ConstantR2FromArray2D<int32>(lhs);
500  auto a1 = builder.ConstantR2FromArray2D<int32>(rhs);
501  builder.ConcatInDim({a0, a1}, spec.concat_dimension);
502
503  std::unique_ptr<Array2D<int32>> expected =
504      ReferenceUtil::Concat2D(lhs, rhs, spec.concat_dimension);
505  ComputeAndCompareR2<int32>(&builder, *expected, {});
506}
507
508// Regression test for b/31944287. x*y is used (at the same index) by all
509// operands of the concat. We should emit x*y in three incoming basic blocks of
510// the concat because these basic blocks are not control-equivalent.
511//
512//      x*y
513//    /  |   \
514// add1 add2 add3
515//    \  |   /
516//     concat
517XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) {
518  auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
519  auto x_literal = Literal::CreateR0<float>(2.f);
520  auto y_literal = Literal::CreateR0<float>(3.f);
521  auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
522  auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
523
524  ComputationBuilder builder(client_, TestName());
525  auto x = builder.Parameter(0, f32_scalar, "x");
526  auto y = builder.Parameter(1, f32_scalar, "y");
527  auto mul = builder.Mul(x, y);
528  auto add1 = builder.Add(mul, builder.ConstantR1<float>({1.f, 2.f}));
529  auto add2 = builder.Add(mul, builder.ConstantR1<float>({3.f, 4.f}));
530  auto add3 = builder.Add(mul, builder.ConstantR1<float>({5.f, 6.f}));
531  builder.ConcatInDim({add1, add2, add3}, /*dimension=*/0);
532
533  ComputeAndCompareR1<float>(&builder, {7., 8., 9., 10., 11., 12.},
534                             {x_data.get(), y_data.get()}, ErrorSpec(1e-4));
535}
536
537// Test that the HLO optimization to replace a concat of a bradcasted scalar
538// produces the correct result in rank 1.
539XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) {
540  auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
541  auto x_literal = Literal::CreateR1<float>({2.0f, 3.0f, 5.0f, 6.0f});
542  auto y_literal = Literal::CreateR0<float>(1.5f);
543  auto z_literal = Literal::CreateR0<float>(5.5f);
544  auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
545  auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
546  auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
547
548  ComputationBuilder builder(client_, TestName());
549  auto x = builder.Parameter(0, x_literal->shape(), "x");
550  auto y = builder.Parameter(1, f32_scalar, "y");
551  auto z = builder.Parameter(2, f32_scalar, "z");
552  auto bcast = builder.Broadcast(y, {5});
553  auto bcast2 = builder.Broadcast(z, {3});
554  auto concat = builder.ConcatInDim({bcast, x}, /*dimension=*/0);
555  builder.ConcatInDim({concat, bcast2}, /*dimension=*/0);
556
557  ComputeAndCompareR1<float>(
558      &builder,
559      {1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 2.0f, 3.0f, 5.0f, 6.0f, 5.5f, 5.5f, 5.5f},
560      {x_data.get(), y_data.get(), z_data.get()}, ErrorSpec(1e-4));
561}
562
563// Test that the HLO optimization to replace a concat of a bradcasted scalar
564// produces the correct result in rank 3 with both high and low padding in
565// different dimensions.
566XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) {
567  auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
568  Array3D<float> x3d(3, 5, 7, 3.14f);
569  auto x_literal = Literal::CreateR3FromArray3D<float>(x3d);
570  auto y_literal = Literal::CreateR0<float>(1.5f);
571  auto z_literal = Literal::CreateR0<float>(5.5f);
572  auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
573  auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
574  auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
575
576  ComputationBuilder builder(client_, TestName());
577  auto x = builder.Parameter(0, x_literal->shape(), "x");
578  auto y = builder.Parameter(1, f32_scalar, "y");
579  auto z = builder.Parameter(2, f32_scalar, "y");
580  auto y_bcast = builder.Broadcast(y, {1, 5, 7});
581  auto z_bcast = builder.Broadcast(z, {4, 1, 7});
582  auto concat = builder.ConcatInDim({y_bcast, x}, /*dimension=*/0);
583  builder.ConcatInDim({concat, z_bcast}, /*dimension=*/1);
584  Array3D<float> y_bcast3d(1, 5, 7, 1.5f);
585  Array3D<float> z_bcast3d(4, 1, 7, 5.5f);
586  auto concat0 = ReferenceUtil::Concat3D(y_bcast3d, x3d, 0);
587  auto concat1 = ReferenceUtil::Concat3D(*concat0, z_bcast3d, 1);
588
589  ComputeAndCompareR3<float>(&builder, *concat1,
590                             {x_data.get(), y_data.get(), z_data.get()},
591                             ErrorSpec(1e-4));
592}
593
594INSTANTIATE_TEST_CASE_P(ConcatR2BinaryTestInstantiation, ConcatR2BinaryTest,
595                        ::testing::Values(R2BinarySpec{1, 1, 1, 1, 0},
596                                          R2BinarySpec{1, 1, 1, 1, 1},
597                                          R2BinarySpec{4, 3, 4, 3, 0},
598                                          R2BinarySpec{4, 3, 4, 3, 1},
599                                          R2BinarySpec{7, 128, 1, 128, 0},
600                                          R2BinarySpec{8, 127, 8, 1, 1}));
601
602}  // namespace
603}  // namespace xla
604