1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include <memory> 17#include <vector> 18 19#include "tensorflow/compiler/xla/array2d.h" 20#include "tensorflow/compiler/xla/array3d.h" 21#include "tensorflow/compiler/xla/client/computation.h" 22#include "tensorflow/compiler/xla/client/computation_builder.h" 23#include "tensorflow/compiler/xla/client/local_client.h" 24#include "tensorflow/compiler/xla/reference_util.h" 25#include "tensorflow/compiler/xla/statusor.h" 26#include "tensorflow/compiler/xla/test.h" 27#include "tensorflow/compiler/xla/test_helpers.h" 28#include "tensorflow/compiler/xla/tests/client_library_test_base.h" 29#include "tensorflow/compiler/xla/tests/literal_test_util.h" 30#include "tensorflow/compiler/xla/tests/test_macros.h" 31#include "tensorflow/core/platform/test.h" 32 33namespace xla { 34namespace { 35 36using ConcatTest = ClientLibraryTestBase; 37using ::testing::HasSubstr; 38 39// Concatenate expects at least one argument. 40XLA_TEST_F(ConcatTest, Concat_Nothing) { 41 ComputationBuilder builder(client_, TestName()); 42 auto concatenated = builder.ConcatInDim({}, 0); 43 StatusOr<Computation> computation_status = builder.Build(); 44 ASSERT_FALSE(computation_status.ok()); 45 EXPECT_THAT(computation_status.status().ToString(), 46 HasSubstr("Concatenate expects at least one argument")); 47} 48 49// Concatenate with one argument works. 50XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) { 51 ComputationBuilder builder(client_, TestName()); 52 auto a = builder.ConstantR1<float>({42.0, 64.0}); 53 auto concatenated = builder.ConcatInDim({a}, 0); 54 55 std::vector<float> expected = {42, 64}; 56 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); 57} 58 59XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) { 60 ComputationBuilder builder(client_, TestName()); 61 auto a = builder.ConstantR1<float>({}); 62 auto concatenated = builder.ConcatInDim({a}, 0); 63 64 std::vector<float> expected = {}; 65 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); 66} 67 68// Show that we can't concatenate R0 with R0 because we can't name the dimension 69// to concatenate on. 70XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) { 71 ComputationBuilder builder(client_, TestName()); 72 auto a = builder.ConstantR0<float>(42.0); 73 auto b = builder.ConstantR0<float>(64.0); 74 auto concatenated = builder.ConcatInDim({a, b}, 0); 75 StatusOr<Computation> computation_status = builder.Build(); 76 ASSERT_FALSE(computation_status.ok()); 77 EXPECT_THAT(computation_status.status().ToString(), 78 HasSubstr("dimension to concatenate along out of bounds: 0")); 79} 80 81XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) { 82 ComputationBuilder builder(client_, TestName()); 83 auto a = builder.ConstantR1<float>({}); 84 auto b = builder.ConstantR1<float>({}); 85 auto concatenated = builder.ConcatInDim({a, b}, 0); 86 87 std::vector<float> expected = {}; 88 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); 89} 90 91XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) { 92 ComputationBuilder builder(client_, TestName()); 93 auto a = builder.ConstantR1<float>({}); 94 auto b = builder.ConstantR1<float>({256.0}); 95 auto concatenated = builder.ConcatInDim({a, b}, 0); 96 97 std::vector<float> expected = {256}; 98 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); 99} 100 101XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) { 102 ComputationBuilder builder(client_, TestName()); 103 auto a = builder.ConstantR1<float>({42.0, 64.0}); 104 auto b = builder.ConstantR1<float>({}); 105 auto concatenated = builder.ConcatInDim({a, b}, 0); 106 107 std::vector<float> expected = {42, 64}; 108 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); 109} 110 111XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L1) { 112 ComputationBuilder builder(client_, TestName()); 113 auto a = builder.ConstantR1<float>({42.0, 64.0}); 114 auto b = builder.ConstantR1<float>({256.0}); 115 auto concatenated = builder.ConcatInDim({a, b}, 0); 116 117 std::vector<float> expected = {42, 64, 256}; 118 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); 119} 120 121XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) { 122 std::vector<float> lhs(253); 123 std::vector<float> rhs(7); 124 std::vector<float> expected(253 + 7); 125 for (int i = 0; i < 253; ++i) { 126 expected[i] = lhs[i] = i + 1; 127 } 128 for (int i = 0; i < 7; ++i) { 129 expected[253 + i] = rhs[i] = 253 + i + 1; 130 } 131 132 ComputationBuilder builder(client_, TestName()); 133 auto a = builder.ConstantR1<float>(lhs); 134 auto b = builder.ConstantR1<float>(rhs); 135 auto concatenated = builder.ConcatInDim({a, b}, 0); 136 137 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); 138} 139 140XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) { 141 for (int dim : {0, 1}) { 142 ComputationBuilder builder(client_, TestName()); 143 auto a = builder.ConstantR2FromArray2D(Array2D<float>(0, 0)); 144 auto b = builder.ConstantR2FromArray2D(Array2D<float>(0, 0)); 145 auto concatenated = builder.ConcatInDim({a, b}, dim); 146 147 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {}, 148 ErrorSpec(0.0001)); 149 } 150} 151 152XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim0) { 153 ComputationBuilder builder(client_, TestName()); 154 auto a_array = CreatePatternedMatrix(1, 1); 155 auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0); 156 auto a = builder.ConstantR2FromArray2D(*a_array); 157 auto b = builder.ConstantR2FromArray2D(*b_array); 158 auto concatenated = builder.ConcatInDim({a, b}, 0); 159 160 Array2D<float> expected({ 161 {0}, {64}, 162 }); 163 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001)); 164} 165 166XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) { 167 ComputationBuilder builder(client_, TestName()); 168 auto a_array = CreatePatternedMatrix(1, 1); 169 auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0); 170 auto a = builder.ConstantR2FromArray2D(*a_array); 171 auto b = builder.ConstantR2FromArray2D(*b_array); 172 auto concatenated = builder.ConcatInDim({a, b}, 1); 173 174 Array2D<float> expected({ 175 {0, 64}, 176 }); 177 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001)); 178} 179 180XLA_TEST_F(ConcatTest, Concat2x0With2x5) { 181 ComputationBuilder builder(client_, TestName()); 182 auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0); 183 auto a = builder.ConstantR2FromArray2D(Array2D<float>(2, 0)); 184 auto b = builder.ConstantR2FromArray2D(*b_array); 185 auto concatenated = builder.ConcatInDim({a, b}, 1); 186 187 ComputeAndCompareR2<float>(&builder, *b_array, {}, ErrorSpec(0.0001)); 188} 189 190XLA_TEST_F(ConcatTest, Concat2x3With2x5) { 191 ComputationBuilder builder(client_, TestName()); 192 auto a_array = CreatePatternedMatrix(2, 3); 193 auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0); 194 auto a = builder.ConstantR2FromArray2D(*a_array); 195 auto b = builder.ConstantR2FromArray2D(*b_array); 196 auto concatenated = builder.ConcatInDim({a, b}, 1); 197 198 Array2D<float> expected({ 199 {0, 1, 2, 64, 65, 66, 67, 68}, 200 {1000, 1001, 1002, 1064, 1065, 1066, 1067, 1068}, 201 }); 202 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001)); 203} 204 205XLA_TEST_F(ConcatTest, Concat3x2With0x2) { 206 ComputationBuilder builder(client_, TestName()); 207 auto a_array = CreatePatternedMatrix(3, 2); 208 auto a = builder.ConstantR2FromArray2D(*a_array); 209 auto b = builder.ConstantR2FromArray2D(Array2D<float>(0, 2)); 210 auto concatenated = builder.ConcatInDim({a, b}, 0); 211 212 ComputeAndCompareR2<float>(&builder, *a_array, {}, ErrorSpec(0.0001)); 213} 214 215XLA_TEST_F(ConcatTest, Concat3x2With5x2) { 216 ComputationBuilder builder(client_, TestName()); 217 auto a_array = CreatePatternedMatrix(3, 2); 218 auto b_array = CreatePatternedMatrix(5, 2, /*offset=*/64.0); 219 auto a = builder.ConstantR2FromArray2D(*a_array); 220 auto b = builder.ConstantR2FromArray2D(*b_array); 221 auto concatenated = builder.ConcatInDim({a, b}, 0); 222 223 Array2D<float> expected({ 224 {0, 1}, 225 {1000, 1001}, 226 {2000, 2001}, 227 {64, 65}, 228 {1064, 1065}, 229 {2064, 2065}, 230 {3064, 3065}, 231 {4064, 4065}, 232 }); 233 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001)); 234} 235 236XLA_TEST_F(ConcatTest, Concat_R3_3x0x2_3x0x1) { 237 ComputationBuilder builder(client_, TestName()); 238 auto a = builder.ConstantR3FromArray3D(Array3D<float>(3, 0, 2)); 239 auto b = builder.ConstantR3FromArray3D(Array3D<float>(3, 0, 1)); 240 auto concatenated = builder.ConcatInDim({a, b}, 2); 241 ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 3), {}, 242 ErrorSpec(0.0001)); 243} 244 245XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) { 246 ComputationBuilder builder(client_, TestName()); 247 Array3D<float> a_array({ 248 // 3x1x2 249 {{0, 1}}, 250 {{2, 3}}, 251 {{4, 5}}, 252 }); 253 Array3D<float> b_array({ 254 // 3x1x1 255 {{6}}, 256 {{7}}, 257 {{8}}, 258 }); 259 auto a = builder.ConstantR3FromArray3D(a_array); 260 auto b = builder.ConstantR3FromArray3D(b_array); 261 auto concatenated = builder.ConcatInDim({a, b}, 2); 262 263 Array3D<float> expected({ 264 {{0, 1, 6}}, {{2, 3, 7}}, {{4, 5, 8}}, 265 }); 266 ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001)); 267} 268 269XLA_TEST_F(ConcatTest, Concat_R1_1x1_1x1_1x1) { 270 ComputationBuilder builder(client_, TestName()); 271 auto a = builder.ConstantR1<float>({42.0}); 272 auto b = builder.ConstantR1<float>({64.0}); 273 auto c = builder.ConstantR1<float>({256.0}); 274 auto concatenated = builder.ConcatInDim({a, b, c}, 0); 275 276 std::vector<float> expected = {42, 64, 256}; 277 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); 278} 279 280XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) { 281 ComputationBuilder builder(client_, TestName()); 282 Array3D<float> a_array({ 283 // 3x1x2 284 {{0, 1}}, 285 {{4, 5}}, 286 {{8, 9}}, 287 }); 288 Array3D<float> b_array({ 289 // 3x1x1 290 {{2}}, 291 {{6}}, 292 {{10}}, 293 }); 294 Array3D<float> c_array({ 295 // 3x1x1 296 {{3}}, 297 {{7}}, 298 {{11}}, 299 }); 300 auto a = builder.ConstantR3FromArray3D(a_array); 301 auto b = builder.ConstantR3FromArray3D(b_array); 302 auto c = builder.ConstantR3FromArray3D(c_array); 303 auto concatenated = builder.ConcatInDim({a, b, c}, 2); 304 305 Array3D<float> expected({ 306 {{0, 1, 2, 3}}, {{4, 5, 6, 7}}, {{8, 9, 10, 11}}, 307 }); 308 ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001)); 309} 310 311XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) { 312 ComputationBuilder builder(client_, TestName()); 313 auto a = builder.ConstantR1<float>({42.0}); 314 auto b = builder.ConstantR1<float>({64.0}); 315 auto c = builder.ConstantR1<float>({256.0}); 316 // concatenated = (a concat b) concat c 317 auto concatenated = 318 builder.ConcatInDim({builder.ConcatInDim({a, b}, 0), c}, 0); 319 320 std::vector<float> expected = {42, 64, 256}; 321 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); 322} 323 324XLA_TEST_F(ConcatTest, DoubleConcatRightAssociative) { 325 ComputationBuilder builder(client_, TestName()); 326 auto a = builder.ConstantR1<float>({42.0}); 327 auto b = builder.ConstantR1<float>({64.0}); 328 auto c = builder.ConstantR1<float>({256.0}); 329 // concatenated = a concat (b concat c) 330 auto concatenated = 331 builder.ConcatInDim({a, builder.ConcatInDim({b, c}, 0)}, 0); 332 333 std::vector<float> expected = {42, 64, 256}; 334 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); 335} 336 337XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim0) { 338 Array2D<float> lhs(1, 1024); 339 Array2D<float> rhs(1, 1024); 340 for (int i = 0; i < 1024; ++i) { 341 lhs(0, i) = i; 342 rhs(0, i) = i + 1024; 343 } 344 345 ComputationBuilder builder(client_, TestName()); 346 auto a = builder.ConstantR2FromArray2D<float>(lhs); 347 auto b = builder.ConstantR2FromArray2D<float>(rhs); 348 builder.ConcatInDim({a, b}, 0); 349 350 Array2D<float> expected(2, 1024); 351 for (int i = 0; i < 1024; ++i) { 352 expected(0, i) = i; 353 expected(1, i) = i + 1024; 354 } 355 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001)); 356} 357 358XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim1) { 359 Array2D<float> lhs(1, 1024); 360 Array2D<float> rhs(1, 1024); 361 for (int i = 0; i < 1024; ++i) { 362 lhs(0, i) = i; 363 rhs(0, i) = i + 1024; 364 } 365 366 ComputationBuilder builder(client_, TestName()); 367 auto a = builder.ConstantR2FromArray2D<float>(lhs); 368 auto b = builder.ConstantR2FromArray2D<float>(rhs); 369 builder.ConcatInDim({a, b}, 1); 370 371 Array2D<float> expected(1, 2048); 372 for (int i = 0; i < 1024; ++i) { 373 expected(0, i) = i; 374 expected(0, i + 1024) = i + 1024; 375 } 376 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001)); 377} 378 379XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) { 380 Array2D<float> lhs(64, 64); 381 Array2D<float> rhs(64, 2); 382 for (int i0 = 0; i0 < 64; ++i0) { 383 for (int i1 = 0; i1 < 64; ++i1) { 384 lhs(i0, i1) = (i0 << 10) | i1; 385 } 386 for (int i1 = 0; i1 < 2; ++i1) { 387 rhs(i0, i1) = (i0 << 10) | (i1 + 64); 388 } 389 } 390 391 ComputationBuilder builder(client_, TestName()); 392 auto a = builder.ConstantR2FromArray2D<float>(lhs); 393 auto b = builder.ConstantR2FromArray2D<float>(rhs); 394 builder.ConcatInDim({a, b}, 1); 395 396 Array2D<float> expected(64, 66); 397 for (int i0 = 0; i0 < 64; ++i0) { 398 for (int i1 = 0; i1 < 66; ++i1) { 399 expected(i0, i1) = (i0 << 10) | i1; 400 } 401 } 402 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001)); 403} 404 405// Show that we can't concatenate with an opaques. 406XLA_TEST_F(ConcatTest, CannotConcatOpaques) { 407 ComputationBuilder builder(client_, TestName()); 408 auto opaque_shape = ShapeUtil::MakeOpaqueShape(); 409 auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1}); 410 auto x = builder.Parameter(0, r1f32, "x"); 411 auto y = builder.Parameter(1, opaque_shape, "y"); 412 auto concatenated = builder.ConcatInDim({x, y}, 0); 413 StatusOr<Computation> computation_status = builder.Build(); 414 ASSERT_FALSE(computation_status.ok()); 415 EXPECT_THAT( 416 computation_status.status().ToString(), 417 HasSubstr("Expected non-opaque argument for operand of concatenation")); 418} 419 420XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) { 421 ComputationBuilder builder(client_, TestName()); 422 auto p0 = builder.ConstantR1<bool>({true}); 423 auto p1 = builder.ConstantR1<bool>({false}); 424 auto p2 = builder.ConstantR1<bool>({true}); 425 auto concatenated = builder.ConcatInDim({p0, p1, p2}, 0); 426 427 bool expected[] = {true, false, true}; 428 ComputeAndCompareR1<bool>(&builder, expected, {}); 429} 430 431XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) { 432 ComputationBuilder builder(client_, TestName()); 433 auto a0 = builder.ConstantR1<int32>({1}); 434 auto a1 = builder.ConstantR1<int32>({2, 3}); 435 auto a2 = builder.ConstantR1<int32>({4, 5, 6}); 436 auto a3 = builder.ConstantR1<int32>({7, 8, 9, 10}); 437 auto concatenated = builder.ConcatInDim({a0, a1, a2, a3}, 0); 438 439 std::vector<int32> expected(10); 440 std::iota(expected.begin(), expected.end(), 1); 441 ComputeAndCompareR1<int32>(&builder, expected, {}); 442} 443 444XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) { 445 ComputationBuilder builder(client_, TestName()); 446 447 Array3D<float> arr0(9, 17, 1); 448 arr0.Fill(1); 449 450 Array3D<float> arr1(9, 17, 256); 451 arr1.Fill(2); 452 453 Array3D<float> expected(9, 17, arr0.n3() + arr1.n3()); 454 for (int64 i = 0; i < expected.n1(); ++i) { 455 for (int64 j = 0; j < expected.n2(); ++j) { 456 int64 kk = 0; 457 for (const Array3D<float>& arr : {arr0, arr1}) { 458 for (int64 k = 0; k < arr.n3(); ++k, ++kk) { 459 expected(i, j, kk) = arr(i, j, k); 460 } 461 } 462 } 463 } 464 465 ComputationDataHandle h0; 466 auto p0 = CreateR3Parameter<float>(arr0, /*parameter_number=*/0, "p0", 467 &builder, &h0); 468 ComputationDataHandle h1; 469 auto p1 = CreateR3Parameter<float>(arr1, /*parameter_number=*/1, "p1", 470 &builder, &h1); 471 472 auto concatenated = builder.ConcatInDim({h0, h1}, 2); 473 474 ComputeAndCompareR3<float>(&builder, expected, {p0.get(), p1.get()}); 475} 476 477// Describes a binary rank-2 concatenation test. 478struct R2BinarySpec { 479 int64 lhs_dim0; 480 int64 lhs_dim1; 481 int64 rhs_dim0; 482 int64 rhs_dim1; 483 int64 concat_dimension; 484}; 485 486// TEST_P harness for binary rank-2 concatenation. 487class ConcatR2BinaryTest : public ClientLibraryTestBase, 488 public ::testing::WithParamInterface<R2BinarySpec> { 489}; 490 491TEST_P(ConcatR2BinaryTest, DoIt) { 492 const R2BinarySpec& spec = GetParam(); 493 Array2D<int32> lhs(spec.lhs_dim0, spec.lhs_dim1); 494 lhs.FillUnique(); 495 Array2D<int32> rhs(spec.rhs_dim0, spec.rhs_dim1); 496 rhs.FillUnique(1000); 497 498 ComputationBuilder builder(client_, TestName()); 499 auto a0 = builder.ConstantR2FromArray2D<int32>(lhs); 500 auto a1 = builder.ConstantR2FromArray2D<int32>(rhs); 501 builder.ConcatInDim({a0, a1}, spec.concat_dimension); 502 503 std::unique_ptr<Array2D<int32>> expected = 504 ReferenceUtil::Concat2D(lhs, rhs, spec.concat_dimension); 505 ComputeAndCompareR2<int32>(&builder, *expected, {}); 506} 507 508// Regression test for b/31944287. x*y is used (at the same index) by all 509// operands of the concat. We should emit x*y in three incoming basic blocks of 510// the concat because these basic blocks are not control-equivalent. 511// 512// x*y 513// / | \ 514// add1 add2 add3 515// \ | / 516// concat 517XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { 518 auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); 519 auto x_literal = Literal::CreateR0<float>(2.f); 520 auto y_literal = Literal::CreateR0<float>(3.f); 521 auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); 522 auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); 523 524 ComputationBuilder builder(client_, TestName()); 525 auto x = builder.Parameter(0, f32_scalar, "x"); 526 auto y = builder.Parameter(1, f32_scalar, "y"); 527 auto mul = builder.Mul(x, y); 528 auto add1 = builder.Add(mul, builder.ConstantR1<float>({1.f, 2.f})); 529 auto add2 = builder.Add(mul, builder.ConstantR1<float>({3.f, 4.f})); 530 auto add3 = builder.Add(mul, builder.ConstantR1<float>({5.f, 6.f})); 531 builder.ConcatInDim({add1, add2, add3}, /*dimension=*/0); 532 533 ComputeAndCompareR1<float>(&builder, {7., 8., 9., 10., 11., 12.}, 534 {x_data.get(), y_data.get()}, ErrorSpec(1e-4)); 535} 536 537// Test that the HLO optimization to replace a concat of a bradcasted scalar 538// produces the correct result in rank 1. 539XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { 540 auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); 541 auto x_literal = Literal::CreateR1<float>({2.0f, 3.0f, 5.0f, 6.0f}); 542 auto y_literal = Literal::CreateR0<float>(1.5f); 543 auto z_literal = Literal::CreateR0<float>(5.5f); 544 auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); 545 auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); 546 auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); 547 548 ComputationBuilder builder(client_, TestName()); 549 auto x = builder.Parameter(0, x_literal->shape(), "x"); 550 auto y = builder.Parameter(1, f32_scalar, "y"); 551 auto z = builder.Parameter(2, f32_scalar, "z"); 552 auto bcast = builder.Broadcast(y, {5}); 553 auto bcast2 = builder.Broadcast(z, {3}); 554 auto concat = builder.ConcatInDim({bcast, x}, /*dimension=*/0); 555 builder.ConcatInDim({concat, bcast2}, /*dimension=*/0); 556 557 ComputeAndCompareR1<float>( 558 &builder, 559 {1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 2.0f, 3.0f, 5.0f, 6.0f, 5.5f, 5.5f, 5.5f}, 560 {x_data.get(), y_data.get(), z_data.get()}, ErrorSpec(1e-4)); 561} 562 563// Test that the HLO optimization to replace a concat of a bradcasted scalar 564// produces the correct result in rank 3 with both high and low padding in 565// different dimensions. 566XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) { 567 auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); 568 Array3D<float> x3d(3, 5, 7, 3.14f); 569 auto x_literal = Literal::CreateR3FromArray3D<float>(x3d); 570 auto y_literal = Literal::CreateR0<float>(1.5f); 571 auto z_literal = Literal::CreateR0<float>(5.5f); 572 auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); 573 auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); 574 auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); 575 576 ComputationBuilder builder(client_, TestName()); 577 auto x = builder.Parameter(0, x_literal->shape(), "x"); 578 auto y = builder.Parameter(1, f32_scalar, "y"); 579 auto z = builder.Parameter(2, f32_scalar, "y"); 580 auto y_bcast = builder.Broadcast(y, {1, 5, 7}); 581 auto z_bcast = builder.Broadcast(z, {4, 1, 7}); 582 auto concat = builder.ConcatInDim({y_bcast, x}, /*dimension=*/0); 583 builder.ConcatInDim({concat, z_bcast}, /*dimension=*/1); 584 Array3D<float> y_bcast3d(1, 5, 7, 1.5f); 585 Array3D<float> z_bcast3d(4, 1, 7, 5.5f); 586 auto concat0 = ReferenceUtil::Concat3D(y_bcast3d, x3d, 0); 587 auto concat1 = ReferenceUtil::Concat3D(*concat0, z_bcast3d, 1); 588 589 ComputeAndCompareR3<float>(&builder, *concat1, 590 {x_data.get(), y_data.get(), z_data.get()}, 591 ErrorSpec(1e-4)); 592} 593 594INSTANTIATE_TEST_CASE_P(ConcatR2BinaryTestInstantiation, ConcatR2BinaryTest, 595 ::testing::Values(R2BinarySpec{1, 1, 1, 1, 0}, 596 R2BinarySpec{1, 1, 1, 1, 1}, 597 R2BinarySpec{4, 3, 4, 3, 0}, 598 R2BinarySpec{4, 3, 4, 3, 1}, 599 R2BinarySpec{7, 128, 1, 128, 0}, 600 R2BinarySpec{8, 127, 8, 1, 1})); 601 602} // namespace 603} // namespace xla 604