1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include <initializer_list> 17#include <memory> 18 19#include "tensorflow/compiler/xla/array2d.h" 20#include "tensorflow/compiler/xla/client/computation.h" 21#include "tensorflow/compiler/xla/client/computation_builder.h" 22#include "tensorflow/compiler/xla/client/local_client.h" 23#include "tensorflow/compiler/xla/literal_util.h" 24#include "tensorflow/compiler/xla/shape_util.h" 25#include "tensorflow/compiler/xla/statusor.h" 26#include "tensorflow/compiler/xla/test_helpers.h" 27#include "tensorflow/compiler/xla/tests/client_library_test_base.h" 28#include "tensorflow/compiler/xla/tests/literal_test_util.h" 29#include "tensorflow/compiler/xla/tests/test_macros.h" 30#include "tensorflow/compiler/xla/xla_data.pb.h" 31#include "tensorflow/core/platform/test.h" 32 33namespace xla { 34namespace { 35 36class TupleTest : public ClientLibraryTestBase { 37 public: 38 ErrorSpec error_spec_{0.0001}; 39}; 40 41// Tests a tuple-shaped constant. 42XLA_TEST_F(TupleTest, TupleConstant) { 43 ComputationBuilder builder(client_, TestName()); 44 45 const float constant_scalar = 7.3f; 46 std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f}; 47 std::initializer_list<std::initializer_list<float>> constant_matrix = { 48 {1.1f, 2.2f, 3.5f}, // row 0 49 {4.8f, 5.0f, 6.7f}, // row 1 50 }; 51 auto value = 52 Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar).get(), 53 Literal::CreateR1<float>(constant_vector).get(), 54 Literal::CreateR2<float>(constant_matrix).get()}); 55 56 auto result = builder.ConstantLiteral(*value); 57 ComputeAndCompareTuple(&builder, *value, {}, error_spec_); 58} 59 60// Tests a tuple made of scalar constants. 61XLA_TEST_F(TupleTest, TupleScalarConstant) { 62 ComputationBuilder builder(client_, TestName()); 63 64 const float constant_scalar1 = 7.3f; 65 const float constant_scalar2 = 1.2f; 66 auto value = 67 Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar1).get(), 68 Literal::CreateR0<float>(constant_scalar2).get()}); 69 70 auto result = builder.ConstantLiteral(*value); 71 ComputeAndCompareTuple(&builder, *value, {}, error_spec_); 72} 73 74// Tests the creation of tuple data. 75XLA_TEST_F(TupleTest, TupleCreate) { 76 ComputationBuilder builder(client_, TestName()); 77 78 const float constant_scalar = 7.3f; 79 std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f}; 80 std::initializer_list<std::initializer_list<float>> constant_matrix = { 81 {1.1f, 2.2f, 3.5f}, // row 0 82 {4.8f, 5.0f, 6.7f}, // row 1 83 }; 84 auto result = builder.Tuple({builder.ConstantR0<float>(constant_scalar), 85 builder.ConstantR1<float>(constant_vector), 86 builder.ConstantR2<float>(constant_matrix)}); 87 88 auto expected = 89 Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar).get(), 90 Literal::CreateR1<float>(constant_vector).get(), 91 Literal::CreateR2<float>(constant_matrix).get()}); 92 ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); 93} 94 95// Tests the creation of tuple data. 96XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { 97 ComputationBuilder builder(client_, TestName()); 98 99 auto result = builder.Tuple( 100 {builder.ConstantR0<float>(7.0), builder.ConstantR1<float>({})}); 101 102 auto expected = Literal::MakeTuple({Literal::CreateR0<float>(7.0).get(), 103 Literal::CreateR1<float>({}).get()}); 104 ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); 105} 106 107// Tests the creation of an empty tuple. 108XLA_TEST_F(TupleTest, EmptyTupleCreate) { 109 ComputationBuilder builder(client_, TestName()); 110 auto result = builder.Tuple({}); 111 auto expected = Literal::MakeTuple({}); 112 ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); 113} 114 115// Trivial test for extracting a tuple element with GetTupleElement. 116XLA_TEST_F(TupleTest, GetTupleElement) { 117 ComputationBuilder builder(client_, TestName()); 118 std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f}; 119 std::initializer_list<std::initializer_list<float>> constant_matrix = { 120 {1.f, 2.f, 3.f}, // row 0 121 {4.f, 5.f, 6.f}, // row 1 122 }; 123 auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector), 124 builder.ConstantR2<float>(constant_matrix)}); 125 auto matrix_element = builder.GetTupleElement(tuple_data, 1); 126 ComputeAndCompareR2<float>(&builder, Array2D<float>(constant_matrix), {}, 127 error_spec_); 128} 129 130// Trivial test for extracting a tuple element with GetTupleElement. 131XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) { 132 ComputationBuilder builder(client_, TestName()); 133 auto tuple_data = builder.Tuple( 134 {builder.ConstantR1<float>({}), 135 builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 101))}); 136 auto matrix_element = builder.GetTupleElement(tuple_data, 1); 137 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 101), {}, error_spec_); 138} 139 140XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) { 141 ComputationBuilder builder(client_, TestName()); 142 auto value = builder.ConstantR1<float>({4.5f}); 143 builder.GetTupleElement(value, 1); 144 auto result_status = builder.Build(); 145 EXPECT_FALSE(result_status.ok()); 146 EXPECT_THAT( 147 result_status.status().error_message(), 148 ::testing::HasSubstr("Operand to GetTupleElement() is not a tuple")); 149} 150 151// Extracts both elements from a tuple with GetTupleElement and then adds them 152// together. 153XLA_TEST_F(TupleTest, AddTupleElements) { 154 ComputationBuilder builder(client_, TestName()); 155 std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f}; 156 std::initializer_list<std::initializer_list<float>> constant_matrix = { 157 {1.f, 2.f, 3.f}, // row 0 158 {4.f, 5.f, 6.f}, // row 1 159 }; 160 auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector), 161 builder.ConstantR2<float>(constant_matrix)}); 162 auto vector_element = builder.GetTupleElement(tuple_data, 0); 163 auto matrix_element = builder.GetTupleElement(tuple_data, 1); 164 auto vector_shape = builder.GetShape(vector_element).ConsumeValueOrDie(); 165 auto matrix_shape = builder.GetShape(matrix_element).ConsumeValueOrDie(); 166 auto result = builder.Add(matrix_element, vector_element, 167 /*broadcast_dimensions=*/{1}); 168 169 Array2D<float> expected({ 170 {2.f, 4.f, 6.f}, // row 0 171 {5.f, 7.f, 9.f}, // row 1 172 }); 173 ASSERT_TRUE(ShapeUtil::ShapeIs(*vector_shape, F32, {3})); 174 ASSERT_TRUE(ShapeUtil::ShapeIs(*matrix_shape, F32, {/*y=*/2, /*x=*/3})); 175 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); 176} 177 178// Extracts both elements from a tuple and then puts them into a new tuple in 179// the opposite order. 180XLA_TEST_F(TupleTest, TupleGTEToTuple) { 181 ComputationBuilder builder(client_, TestName()); 182 std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f}; 183 std::initializer_list<std::initializer_list<float>> constant_matrix = { 184 {1.f, 2.f, 3.f}, // row 0 185 {4.f, 5.f, 6.f}, // row 1 186 }; 187 auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector), 188 builder.ConstantR2<float>(constant_matrix)}); 189 auto new_tuple = builder.Tuple({builder.GetTupleElement(tuple_data, 1), 190 builder.GetTupleElement(tuple_data, 0)}); 191 auto expected = 192 Literal::MakeTuple({Literal::CreateR2<float>(constant_matrix).get(), 193 Literal::CreateR1<float>(constant_vector).get()}); 194 ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); 195} 196 197XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { 198 ComputationBuilder b(client_, TestName()); 199 ComputationDataHandle v1, v2; 200 201 for (bool direction : {false, true}) { 202 std::unique_ptr<GlobalData> v1_data = 203 CreateR0Parameter<float>(0.0f, /*parameter_number=*/0, /*name=*/"v1", 204 /*builder=*/&b, /*data_handle=*/&v1); 205 std::unique_ptr<GlobalData> v2_data = 206 CreateR0Parameter<float>(1.0f, /*parameter_number=*/1, /*name=*/"v2", 207 /*builder=*/&b, /*data_handle=*/&v2); 208 auto v1_gt = b.Gt(v1, v2); // false 209 auto v2_gt = b.Gt(v2, v1); // true 210 auto v1_v2 = b.Tuple({v1_gt, v2_gt}); // {false, true} 211 auto v2_v1 = b.Tuple({v2_gt, v1_gt}); // {true, false} 212 auto select = b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); 213 auto expected = 214 Literal::MakeTuple({Literal::CreateR0<bool>(direction).get(), 215 Literal::CreateR0<bool>(!direction).get()}); 216 217 ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()}, 218 error_spec_); 219 } 220} 221 222// Builds two new tuples from an existing tuple (by means of GetTupleElement), 223// then adds up the components of the new tuples. 224XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) { 225 // 226 // v------ --(GTE 0)-- --(GTE 0)---------- 227 // \ / \ / \ 228 // (tuple)-- (tuple01)-- \ 229 // / | \ / \ \ 230 // m------ | --(GTE 1)-- --(GTE 1)------------ \ 231 // | \ \ 232 // | (add) 233 // | / / 234 // |--------(GTE 1)-- --(GTE 0)------------ / 235 // \ \ / / 236 // \ (tuple10)-- / 237 // \ / \ / 238 // -----(GTE 0)-- --(GTE 1)---------- 239 ComputationBuilder builder(client_, TestName()); 240 std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f}; 241 std::initializer_list<std::initializer_list<float>> constant_matrix = { 242 {1.f, 2.f, 3.f}, // row 0 243 {4.f, 5.f, 6.f}, // row 1 244 }; 245 auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector), 246 builder.ConstantR2<float>(constant_matrix)}); 247 auto new_tuple01 = builder.Tuple({builder.GetTupleElement(tuple_data, 0), 248 builder.GetTupleElement(tuple_data, 1)}); 249 auto new_tuple10 = builder.Tuple({builder.GetTupleElement(tuple_data, 1), 250 builder.GetTupleElement(tuple_data, 0)}); 251 auto vector_from_01 = builder.GetTupleElement(new_tuple01, 0); 252 auto vector_from_10 = builder.GetTupleElement(new_tuple10, 1); 253 auto matrix_from_01 = builder.GetTupleElement(new_tuple01, 1); 254 auto matrix_from_10 = builder.GetTupleElement(new_tuple10, 0); 255 256 auto addvectors = builder.Add(vector_from_01, vector_from_10); 257 auto addmatrices = builder.Add(matrix_from_01, matrix_from_10); 258 259 auto result = builder.Add(addmatrices, addvectors, 260 /*broadcast_dimensions=*/{1}); 261 262 Array2D<float> expected({ 263 {4.f, 8.f, 12.f}, // row 0 264 {10.f, 14.f, 18.f}, // row 1 265 }); 266 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); 267} 268 269XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) { 270 // Tests a selection between tuples with "false" path taken. 271 ComputationBuilder builder(client_, TestName()); 272 273 std::initializer_list<float> vec1 = {1.f, 2.f, 3.f}; 274 std::initializer_list<float> vec2 = {2.f, 4.f, 6.f}; 275 auto tuple12 = builder.Tuple( 276 {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)}); 277 auto tuple21 = builder.Tuple( 278 {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)}); 279 280 auto select = 281 builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21); 282 auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec2).get(), 283 Literal::CreateR1<float>(vec1).get()}); 284 ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); 285} 286 287XLA_TEST_F(TupleTest, TuplesInAMap) { 288 Computation tuple_computation; 289 { 290 // tuple_computation(x) = 100 * min(x, x^2) + max(x, x^2) using tuples. 291 // 292 // Need to put a select in there to prevent HLO-level optimizations from 293 // optimizing out the tuples. 294 ComputationBuilder b(client_, "sort_square"); 295 auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); 296 auto x2 = b.Mul(x, x); 297 auto x_smaller_tuple = b.Tuple({x, x2}); 298 auto x2_smaller_tuple = b.Tuple({x2, x}); 299 auto sorted = b.Select(b.Lt(x, x2), x_smaller_tuple, x2_smaller_tuple); 300 auto smaller = b.GetTupleElement(sorted, 0); 301 auto greater = b.GetTupleElement(sorted, 1); 302 b.Add(greater, b.Mul(b.ConstantR0<float>(100.0f), smaller)); 303 auto computation_status = b.Build(); 304 ASSERT_IS_OK(computation_status.status()); 305 tuple_computation = computation_status.ConsumeValueOrDie(); 306 } 307 308 ComputationBuilder b(client_, TestName()); 309 auto input = b.ConstantR1<float>({-1.0f, 1.0f, 2.1f}); 310 b.Map({input}, tuple_computation, {0}); 311 ComputeAndCompareR1<float>(&b, {-99.0f, 101.0f, 214.41f}, {}, error_spec_); 312} 313 314XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) { 315 // Tests a selection between tuples with "true" path taken. 316 ComputationBuilder builder(client_, TestName()); 317 318 std::initializer_list<float> vec1 = {1.f, 2.f, 3.f}; 319 std::initializer_list<float> vec2 = {2.f, 4.f, 6.f}; 320 auto tuple12 = builder.Tuple( 321 {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)}); 322 auto tuple21 = builder.Tuple( 323 {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)}); 324 325 auto select = 326 builder.Select(builder.ConstantR0<bool>(true), tuple12, tuple21); 327 auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec1).get(), 328 Literal::CreateR1<float>(vec2).get()}); 329 ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); 330} 331 332XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) { 333 // Tests a selection between tuples but the final result is an element of the 334 // tuple, not the whole tuple. 335 ComputationBuilder builder(client_, TestName()); 336 337 std::initializer_list<float> vec1 = {1.f, 2.f, 3.f}; 338 std::initializer_list<float> vec2 = {2.f, 4.f, 6.f}; 339 auto tuple12 = builder.Tuple( 340 {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)}); 341 auto tuple21 = builder.Tuple( 342 {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)}); 343 344 auto select = 345 builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21); 346 auto element = builder.GetTupleElement(select, 0); 347 348 ComputeAndCompareR1<float>(&builder, vec2, {}, error_spec_); 349} 350 351// Cascaded selects between tuple types. 352XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesCascaded)) { 353 // 354 // vec1 vec2 vec2 vec1 355 // | | | | 356 // | | | | 357 // (tuple 12) (tuple 21) 358 // \ / 359 // \ / 360 // \ / 361 // true -- --(GTE 0)--(select 1) 362 // \ / | 363 // (pred tuple)-- | --(GTE 0)-- 364 // / \ V / \ 365 // false -- --(GTE 1)--(select 2)-- --(add) 366 // / \ / 367 // / --(GTE 1)-- 368 // / 369 // (tuple 21) 370 ComputationBuilder builder(client_, TestName()); 371 372 std::initializer_list<float> vec1 = {1.f, 2.f, 3.f}; 373 std::initializer_list<float> vec2 = {2.f, 4.f, 6.f}; 374 375 auto pred_tuple = builder.Tuple( 376 {builder.ConstantR0<bool>(true), builder.ConstantR0<bool>(false)}); 377 auto tuple12 = builder.Tuple( 378 {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)}); 379 auto tuple21 = builder.Tuple( 380 {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)}); 381 382 auto select1 = 383 builder.Select(builder.GetTupleElement(pred_tuple, 0), tuple12, tuple21); 384 auto select2 = 385 builder.Select(builder.GetTupleElement(pred_tuple, 1), tuple21, select1); 386 auto result = builder.Add(builder.GetTupleElement(select2, 0), 387 builder.GetTupleElement(select2, 1)); 388 389 ComputeAndCompareR1<float>(&builder, {3.f, 6.f, 9.f}, {}, error_spec_); 390} 391 392XLA_TEST_F(TupleTest, 393 DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesReuseConstants)) { 394 // Similar to SelectBetweenTuples, but the constants are shared between the 395 // input tuples. 396 ComputationBuilder builder(client_, TestName()); 397 398 std::initializer_list<float> vec1 = {1.f, 2.f, 3.f}; 399 std::initializer_list<float> vec2 = {2.f, 4.f, 6.f}; 400 auto c1 = builder.ConstantR1<float>(vec1); 401 auto c2 = builder.ConstantR1<float>(vec2); 402 auto tuple12 = builder.Tuple({c1, c2}); 403 auto tuple21 = builder.Tuple({c2, c1}); 404 405 auto select = 406 builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21); 407 auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec2).get(), 408 Literal::CreateR1<float>(vec1).get()}); 409 ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); 410} 411 412XLA_TEST_F(TupleTest, NestedTuples) { 413 ComputationBuilder builder(client_, TestName()); 414 auto inner_tuple = builder.Tuple( 415 {builder.ConstantR1<float>({1.0, 2.0}), builder.ConstantR0<float>(42.0)}); 416 auto outer_tuple = 417 builder.Tuple({inner_tuple, builder.ConstantR1<float>({22.0, 44.0})}); 418 419 auto expected_v1 = Literal::CreateR1<float>({1.0, 2.0}); 420 auto expected_s = Literal::CreateR0<float>(42.0); 421 auto expected_inner_tuple = 422 Literal::MakeTuple({expected_v1.get(), expected_s.get()}); 423 auto expected_v2 = Literal::CreateR1<float>({22.0, 44.0}); 424 auto expected = 425 Literal::MakeTuple({expected_inner_tuple.get(), expected_v2.get()}); 426 427 ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); 428} 429 430XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { 431 ComputationBuilder builder(client_, TestName()); 432 433 Shape data_shape = ShapeUtil::MakeShape(F32, {3}); 434 Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({data_shape, data_shape}); 435 Shape outer_tuple_shape = 436 ShapeUtil::MakeTupleShape({inner_tuple_shape, data_shape}); 437 438 auto input = builder.Parameter(0, outer_tuple_shape, "input"); 439 auto gte0 = builder.GetTupleElement(input, 0); 440 auto gte1 = builder.GetTupleElement(gte0, 1); 441 builder.Add(gte1, builder.ConstantR1<float>({10.0, 11.0, 12.0})); 442 443 std::unique_ptr<GlobalData> data = 444 client_ 445 ->TransferToServer(*Literal::MakeTuple({ 446 Literal::MakeTuple( 447 { 448 Literal::CreateR1<float>({1.0, 2.0, 3.0}).get(), 449 Literal::CreateR1<float>({4.0, 5.0, 6.0}).get(), 450 }) 451 .get(), 452 Literal::CreateR1<float>({7.0, 8.0, 9.0}).get(), 453 })) 454 .ConsumeValueOrDie(); 455 456 std::vector<GlobalData*> arguments = {data.get()}; 457 const std::vector<float> expected = {4.0 + 10.0, 5.0 + 11.0, 6.0 + 12.0}; 458 ComputeAndCompareR1<float>(&builder, expected, arguments, ErrorSpec(1e-5)); 459} 460 461XLA_TEST_F(TupleTest, ComplexTuples) { 462 ComputationBuilder builder(client_, TestName()); 463 { 464 Shape c64r0 = ShapeUtil::MakeShape(C64, {}); 465 Shape c64r1 = ShapeUtil::MakeShape(C64, {2}); 466 Shape c64r2 = ShapeUtil::MakeShape(C64, {3, 2}); 467 Shape arg0_shape = ShapeUtil::MakeTupleShape( 468 {c64r0, ShapeUtil::MakeTupleShape({c64r1, c64r2})}); 469 auto input0 = builder.Parameter(0, arg0_shape, "input0"); 470 auto t0 = builder.GetTupleElement(input0, 0); 471 auto t1 = builder.GetTupleElement(input0, 1); 472 auto t10 = builder.GetTupleElement(t1, 0); 473 auto t11 = builder.GetTupleElement(t1, 1); 474 auto sum = builder.Add(builder.Add(t10, t11, {1}), t0); 475 auto input1 = builder.Parameter(1, c64r1, "input1"); 476 auto prod = builder.Mul(input1, sum, {1}); 477 builder.Tuple({builder.Tuple({prod, sum}), 478 builder.ConstantR0<complex64>({123, 456})}); 479 } 480 481 std::unique_ptr<GlobalData> arg0 = 482 client_ 483 ->TransferToServer(*Literal::MakeTuple( 484 {Literal::CreateR0<complex64>({1, 2}).get(), 485 Literal::MakeTuple( 486 {Literal::CreateR1<complex64>({{10, 20}, {30, 40}}).get(), 487 Literal::CreateR2<complex64>( 488 {{{100, 200}, {300, 400}}, 489 {{1000, 2000}, {3000, 4000}}, 490 {{10000, 20000}, {30000, 40000}}}) 491 .get()}) 492 .get()})) 493 .ConsumeValueOrDie(); 494 std::unique_ptr<GlobalData> arg1 = 495 client_ 496 ->TransferToServer(*Literal::CreateR1<complex64>({{1, 2}, {1, -2}})) 497 .ConsumeValueOrDie(); 498 auto sum = Literal::CreateR2<complex64>({{{111, 222}, {331, 442}}, 499 {{1011, 2022}, {3031, 4042}}, 500 {{10011, 20022}, {30031, 40042}}}); 501 auto prod = Literal::CreateFromShape(sum->shape()); 502 ASSERT_TRUE(prod->Populate<complex64>( 503 [&sum](tensorflow::gtl::ArraySlice<int64> indexes) { 504 return sum->Get<complex64>(indexes) * 505 (indexes[indexes.size() - 1] == 0 506 ? complex64(1, 2) 507 : complex64(1, -2)); 508 }) 509 .ok()); 510 auto expected = 511 Literal::MakeTuple({Literal::MakeTuple({prod.get(), sum.get()}).get(), 512 Literal::CreateR0<complex64>({123, 456}).get()}); 513 ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()}, 514 error_spec_); 515} 516 517} // namespace 518} // namespace xla 519