1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <initializer_list>
17#include <memory>
18
19#include "tensorflow/compiler/xla/array2d.h"
20#include "tensorflow/compiler/xla/client/computation.h"
21#include "tensorflow/compiler/xla/client/computation_builder.h"
22#include "tensorflow/compiler/xla/client/local_client.h"
23#include "tensorflow/compiler/xla/literal_util.h"
24#include "tensorflow/compiler/xla/shape_util.h"
25#include "tensorflow/compiler/xla/statusor.h"
26#include "tensorflow/compiler/xla/test_helpers.h"
27#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
28#include "tensorflow/compiler/xla/tests/literal_test_util.h"
29#include "tensorflow/compiler/xla/tests/test_macros.h"
30#include "tensorflow/compiler/xla/xla_data.pb.h"
31#include "tensorflow/core/platform/test.h"
32
33namespace xla {
34namespace {
35
36class TupleTest : public ClientLibraryTestBase {
37 public:
38  ErrorSpec error_spec_{0.0001};
39};
40
41// Tests a tuple-shaped constant.
42XLA_TEST_F(TupleTest, TupleConstant) {
43  ComputationBuilder builder(client_, TestName());
44
45  const float constant_scalar = 7.3f;
46  std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
47  std::initializer_list<std::initializer_list<float>> constant_matrix = {
48      {1.1f, 2.2f, 3.5f},  // row 0
49      {4.8f, 5.0f, 6.7f},  // row 1
50  };
51  auto value =
52      Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar).get(),
53                          Literal::CreateR1<float>(constant_vector).get(),
54                          Literal::CreateR2<float>(constant_matrix).get()});
55
56  auto result = builder.ConstantLiteral(*value);
57  ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
58}
59
60// Tests a tuple made of scalar constants.
61XLA_TEST_F(TupleTest, TupleScalarConstant) {
62  ComputationBuilder builder(client_, TestName());
63
64  const float constant_scalar1 = 7.3f;
65  const float constant_scalar2 = 1.2f;
66  auto value =
67      Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar1).get(),
68                          Literal::CreateR0<float>(constant_scalar2).get()});
69
70  auto result = builder.ConstantLiteral(*value);
71  ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
72}
73
74// Tests the creation of tuple data.
75XLA_TEST_F(TupleTest, TupleCreate) {
76  ComputationBuilder builder(client_, TestName());
77
78  const float constant_scalar = 7.3f;
79  std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
80  std::initializer_list<std::initializer_list<float>> constant_matrix = {
81      {1.1f, 2.2f, 3.5f},  // row 0
82      {4.8f, 5.0f, 6.7f},  // row 1
83  };
84  auto result = builder.Tuple({builder.ConstantR0<float>(constant_scalar),
85                               builder.ConstantR1<float>(constant_vector),
86                               builder.ConstantR2<float>(constant_matrix)});
87
88  auto expected =
89      Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar).get(),
90                          Literal::CreateR1<float>(constant_vector).get(),
91                          Literal::CreateR2<float>(constant_matrix).get()});
92  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
93}
94
95// Tests the creation of tuple data.
96XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
97  ComputationBuilder builder(client_, TestName());
98
99  auto result = builder.Tuple(
100      {builder.ConstantR0<float>(7.0), builder.ConstantR1<float>({})});
101
102  auto expected = Literal::MakeTuple({Literal::CreateR0<float>(7.0).get(),
103                                      Literal::CreateR1<float>({}).get()});
104  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
105}
106
107// Tests the creation of an empty tuple.
108XLA_TEST_F(TupleTest, EmptyTupleCreate) {
109  ComputationBuilder builder(client_, TestName());
110  auto result = builder.Tuple({});
111  auto expected = Literal::MakeTuple({});
112  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
113}
114
115// Trivial test for extracting a tuple element with GetTupleElement.
116XLA_TEST_F(TupleTest, GetTupleElement) {
117  ComputationBuilder builder(client_, TestName());
118  std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
119  std::initializer_list<std::initializer_list<float>> constant_matrix = {
120      {1.f, 2.f, 3.f},  // row 0
121      {4.f, 5.f, 6.f},  // row 1
122  };
123  auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
124                                   builder.ConstantR2<float>(constant_matrix)});
125  auto matrix_element = builder.GetTupleElement(tuple_data, 1);
126  ComputeAndCompareR2<float>(&builder, Array2D<float>(constant_matrix), {},
127                             error_spec_);
128}
129
130// Trivial test for extracting a tuple element with GetTupleElement.
131XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) {
132  ComputationBuilder builder(client_, TestName());
133  auto tuple_data = builder.Tuple(
134      {builder.ConstantR1<float>({}),
135       builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 101))});
136  auto matrix_element = builder.GetTupleElement(tuple_data, 1);
137  ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 101), {}, error_spec_);
138}
139
140XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) {
141  ComputationBuilder builder(client_, TestName());
142  auto value = builder.ConstantR1<float>({4.5f});
143  builder.GetTupleElement(value, 1);
144  auto result_status = builder.Build();
145  EXPECT_FALSE(result_status.ok());
146  EXPECT_THAT(
147      result_status.status().error_message(),
148      ::testing::HasSubstr("Operand to GetTupleElement() is not a tuple"));
149}
150
151// Extracts both elements from a tuple with GetTupleElement and then adds them
152// together.
153XLA_TEST_F(TupleTest, AddTupleElements) {
154  ComputationBuilder builder(client_, TestName());
155  std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
156  std::initializer_list<std::initializer_list<float>> constant_matrix = {
157      {1.f, 2.f, 3.f},  // row 0
158      {4.f, 5.f, 6.f},  // row 1
159  };
160  auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
161                                   builder.ConstantR2<float>(constant_matrix)});
162  auto vector_element = builder.GetTupleElement(tuple_data, 0);
163  auto matrix_element = builder.GetTupleElement(tuple_data, 1);
164  auto vector_shape = builder.GetShape(vector_element).ConsumeValueOrDie();
165  auto matrix_shape = builder.GetShape(matrix_element).ConsumeValueOrDie();
166  auto result = builder.Add(matrix_element, vector_element,
167                            /*broadcast_dimensions=*/{1});
168
169  Array2D<float> expected({
170      {2.f, 4.f, 6.f},  // row 0
171      {5.f, 7.f, 9.f},  // row 1
172  });
173  ASSERT_TRUE(ShapeUtil::ShapeIs(*vector_shape, F32, {3}));
174  ASSERT_TRUE(ShapeUtil::ShapeIs(*matrix_shape, F32, {/*y=*/2, /*x=*/3}));
175  ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
176}
177
178// Extracts both elements from a tuple and then puts them into a new tuple in
179// the opposite order.
180XLA_TEST_F(TupleTest, TupleGTEToTuple) {
181  ComputationBuilder builder(client_, TestName());
182  std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
183  std::initializer_list<std::initializer_list<float>> constant_matrix = {
184      {1.f, 2.f, 3.f},  // row 0
185      {4.f, 5.f, 6.f},  // row 1
186  };
187  auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
188                                   builder.ConstantR2<float>(constant_matrix)});
189  auto new_tuple = builder.Tuple({builder.GetTupleElement(tuple_data, 1),
190                                  builder.GetTupleElement(tuple_data, 0)});
191  auto expected =
192      Literal::MakeTuple({Literal::CreateR2<float>(constant_matrix).get(),
193                          Literal::CreateR1<float>(constant_vector).get()});
194  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
195}
196
197XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
198  ComputationBuilder b(client_, TestName());
199  ComputationDataHandle v1, v2;
200
201  for (bool direction : {false, true}) {
202    std::unique_ptr<GlobalData> v1_data =
203        CreateR0Parameter<float>(0.0f, /*parameter_number=*/0, /*name=*/"v1",
204                                 /*builder=*/&b, /*data_handle=*/&v1);
205    std::unique_ptr<GlobalData> v2_data =
206        CreateR0Parameter<float>(1.0f, /*parameter_number=*/1, /*name=*/"v2",
207                                 /*builder=*/&b, /*data_handle=*/&v2);
208    auto v1_gt = b.Gt(v1, v2);             // false
209    auto v2_gt = b.Gt(v2, v1);             // true
210    auto v1_v2 = b.Tuple({v1_gt, v2_gt});  // {false, true}
211    auto v2_v1 = b.Tuple({v2_gt, v1_gt});  // {true, false}
212    auto select = b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1);
213    auto expected =
214        Literal::MakeTuple({Literal::CreateR0<bool>(direction).get(),
215                            Literal::CreateR0<bool>(!direction).get()});
216
217    ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()},
218                           error_spec_);
219  }
220}
221
222// Builds two new tuples from an existing tuple (by means of GetTupleElement),
223// then adds up the components of the new tuples.
224XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {
225  //
226  // v------           --(GTE 0)--             --(GTE 0)----------
227  //        \         /           \           /                   \
228  //         (tuple)--             (tuple01)--                     \
229  //        /   |     \           /           \                     \
230  // m------    |      --(GTE 1)--             --(GTE 1)------------ \
231  //            |                                                   \ \
232  //            |                                                    (add)
233  //            |                                                   / /
234  //            |--------(GTE 1)--             --(GTE 0)------------ /
235  //             \                \           /                     /
236  //              \                (tuple10)--                     /
237  //               \              /           \                   /
238  //                -----(GTE 0)--             --(GTE 1)----------
239  ComputationBuilder builder(client_, TestName());
240  std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
241  std::initializer_list<std::initializer_list<float>> constant_matrix = {
242      {1.f, 2.f, 3.f},  // row 0
243      {4.f, 5.f, 6.f},  // row 1
244  };
245  auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
246                                   builder.ConstantR2<float>(constant_matrix)});
247  auto new_tuple01 = builder.Tuple({builder.GetTupleElement(tuple_data, 0),
248                                    builder.GetTupleElement(tuple_data, 1)});
249  auto new_tuple10 = builder.Tuple({builder.GetTupleElement(tuple_data, 1),
250                                    builder.GetTupleElement(tuple_data, 0)});
251  auto vector_from_01 = builder.GetTupleElement(new_tuple01, 0);
252  auto vector_from_10 = builder.GetTupleElement(new_tuple10, 1);
253  auto matrix_from_01 = builder.GetTupleElement(new_tuple01, 1);
254  auto matrix_from_10 = builder.GetTupleElement(new_tuple10, 0);
255
256  auto addvectors = builder.Add(vector_from_01, vector_from_10);
257  auto addmatrices = builder.Add(matrix_from_01, matrix_from_10);
258
259  auto result = builder.Add(addmatrices, addvectors,
260                            /*broadcast_dimensions=*/{1});
261
262  Array2D<float> expected({
263      {4.f, 8.f, 12.f},    // row 0
264      {10.f, 14.f, 18.f},  // row 1
265  });
266  ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
267}
268
269XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) {
270  // Tests a selection between tuples with "false" path taken.
271  ComputationBuilder builder(client_, TestName());
272
273  std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
274  std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
275  auto tuple12 = builder.Tuple(
276      {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)});
277  auto tuple21 = builder.Tuple(
278      {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
279
280  auto select =
281      builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
282  auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec2).get(),
283                                      Literal::CreateR1<float>(vec1).get()});
284  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
285}
286
287XLA_TEST_F(TupleTest, TuplesInAMap) {
288  Computation tuple_computation;
289  {
290    // tuple_computation(x) = 100 * min(x, x^2) + max(x, x^2) using tuples.
291    //
292    // Need to put a select in there to prevent HLO-level optimizations from
293    // optimizing out the tuples.
294    ComputationBuilder b(client_, "sort_square");
295    auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
296    auto x2 = b.Mul(x, x);
297    auto x_smaller_tuple = b.Tuple({x, x2});
298    auto x2_smaller_tuple = b.Tuple({x2, x});
299    auto sorted = b.Select(b.Lt(x, x2), x_smaller_tuple, x2_smaller_tuple);
300    auto smaller = b.GetTupleElement(sorted, 0);
301    auto greater = b.GetTupleElement(sorted, 1);
302    b.Add(greater, b.Mul(b.ConstantR0<float>(100.0f), smaller));
303    auto computation_status = b.Build();
304    ASSERT_IS_OK(computation_status.status());
305    tuple_computation = computation_status.ConsumeValueOrDie();
306  }
307
308  ComputationBuilder b(client_, TestName());
309  auto input = b.ConstantR1<float>({-1.0f, 1.0f, 2.1f});
310  b.Map({input}, tuple_computation, {0});
311  ComputeAndCompareR1<float>(&b, {-99.0f, 101.0f, 214.41f}, {}, error_spec_);
312}
313
314XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) {
315  // Tests a selection between tuples with "true" path taken.
316  ComputationBuilder builder(client_, TestName());
317
318  std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
319  std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
320  auto tuple12 = builder.Tuple(
321      {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)});
322  auto tuple21 = builder.Tuple(
323      {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
324
325  auto select =
326      builder.Select(builder.ConstantR0<bool>(true), tuple12, tuple21);
327  auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec1).get(),
328                                      Literal::CreateR1<float>(vec2).get()});
329  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
330}
331
332XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
333  // Tests a selection between tuples but the final result is an element of the
334  // tuple, not the whole tuple.
335  ComputationBuilder builder(client_, TestName());
336
337  std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
338  std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
339  auto tuple12 = builder.Tuple(
340      {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)});
341  auto tuple21 = builder.Tuple(
342      {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
343
344  auto select =
345      builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
346  auto element = builder.GetTupleElement(select, 0);
347
348  ComputeAndCompareR1<float>(&builder, vec2, {}, error_spec_);
349}
350
351// Cascaded selects between tuple types.
352XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesCascaded)) {
353  //
354  //                       vec1     vec2   vec2     vec1
355  //                        |        |      |        |
356  //                        |        |      |        |
357  //                        (tuple 12)      (tuple 21)
358  //                               \            /
359  //                                \          /
360  //                                 \        /
361  //  true  --            --(GTE 0)--(select 1)
362  //          \          /             |
363  //       (pred tuple)--              |          --(GTE 0)--
364  //          /          \             V         /           \
365  //  false --            --(GTE 1)--(select 2)--             --(add)
366  //                                 /           \           /
367  //                                /             --(GTE 1)--
368  //                               /
369  //                          (tuple 21)
370  ComputationBuilder builder(client_, TestName());
371
372  std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
373  std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
374
375  auto pred_tuple = builder.Tuple(
376      {builder.ConstantR0<bool>(true), builder.ConstantR0<bool>(false)});
377  auto tuple12 = builder.Tuple(
378      {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)});
379  auto tuple21 = builder.Tuple(
380      {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
381
382  auto select1 =
383      builder.Select(builder.GetTupleElement(pred_tuple, 0), tuple12, tuple21);
384  auto select2 =
385      builder.Select(builder.GetTupleElement(pred_tuple, 1), tuple21, select1);
386  auto result = builder.Add(builder.GetTupleElement(select2, 0),
387                            builder.GetTupleElement(select2, 1));
388
389  ComputeAndCompareR1<float>(&builder, {3.f, 6.f, 9.f}, {}, error_spec_);
390}
391
392XLA_TEST_F(TupleTest,
393           DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesReuseConstants)) {
394  // Similar to SelectBetweenTuples, but the constants are shared between the
395  // input tuples.
396  ComputationBuilder builder(client_, TestName());
397
398  std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
399  std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
400  auto c1 = builder.ConstantR1<float>(vec1);
401  auto c2 = builder.ConstantR1<float>(vec2);
402  auto tuple12 = builder.Tuple({c1, c2});
403  auto tuple21 = builder.Tuple({c2, c1});
404
405  auto select =
406      builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
407  auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec2).get(),
408                                      Literal::CreateR1<float>(vec1).get()});
409  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
410}
411
412XLA_TEST_F(TupleTest, NestedTuples) {
413  ComputationBuilder builder(client_, TestName());
414  auto inner_tuple = builder.Tuple(
415      {builder.ConstantR1<float>({1.0, 2.0}), builder.ConstantR0<float>(42.0)});
416  auto outer_tuple =
417      builder.Tuple({inner_tuple, builder.ConstantR1<float>({22.0, 44.0})});
418
419  auto expected_v1 = Literal::CreateR1<float>({1.0, 2.0});
420  auto expected_s = Literal::CreateR0<float>(42.0);
421  auto expected_inner_tuple =
422      Literal::MakeTuple({expected_v1.get(), expected_s.get()});
423  auto expected_v2 = Literal::CreateR1<float>({22.0, 44.0});
424  auto expected =
425      Literal::MakeTuple({expected_inner_tuple.get(), expected_v2.get()});
426
427  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
428}
429
430XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
431  ComputationBuilder builder(client_, TestName());
432
433  Shape data_shape = ShapeUtil::MakeShape(F32, {3});
434  Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({data_shape, data_shape});
435  Shape outer_tuple_shape =
436      ShapeUtil::MakeTupleShape({inner_tuple_shape, data_shape});
437
438  auto input = builder.Parameter(0, outer_tuple_shape, "input");
439  auto gte0 = builder.GetTupleElement(input, 0);
440  auto gte1 = builder.GetTupleElement(gte0, 1);
441  builder.Add(gte1, builder.ConstantR1<float>({10.0, 11.0, 12.0}));
442
443  std::unique_ptr<GlobalData> data =
444      client_
445          ->TransferToServer(*Literal::MakeTuple({
446              Literal::MakeTuple(
447                  {
448                      Literal::CreateR1<float>({1.0, 2.0, 3.0}).get(),
449                      Literal::CreateR1<float>({4.0, 5.0, 6.0}).get(),
450                  })
451                  .get(),
452              Literal::CreateR1<float>({7.0, 8.0, 9.0}).get(),
453          }))
454          .ConsumeValueOrDie();
455
456  std::vector<GlobalData*> arguments = {data.get()};
457  const std::vector<float> expected = {4.0 + 10.0, 5.0 + 11.0, 6.0 + 12.0};
458  ComputeAndCompareR1<float>(&builder, expected, arguments, ErrorSpec(1e-5));
459}
460
461XLA_TEST_F(TupleTest, ComplexTuples) {
462  ComputationBuilder builder(client_, TestName());
463  {
464    Shape c64r0 = ShapeUtil::MakeShape(C64, {});
465    Shape c64r1 = ShapeUtil::MakeShape(C64, {2});
466    Shape c64r2 = ShapeUtil::MakeShape(C64, {3, 2});
467    Shape arg0_shape = ShapeUtil::MakeTupleShape(
468        {c64r0, ShapeUtil::MakeTupleShape({c64r1, c64r2})});
469    auto input0 = builder.Parameter(0, arg0_shape, "input0");
470    auto t0 = builder.GetTupleElement(input0, 0);
471    auto t1 = builder.GetTupleElement(input0, 1);
472    auto t10 = builder.GetTupleElement(t1, 0);
473    auto t11 = builder.GetTupleElement(t1, 1);
474    auto sum = builder.Add(builder.Add(t10, t11, {1}), t0);
475    auto input1 = builder.Parameter(1, c64r1, "input1");
476    auto prod = builder.Mul(input1, sum, {1});
477    builder.Tuple({builder.Tuple({prod, sum}),
478                   builder.ConstantR0<complex64>({123, 456})});
479  }
480
481  std::unique_ptr<GlobalData> arg0 =
482      client_
483          ->TransferToServer(*Literal::MakeTuple(
484              {Literal::CreateR0<complex64>({1, 2}).get(),
485               Literal::MakeTuple(
486                   {Literal::CreateR1<complex64>({{10, 20}, {30, 40}}).get(),
487                    Literal::CreateR2<complex64>(
488                        {{{100, 200}, {300, 400}},
489                         {{1000, 2000}, {3000, 4000}},
490                         {{10000, 20000}, {30000, 40000}}})
491                        .get()})
492                   .get()}))
493          .ConsumeValueOrDie();
494  std::unique_ptr<GlobalData> arg1 =
495      client_
496          ->TransferToServer(*Literal::CreateR1<complex64>({{1, 2}, {1, -2}}))
497          .ConsumeValueOrDie();
498  auto sum = Literal::CreateR2<complex64>({{{111, 222}, {331, 442}},
499                                           {{1011, 2022}, {3031, 4042}},
500                                           {{10011, 20022}, {30031, 40042}}});
501  auto prod = Literal::CreateFromShape(sum->shape());
502  ASSERT_TRUE(prod->Populate<complex64>(
503                      [&sum](tensorflow::gtl::ArraySlice<int64> indexes) {
504                        return sum->Get<complex64>(indexes) *
505                               (indexes[indexes.size() - 1] == 0
506                                    ? complex64(1, 2)
507                                    : complex64(1, -2));
508                      })
509                  .ok());
510  auto expected =
511      Literal::MakeTuple({Literal::MakeTuple({prod.get(), sum.get()}).get(),
512                          Literal::CreateR0<complex64>({123, 456}).get()});
513  ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()},
514                         error_spec_);
515}
516
517}  // namespace
518}  // namespace xla
519