1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <memory>
17#include <vector>
18
19#include "tensorflow/compiler/xla/array2d.h"
20#include "tensorflow/compiler/xla/array3d.h"
21#include "tensorflow/compiler/xla/client/computation_builder.h"
22#include "tensorflow/compiler/xla/client/local_client.h"
23#include "tensorflow/compiler/xla/primitive_util.h"
24#include "tensorflow/compiler/xla/reference_util.h"
25#include "tensorflow/compiler/xla/shape_util.h"
26#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
27#include "tensorflow/compiler/xla/tests/literal_test_util.h"
28#include "tensorflow/compiler/xla/tests/test_macros.h"
29#include "tensorflow/compiler/xla/tests/test_utils.h"
30#include "tensorflow/core/platform/test.h"
31#include "tensorflow/core/platform/types.h"
32#include "tensorflow/core/util/command_line_flags.h"
33
34namespace xla {
35namespace {
36
37// TODO(b/34468543): use GUnit typed tests when we can do all tests on all
38// backends.
39class DotOperationTest : public ClientLibraryTestBase {
40 public:
41  ErrorSpec error_spec_{0.0001, 1e-5};
42
43 protected:
44  template <typename Element>
45  void TestOneElementVectorDot();
46  template <typename Element>
47  void TestVectorDot();
48  template <typename Element>
49  void TestSquareMatrixDot(bool lhs_row_major = false,
50                           bool rhs_row_major = false);
51  template <typename Element>
52  void TestNonsquareMatrixDot(bool lhs_row_major = false,
53                              bool rhs_row_major = false);
54};
55
56XLA_TEST_F(DotOperationTest, ZeroElementVectorDotF32) {
57  ComputationBuilder builder(client_, TestName());
58  auto lhs = builder.ConstantR1<float>({});
59  auto rhs = builder.ConstantR1<float>({});
60  auto result = builder.Dot(lhs, rhs);
61
62  ComputeAndCompareR0<float>(&builder, 0.0, {}, error_spec_);
63}
64
65XLA_TEST_F(DotOperationTest, TrivialMatrixVectorDotF32) {
66  ComputationBuilder builder(client_, TestName());
67  auto lhs = builder.ConstantR2<float>({{3.0, 4.0}});
68  auto rhs = builder.ConstantR1<float>({3.0, 4.0});
69  auto result = builder.Dot(lhs, rhs);
70
71  ComputeAndCompareR1<float>(&builder, {25.0}, {}, error_spec_);
72}
73
74template <typename Element>
75void DotOperationTest::TestOneElementVectorDot() {
76  ComputationBuilder builder(client_, TestName());
77  auto lhs = builder.ConstantR1<Element>({2.0});
78  auto rhs = builder.ConstantR1<Element>({3.0});
79  auto result = builder.Dot(lhs, rhs);
80
81  ComputeAndCompareR0<Element>(&builder, 6.0, {}, error_spec_);
82}
83
84XLA_TEST_F(DotOperationTest, OneElementVectorDotF32) {
85  TestOneElementVectorDot<float>();
86}
87
88XLA_TEST_F(DotOperationTest, OneElementVectorDotF64) {
89  TestOneElementVectorDot<double>();
90}
91
92template <typename Element>
93void DotOperationTest::TestVectorDot() {
94  ComputationBuilder builder(client_, TestName());
95  auto lhs = builder.ConstantR1<Element>({1.0, 2.5, 42.0});
96  auto rhs = builder.ConstantR1<Element>({11.0, -1.0, 0.5});
97  auto result = builder.Dot(lhs, rhs);
98
99  ComputeAndCompareR0<Element>(&builder, 29.5, {}, error_spec_);
100}
101
102XLA_TEST_F(DotOperationTest, VectorDotF32) { TestVectorDot<float>(); }
103
104XLA_TEST_F(DotOperationTest, VectorDotF64) { TestVectorDot<double>(); }
105
106namespace {
107
108std::vector<int64> MinorToMajorForIsRowMajor(bool row_major) {
109  return {row_major ? 1 : 0, row_major ? 0 : 1};
110}
111
112}  // namespace
113
114XLA_TEST_F(DotOperationTest, Dot_0x2_2x0) {
115  ComputationBuilder builder(client_, TestName());
116  auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
117  auto rhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(2, 0));
118  auto result = builder.Dot(lhs, rhs);
119
120  ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {}, error_spec_);
121}
122
123XLA_TEST_F(DotOperationTest, Dot_0x2_2x3) {
124  ComputationBuilder builder(client_, TestName());
125  auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
126  auto rhs = builder.ConstantR2<float>({{7.0, 8.0, 9.0}, {42.0, 77.0, 101.0}});
127  auto result = builder.Dot(lhs, rhs);
128
129  ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 3), {}, error_spec_);
130}
131
132XLA_TEST_F(DotOperationTest, Dot_3x2_2x0) {
133  ComputationBuilder builder(client_, TestName());
134  auto lhs =
135      builder.ConstantR2<float>({{7.0, 8.0}, {9.0, 42.0}, {77.0, 101.0}});
136  auto rhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(2, 0));
137  auto result = builder.Dot(lhs, rhs);
138
139  ComputeAndCompareR2<float>(&builder, Array2D<float>(3, 0), {}, error_spec_);
140}
141
142XLA_TEST_F(DotOperationTest, Dot_2x0_0x2) {
143  ComputationBuilder builder(client_, TestName());
144  auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(2, 0));
145  auto rhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
146  auto result = builder.Dot(lhs, rhs);
147
148  ComputeAndCompareR2<float>(&builder, Array2D<float>(2, 2, 0.0f), {},
149                             error_spec_);
150}
151
152XLA_TEST_F(DotOperationTest, FusedDot) {
153  ComputationBuilder builder(client_, TestName());
154  auto param0 = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 4}), "arg0");
155  auto param1 = builder.Parameter(1, ShapeUtil::MakeShape(F32, {4, 1}), "arg1");
156  auto exp0 = builder.Exp(param0);
157  auto result = builder.Dot(exp0, param1);
158
159  auto lhs_handle = client_
160                        ->TransferToServer(*Literal::CreateR2<float>(
161                            {{1.0, 2.0, 3.0, 4.0}, {-1.0, -2.0, -3.0, -4.0}}))
162                        .ConsumeValueOrDie();
163  auto rhs_handle = client_
164                        ->TransferToServer(*Literal::CreateR2<float>(
165                            {{1.0}, {2.0}, {3.0}, {4.0}}))
166                        .ConsumeValueOrDie();
167
168  ComputeAndCompareR2<float>(
169      &builder, Array2D<float>({{296.14560492846033}, {0.8611737683031964}}),
170      {lhs_handle.get(), rhs_handle.get()}, error_spec_);
171}
172
173template <typename Element>
174void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major,
175                                           bool rhs_row_major) {
176  auto lhs_handle =
177      client_
178          ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
179              {{1.0, 2.0}, {3.0, -4.0}},
180              LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))))
181          .ConsumeValueOrDie();
182  auto rhs_handle =
183      client_
184          ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
185              {{1.0, 6.0}, {7.0, -4.0}},
186              LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))))
187          .ConsumeValueOrDie();
188
189  ComputationBuilder builder(client_, TestName());
190  auto prim_type = primitive_util::NativeToPrimitiveType<Element>();
191  auto result = builder.Dot(
192      builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"),
193      builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs"));
194
195  Array2D<Element> expected({{15.0, -2.0}, {-25.0, 34.0}});
196  ComputeAndCompareR2<Element>(
197      &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_);
198}
199
200struct DotTestParam {
201  int m;
202  int k;
203  int n;
204  bool dot_lhs_row_major;
205  bool dot_rhs_row_major;
206  bool has_addend;
207  bool addend_row_major;
208};
209
210string PrintDotTestParam(
211    const ::testing::TestParamInfo<DotTestParam>& test_param) {
212  const DotTestParam& param = test_param.param;
213  if (param.has_addend) {
214    return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n,
215                                       "_MajorToMinor",
216                                       param.dot_lhs_row_major ? "T" : "F",
217                                       param.dot_rhs_row_major ? "T" : "F",
218                                       param.addend_row_major ? "T" : "F");
219  } else {
220    return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n,
221                                       "_MajorToMinor",
222                                       param.dot_lhs_row_major ? "T" : "F",
223                                       param.dot_rhs_row_major ? "T" : "F");
224  }
225}
226
227class ParametricDotTest : public DotOperationTest,
228                          public ::testing::WithParamInterface<DotTestParam> {};
229
230XLA_TEST_P(ParametricDotTest, TestF32) {
231  DotTestParam param = GetParam();
232
233  std::unique_ptr<Array2D<float>> dot_lhs_data =
234      MakeLinspaceArray2D(0.0, 1.0, param.m, param.k);
235  std::unique_ptr<Literal> dot_lhs_lit = Literal::CreateR2FromArray2DWithLayout(
236      *dot_lhs_data, LayoutUtil::MakeLayout(
237                         MinorToMajorForIsRowMajor(param.dot_lhs_row_major)));
238  std::unique_ptr<GlobalData> dot_lhs_handle =
239      client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie();
240
241  std::unique_ptr<Array2D<float>> dot_rhs_data =
242      MakeLinspaceArray2D(0.0, 1.0, param.k, param.n);
243  std::unique_ptr<Literal> dot_rhs_lit = Literal::CreateR2FromArray2DWithLayout(
244      *dot_rhs_data, LayoutUtil::MakeLayout(
245                         MinorToMajorForIsRowMajor(param.dot_rhs_row_major)));
246  std::unique_ptr<GlobalData> dot_rhs_handle =
247      client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie();
248
249  std::unique_ptr<Array2D<float>> addend_data;
250  std::unique_ptr<Literal> addend_lit;
251  std::unique_ptr<GlobalData> addend_handle;
252
253  if (param.has_addend) {
254    addend_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.n);
255    addend_lit = Literal::CreateR2FromArray2DWithLayout(
256        *addend_data, LayoutUtil::MakeLayout(
257                          MinorToMajorForIsRowMajor(param.addend_row_major)));
258    addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie();
259  }
260
261  ComputationBuilder builder(client_, TestName());
262  auto prim_type = primitive_util::NativeToPrimitiveType<float>();
263  auto result = builder.Dot(
264      builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {param.m, param.k}),
265                        "dot_lhs"),
266      builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {param.k, param.n}),
267                        "dot_rhs"));
268
269  if (param.has_addend) {
270    result = builder.Add(
271        result,
272        builder.Parameter(
273            2, ShapeUtil::MakeShape(prim_type, {param.m, param.n}), "addend"));
274  }
275
276  std::unique_ptr<Array2D<float>> expected;
277  if (param.has_addend) {
278    expected = ReferenceUtil::ApplyElementwise2D(
279        std::plus<float>(),
280        *ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data),
281        *addend_data);
282  } else {
283    expected = ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data);
284  }
285
286  std::vector<GlobalData*> args = {dot_lhs_handle.get(), dot_rhs_handle.get()};
287  if (param.has_addend) {
288    args.push_back(addend_handle.get());
289  }
290
291  ComputeAndCompareR2<float>(&builder, *expected, args, ErrorSpec(0.3, 3e-3));
292}
293
294std::vector<DotTestParam> CreateDotTestParameters() {
295  std::vector<DotTestParam> params;
296
297  auto add_matrix_matrix_dot_test = [&](int m, int k, int n) {
298    for (bool lhs_row_major : {true, false}) {
299      for (bool rhs_row_major : {true, false}) {
300        params.push_back({/*m=*/m, /*k=*/k, /*n=*/n,
301                          /*dot_lhs_row_major=*/lhs_row_major,
302                          /*dot_rhs_row_major=*/rhs_row_major,
303                          /*has_addend=*/false, /*addend_row_major=*/true});
304      }
305    }
306  };
307
308  auto add_matrix_vector_dot_test = [&](int k, int n) {
309    for (bool has_addend : {false, true}) {
310      params.push_back({/*m=*/1, /*k=*/k, /*n=*/n,
311                        /*dot_lhs_row_major=*/true, /*dot_rhs_row_major=*/true,
312                        /*has_addend=*/has_addend, /*addend_row_major=*/true});
313      if (n != 1) {
314        params.push_back(
315            {/*m=*/n, /*k=*/k, /*n=*/1,
316             /*dot_lhs_row_major=*/true, /*dot_rhs_row_major=*/true,
317             /*has_addend=*/has_addend, /*addend_row_major=*/true});
318      }
319    }
320  };
321
322  add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7);
323  add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520);
324  add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520);
325
326  add_matrix_vector_dot_test(/*k=*/8, /*n=*/8);
327  add_matrix_vector_dot_test(/*k=*/130, /*n=*/8);
328  add_matrix_vector_dot_test(/*k=*/8, /*n=*/130);
329  add_matrix_vector_dot_test(/*k=*/290, /*n=*/130);
330  add_matrix_vector_dot_test(/*k=*/1, /*n=*/1);
331  add_matrix_vector_dot_test(/*k=*/1, /*n=*/16);
332  add_matrix_vector_dot_test(/*k=*/3, /*n=*/16);
333  add_matrix_vector_dot_test(/*k=*/3, /*n=*/3);
334  add_matrix_vector_dot_test(/*k=*/29, /*n=*/29);
335  add_matrix_vector_dot_test(/*k=*/8, /*n=*/2);
336  add_matrix_vector_dot_test(/*k=*/2, /*n=*/8);
337  add_matrix_vector_dot_test(/*k=*/259, /*n=*/258);
338
339  return params;
340}
341
342INSTANTIATE_TEST_CASE_P(DotTests, ParametricDotTest,
343                        ::testing::ValuesIn(CreateDotTestParameters()),
344                        PrintDotTestParam);
345
346XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) {
347  TestSquareMatrixDot<float>(false, false);
348}
349
350XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFT) {
351  TestSquareMatrixDot<float>(false, true);
352}
353
354XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTF) {
355  TestSquareMatrixDot<float>(true, false);
356}
357
358XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTT) {
359  TestSquareMatrixDot<float>(true, true);
360}
361
362XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFF) {
363  TestSquareMatrixDot<complex64>(false, false);
364}
365
366XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFT) {
367  TestSquareMatrixDot<complex64>(false, true);
368}
369
370XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTF) {
371  TestSquareMatrixDot<complex64>(true, false);
372}
373
374XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTT) {
375  TestSquareMatrixDot<complex64>(true, true);
376}
377
378XLA_TEST_F(DotOperationTest, SquareMatrixDotF64) {
379  TestSquareMatrixDot<double>();
380}
381
382template <typename Element>
383void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major,
384                                              bool rhs_row_major) {
385  auto lhs_handle =
386      client_
387          ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
388              {{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}},
389              LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))))
390          .ConsumeValueOrDie();
391  auto rhs_handle =
392      client_
393          ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
394              {{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}},
395              LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))))
396          .ConsumeValueOrDie();
397
398  ComputationBuilder builder(client_, TestName());
399  auto prim_type = primitive_util::NativeToPrimitiveType<Element>();
400  auto result = builder.Dot(
401      builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"),
402      builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs"));
403
404  Array2D<Element> expected({{26.0, 0.0}, {-12.0, 10.0}});
405
406  ComputeAndCompareR2<Element>(
407      &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_);
408}
409
410XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFF) {
411  TestNonsquareMatrixDot<float>(false, false);
412}
413
414XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFT) {
415  TestNonsquareMatrixDot<float>(false, true);
416}
417
418XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTF) {
419  TestNonsquareMatrixDot<float>(true, false);
420}
421
422XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) {
423  TestNonsquareMatrixDot<float>(true, true);
424}
425
426XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF64) {
427  TestNonsquareMatrixDot<double>();
428}
429
430XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFF) {
431  TestNonsquareMatrixDot<complex64>(false, false);
432}
433
434XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFT) {
435  TestNonsquareMatrixDot<complex64>(false, true);
436}
437
438XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTF) {
439  TestNonsquareMatrixDot<complex64>(true, false);
440}
441
442XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTT) {
443  TestNonsquareMatrixDot<complex64>(true, true);
444}
445
446XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
447  auto lhs_handle =
448      client_
449          ->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
450              {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
451          .ConsumeValueOrDie();
452  auto rhs_handle =
453      client_
454          ->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
455              {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
456              LayoutUtil::MakeLayout({1, 0})))
457          .ConsumeValueOrDie();
458
459  ComputationBuilder builder(client_, TestName());
460  auto prim_type = primitive_util::NativeToPrimitiveType<complex64>();
461  auto result = builder.Dot(
462      builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"),
463      builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs"));
464
465  Array2D<complex64> expected({{30.0, -2.0}});
466
467  ComputeAndCompareR2<complex64>(
468      &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_);
469}
470
471XLA_TEST_F(DotOperationTest, ConcurrentMatMul) {
472  ComputationBuilder builder(client_, TestName());
473  auto matrix1 = builder.ConstantR2<float>({{1.0, 2.0}, {3.0, 4.0}});
474  auto matrix2 = builder.ConstantR2<float>({{5.0, 6.0}, {7.0, 8.0}});
475  auto matrix12 = builder.Dot(matrix1, matrix2);
476  auto matrix21 = builder.Dot(matrix2, matrix1);
477  builder.Add(matrix12, matrix21);
478
479  Array2D<float> expected({{42.0, 56.0}, {74.0, 96.0}});
480  ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
481}
482
483// Regression test for b/32055648. The root of the graph is a kFusion of 4
484// bitcasts. Although bitcasts don't map to thunks, the root should still be
485// sync-dependent on bitcasts' operands.
486XLA_TEST_F(DotOperationTest, BatchMatMul) {
487  ComputationBuilder builder(client_, TestName());
488  auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "x");
489  auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "y");
490
491  auto x_flat = builder.Reshape(x, {0, 1, 2, 3}, {4, 2, 2});
492  auto y_flat = builder.Reshape(y, {0, 1, 2, 3}, {4, 2, 2});
493
494  // Slice batches into individual matrices and multiply them.
495  std::vector<xla::ComputationDataHandle> out_slices;
496  for (int i = 0; i < 4; ++i) {
497    // Slice off individual matrices and reshape to 2D tensors.
498    auto x_slice = builder.Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
499    x_slice = builder.Reshape(x_slice, {0, 1, 2}, {2, 2});
500    auto y_slice = builder.Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
501    y_slice = builder.Reshape(y_slice, {0, 1, 2}, {2, 2});
502
503    auto out = builder.Dot(x_slice, y_slice);
504    out = builder.Reshape(out, {0, 1}, {1, 2, 2});
505    out_slices.push_back(out);
506  }
507  auto out_flat = builder.ConcatInDim(out_slices, 0);
508  builder.Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2});
509
510  auto x_data = client_
511                    ->TransferToServer(*Literal::CreateR4<float>(
512                        {{{{1000, 100}, {10, 1}}, {{2000, 200}, {20, 2}}},
513                         {{{3000, 300}, {30, 3}}, {{4000, 400}, {40, 4}}}}))
514                    .ConsumeValueOrDie();
515  auto y_data = client_
516                    ->TransferToServer(*Literal::CreateR4<float>(
517                        {{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}},
518                         {{{11, 22}, {33, 44}}, {{55, 66}, {77, 88}}}}))
519                    .ConsumeValueOrDie();
520
521  ComputeAndCompareR4<float>(
522      &builder,
523      /*expected=*/
524      {{{{1300, 2400}, {13, 24}}, {{11400, 13600}, {114, 136}}},
525       {{{42900, 79200}, {429, 792}}, {{250800, 299200}, {2508, 2992}}}},
526      {x_data.get(), y_data.get()}, error_spec_);
527}
528
529XLA_TEST_F(DotOperationTest, GeneralMatMul) {
530  ComputationBuilder builder(client_, TestName());
531  auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2}), "x");
532  auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2, 2}), "y");
533
534  DotDimensionNumbers dnums;
535  dnums.add_lhs_contracting_dimensions(2);
536  dnums.add_rhs_contracting_dimensions(1);
537  dnums.add_lhs_batch_dimensions(0);
538  dnums.add_rhs_batch_dimensions(0);
539
540  auto out = builder.DotGeneral(x, y, dnums);
541
542  auto x_data = client_
543                    ->TransferToServer(*Literal::CreateR3<float>(
544                        {{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}}))
545                    .ConsumeValueOrDie();
546
547  auto y_data = client_
548                    ->TransferToServer(*Literal::CreateR3<float>(
549                        {{{1.0, 0.0}, {0.0, 1.0}}, {{1.0, 0.0}, {0.0, 1.0}}}))
550                    .ConsumeValueOrDie();
551
552  ComputeAndCompareR3<float>(
553      &builder,
554      /*expected=*/
555      {{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}},
556      {x_data.get(), y_data.get()}, error_spec_);
557}
558
559TEST_F(DotOperationTest, TransposeFolding) {
560  for (bool transpose_lhs : {false, true}) {
561    for (bool transpose_rhs : {false, true}) {
562      for (bool row_major : {false, true}) {
563        std::unique_ptr<Array2D<float>> lhs(
564            new Array2D<float>({{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}}));
565        std::unique_ptr<Array2D<float>> rhs(
566            new Array2D<float>({{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}}));
567
568        if (transpose_lhs) {
569          lhs = ReferenceUtil::TransposeArray2D(*lhs);
570        }
571        if (transpose_rhs) {
572          rhs = ReferenceUtil::TransposeArray2D(*rhs);
573        }
574        auto lhs_handle =
575            client_
576                ->TransferToServer(
577                    *Literal::CreateR2FromArray2DWithLayout<float>(
578                        *lhs, LayoutUtil::MakeLayout(
579                                  MinorToMajorForIsRowMajor(row_major))))
580                .ConsumeValueOrDie();
581        auto rhs_handle =
582            client_
583                ->TransferToServer(
584                    *Literal::CreateR2FromArray2DWithLayout<float>(
585                        *rhs, LayoutUtil::MakeLayout(
586                                  MinorToMajorForIsRowMajor(row_major))))
587                .ConsumeValueOrDie();
588
589        ComputationBuilder builder(client_, TestName());
590        auto prim_type = primitive_util::NativeToPrimitiveType<float>();
591        auto lhs_arg = builder.Parameter(
592            0, ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}),
593            "lhs");
594        auto rhs_arg = builder.Parameter(
595            1, ShapeUtil::MakeShape(prim_type, {rhs->height(), rhs->width()}),
596            "rhs");
597        if (transpose_lhs) {
598          lhs_arg = builder.Transpose(lhs_arg, {1, 0});
599        }
600        if (transpose_rhs) {
601          rhs_arg = builder.Transpose(rhs_arg, {1, 0});
602        }
603        auto result = builder.Dot(lhs_arg, rhs_arg);
604
605        Array2D<float> expected({{26.0, 0.0}, {-12.0, 10.0}});
606        VLOG(1) << "TestTransposeFolding " << transpose_lhs << " "
607                << transpose_rhs << " " << row_major;
608        ComputeAndCompareR2<float>(&builder, expected,
609                                   {lhs_handle.get(), rhs_handle.get()},
610                                   error_spec_);
611      }
612    }
613  }
614}
615
616TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstLHS) {
617  auto prim_type = primitive_util::NativeToPrimitiveType<float>();
618
619  std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
620      {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
621
622  ComputationBuilder builder(client_, TestName());
623  auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
624  auto rhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}),
625                                     "rhs_arg_0");
626  auto rhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}),
627                                     "rhs_arg_1");
628  auto rhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {1, 2}),
629                                     "rhs_arg_2");
630  auto result = builder.Dot(
631      lhs_constant, builder.ConcatInDim({rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0));
632
633  std::unique_ptr<Array2D<float>> arg_0_value_array(
634      new Array2D<float>({{1.0, 2.0}, {3.0, 4.0}}));
635  std::unique_ptr<Array2D<float>> arg_1_value_array(
636      new Array2D<float>({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}));
637  std::unique_ptr<Array2D<float>> arg_2_value_array(
638      new Array2D<float>({{1.0, 2.0}}));
639
640  TF_ASSERT_OK_AND_ASSIGN(
641      auto arg_0_value,
642      client_->TransferToServer(
643          *Literal::CreateR2FromArray2D<float>(*arg_0_value_array)));
644  TF_ASSERT_OK_AND_ASSIGN(
645      auto arg_1_value,
646      client_->TransferToServer(
647          *Literal::CreateR2FromArray2D<float>(*arg_1_value_array)));
648  TF_ASSERT_OK_AND_ASSIGN(
649      auto arg_2_value,
650      client_->TransferToServer(
651          *Literal::CreateR2FromArray2D<float>(*arg_2_value_array)));
652
653  Array2D<float> expected({{53.0, 74.0}, {45.0, 66.0}});
654  ComputeAndCompareR2<float>(
655      &builder, expected,
656      {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, error_spec_);
657}
658
659TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstRHS) {
660  auto prim_type = primitive_util::NativeToPrimitiveType<float>();
661
662  std::unique_ptr<Array2D<float>> constant_rhs_array(
663      new Array2D<float>({{1.0, 2.0},
664                          {3.0, 4.0},
665                          {5.0, 6.0},
666                          {6.0, 5.0},
667                          {4.0, 3.0},
668                          {2.0, 1.0}}));
669
670  ComputationBuilder builder(client_, TestName());
671  auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
672  auto lhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}),
673                                     "lhs_arg_0");
674  auto lhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 3}),
675                                     "lhs_arg_1");
676  auto lhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {2, 1}),
677                                     "lhs_arg_2");
678  auto result = builder.Dot(
679      builder.ConcatInDim({lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1), rhs_constant);
680
681  std::unique_ptr<Array2D<float>> arg_0_value_array(
682      new Array2D<float>({{1.0, 2.0}, {3.0, 4.0}}));
683  std::unique_ptr<Array2D<float>> arg_1_value_array(
684      new Array2D<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
685  std::unique_ptr<Array2D<float>> arg_2_value_array(
686      new Array2D<float>({{1.0}, {2.0}}));
687
688  TF_ASSERT_OK_AND_ASSIGN(
689      auto arg_0_value,
690      client_->TransferToServer(
691          *Literal::CreateR2FromArray2D<float>(*arg_0_value_array)));
692  TF_ASSERT_OK_AND_ASSIGN(
693      auto arg_1_value,
694      client_->TransferToServer(
695          *Literal::CreateR2FromArray2D<float>(*arg_1_value_array)));
696  TF_ASSERT_OK_AND_ASSIGN(
697      auto arg_2_value,
698      client_->TransferToServer(
699          *Literal::CreateR2FromArray2D<float>(*arg_2_value_array)));
700
701  Array2D<float> expected({{38.0, 36.0}, {93.0, 91.0}});
702  ComputeAndCompareR2<float>(
703      &builder, expected,
704      {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, error_spec_);
705}
706}  // namespace
707}  // namespace xla
708