1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include <memory> 17#include <vector> 18 19#include "tensorflow/compiler/xla/array2d.h" 20#include "tensorflow/compiler/xla/array3d.h" 21#include "tensorflow/compiler/xla/client/computation_builder.h" 22#include "tensorflow/compiler/xla/client/local_client.h" 23#include "tensorflow/compiler/xla/primitive_util.h" 24#include "tensorflow/compiler/xla/reference_util.h" 25#include "tensorflow/compiler/xla/shape_util.h" 26#include "tensorflow/compiler/xla/tests/client_library_test_base.h" 27#include "tensorflow/compiler/xla/tests/literal_test_util.h" 28#include "tensorflow/compiler/xla/tests/test_macros.h" 29#include "tensorflow/compiler/xla/tests/test_utils.h" 30#include "tensorflow/core/platform/test.h" 31#include "tensorflow/core/platform/types.h" 32#include "tensorflow/core/util/command_line_flags.h" 33 34namespace xla { 35namespace { 36 37// TODO(b/34468543): use GUnit typed tests when we can do all tests on all 38// backends. 39class DotOperationTest : public ClientLibraryTestBase { 40 public: 41 ErrorSpec error_spec_{0.0001, 1e-5}; 42 43 protected: 44 template <typename Element> 45 void TestOneElementVectorDot(); 46 template <typename Element> 47 void TestVectorDot(); 48 template <typename Element> 49 void TestSquareMatrixDot(bool lhs_row_major = false, 50 bool rhs_row_major = false); 51 template <typename Element> 52 void TestNonsquareMatrixDot(bool lhs_row_major = false, 53 bool rhs_row_major = false); 54}; 55 56XLA_TEST_F(DotOperationTest, ZeroElementVectorDotF32) { 57 ComputationBuilder builder(client_, TestName()); 58 auto lhs = builder.ConstantR1<float>({}); 59 auto rhs = builder.ConstantR1<float>({}); 60 auto result = builder.Dot(lhs, rhs); 61 62 ComputeAndCompareR0<float>(&builder, 0.0, {}, error_spec_); 63} 64 65XLA_TEST_F(DotOperationTest, TrivialMatrixVectorDotF32) { 66 ComputationBuilder builder(client_, TestName()); 67 auto lhs = builder.ConstantR2<float>({{3.0, 4.0}}); 68 auto rhs = builder.ConstantR1<float>({3.0, 4.0}); 69 auto result = builder.Dot(lhs, rhs); 70 71 ComputeAndCompareR1<float>(&builder, {25.0}, {}, error_spec_); 72} 73 74template <typename Element> 75void DotOperationTest::TestOneElementVectorDot() { 76 ComputationBuilder builder(client_, TestName()); 77 auto lhs = builder.ConstantR1<Element>({2.0}); 78 auto rhs = builder.ConstantR1<Element>({3.0}); 79 auto result = builder.Dot(lhs, rhs); 80 81 ComputeAndCompareR0<Element>(&builder, 6.0, {}, error_spec_); 82} 83 84XLA_TEST_F(DotOperationTest, OneElementVectorDotF32) { 85 TestOneElementVectorDot<float>(); 86} 87 88XLA_TEST_F(DotOperationTest, OneElementVectorDotF64) { 89 TestOneElementVectorDot<double>(); 90} 91 92template <typename Element> 93void DotOperationTest::TestVectorDot() { 94 ComputationBuilder builder(client_, TestName()); 95 auto lhs = builder.ConstantR1<Element>({1.0, 2.5, 42.0}); 96 auto rhs = builder.ConstantR1<Element>({11.0, -1.0, 0.5}); 97 auto result = builder.Dot(lhs, rhs); 98 99 ComputeAndCompareR0<Element>(&builder, 29.5, {}, error_spec_); 100} 101 102XLA_TEST_F(DotOperationTest, VectorDotF32) { TestVectorDot<float>(); } 103 104XLA_TEST_F(DotOperationTest, VectorDotF64) { TestVectorDot<double>(); } 105 106namespace { 107 108std::vector<int64> MinorToMajorForIsRowMajor(bool row_major) { 109 return {row_major ? 1 : 0, row_major ? 0 : 1}; 110} 111 112} // namespace 113 114XLA_TEST_F(DotOperationTest, Dot_0x2_2x0) { 115 ComputationBuilder builder(client_, TestName()); 116 auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2)); 117 auto rhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(2, 0)); 118 auto result = builder.Dot(lhs, rhs); 119 120 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {}, error_spec_); 121} 122 123XLA_TEST_F(DotOperationTest, Dot_0x2_2x3) { 124 ComputationBuilder builder(client_, TestName()); 125 auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2)); 126 auto rhs = builder.ConstantR2<float>({{7.0, 8.0, 9.0}, {42.0, 77.0, 101.0}}); 127 auto result = builder.Dot(lhs, rhs); 128 129 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 3), {}, error_spec_); 130} 131 132XLA_TEST_F(DotOperationTest, Dot_3x2_2x0) { 133 ComputationBuilder builder(client_, TestName()); 134 auto lhs = 135 builder.ConstantR2<float>({{7.0, 8.0}, {9.0, 42.0}, {77.0, 101.0}}); 136 auto rhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(2, 0)); 137 auto result = builder.Dot(lhs, rhs); 138 139 ComputeAndCompareR2<float>(&builder, Array2D<float>(3, 0), {}, error_spec_); 140} 141 142XLA_TEST_F(DotOperationTest, Dot_2x0_0x2) { 143 ComputationBuilder builder(client_, TestName()); 144 auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(2, 0)); 145 auto rhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2)); 146 auto result = builder.Dot(lhs, rhs); 147 148 ComputeAndCompareR2<float>(&builder, Array2D<float>(2, 2, 0.0f), {}, 149 error_spec_); 150} 151 152XLA_TEST_F(DotOperationTest, FusedDot) { 153 ComputationBuilder builder(client_, TestName()); 154 auto param0 = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 4}), "arg0"); 155 auto param1 = builder.Parameter(1, ShapeUtil::MakeShape(F32, {4, 1}), "arg1"); 156 auto exp0 = builder.Exp(param0); 157 auto result = builder.Dot(exp0, param1); 158 159 auto lhs_handle = client_ 160 ->TransferToServer(*Literal::CreateR2<float>( 161 {{1.0, 2.0, 3.0, 4.0}, {-1.0, -2.0, -3.0, -4.0}})) 162 .ConsumeValueOrDie(); 163 auto rhs_handle = client_ 164 ->TransferToServer(*Literal::CreateR2<float>( 165 {{1.0}, {2.0}, {3.0}, {4.0}})) 166 .ConsumeValueOrDie(); 167 168 ComputeAndCompareR2<float>( 169 &builder, Array2D<float>({{296.14560492846033}, {0.8611737683031964}}), 170 {lhs_handle.get(), rhs_handle.get()}, error_spec_); 171} 172 173template <typename Element> 174void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major, 175 bool rhs_row_major) { 176 auto lhs_handle = 177 client_ 178 ->TransferToServer(*Literal::CreateR2WithLayout<Element>( 179 {{1.0, 2.0}, {3.0, -4.0}}, 180 LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major)))) 181 .ConsumeValueOrDie(); 182 auto rhs_handle = 183 client_ 184 ->TransferToServer(*Literal::CreateR2WithLayout<Element>( 185 {{1.0, 6.0}, {7.0, -4.0}}, 186 LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major)))) 187 .ConsumeValueOrDie(); 188 189 ComputationBuilder builder(client_, TestName()); 190 auto prim_type = primitive_util::NativeToPrimitiveType<Element>(); 191 auto result = builder.Dot( 192 builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"), 193 builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs")); 194 195 Array2D<Element> expected({{15.0, -2.0}, {-25.0, 34.0}}); 196 ComputeAndCompareR2<Element>( 197 &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); 198} 199 200struct DotTestParam { 201 int m; 202 int k; 203 int n; 204 bool dot_lhs_row_major; 205 bool dot_rhs_row_major; 206 bool has_addend; 207 bool addend_row_major; 208}; 209 210string PrintDotTestParam( 211 const ::testing::TestParamInfo<DotTestParam>& test_param) { 212 const DotTestParam& param = test_param.param; 213 if (param.has_addend) { 214 return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n, 215 "_MajorToMinor", 216 param.dot_lhs_row_major ? "T" : "F", 217 param.dot_rhs_row_major ? "T" : "F", 218 param.addend_row_major ? "T" : "F"); 219 } else { 220 return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n, 221 "_MajorToMinor", 222 param.dot_lhs_row_major ? "T" : "F", 223 param.dot_rhs_row_major ? "T" : "F"); 224 } 225} 226 227class ParametricDotTest : public DotOperationTest, 228 public ::testing::WithParamInterface<DotTestParam> {}; 229 230XLA_TEST_P(ParametricDotTest, TestF32) { 231 DotTestParam param = GetParam(); 232 233 std::unique_ptr<Array2D<float>> dot_lhs_data = 234 MakeLinspaceArray2D(0.0, 1.0, param.m, param.k); 235 std::unique_ptr<Literal> dot_lhs_lit = Literal::CreateR2FromArray2DWithLayout( 236 *dot_lhs_data, LayoutUtil::MakeLayout( 237 MinorToMajorForIsRowMajor(param.dot_lhs_row_major))); 238 std::unique_ptr<GlobalData> dot_lhs_handle = 239 client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie(); 240 241 std::unique_ptr<Array2D<float>> dot_rhs_data = 242 MakeLinspaceArray2D(0.0, 1.0, param.k, param.n); 243 std::unique_ptr<Literal> dot_rhs_lit = Literal::CreateR2FromArray2DWithLayout( 244 *dot_rhs_data, LayoutUtil::MakeLayout( 245 MinorToMajorForIsRowMajor(param.dot_rhs_row_major))); 246 std::unique_ptr<GlobalData> dot_rhs_handle = 247 client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie(); 248 249 std::unique_ptr<Array2D<float>> addend_data; 250 std::unique_ptr<Literal> addend_lit; 251 std::unique_ptr<GlobalData> addend_handle; 252 253 if (param.has_addend) { 254 addend_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.n); 255 addend_lit = Literal::CreateR2FromArray2DWithLayout( 256 *addend_data, LayoutUtil::MakeLayout( 257 MinorToMajorForIsRowMajor(param.addend_row_major))); 258 addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie(); 259 } 260 261 ComputationBuilder builder(client_, TestName()); 262 auto prim_type = primitive_util::NativeToPrimitiveType<float>(); 263 auto result = builder.Dot( 264 builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {param.m, param.k}), 265 "dot_lhs"), 266 builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {param.k, param.n}), 267 "dot_rhs")); 268 269 if (param.has_addend) { 270 result = builder.Add( 271 result, 272 builder.Parameter( 273 2, ShapeUtil::MakeShape(prim_type, {param.m, param.n}), "addend")); 274 } 275 276 std::unique_ptr<Array2D<float>> expected; 277 if (param.has_addend) { 278 expected = ReferenceUtil::ApplyElementwise2D( 279 std::plus<float>(), 280 *ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data), 281 *addend_data); 282 } else { 283 expected = ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data); 284 } 285 286 std::vector<GlobalData*> args = {dot_lhs_handle.get(), dot_rhs_handle.get()}; 287 if (param.has_addend) { 288 args.push_back(addend_handle.get()); 289 } 290 291 ComputeAndCompareR2<float>(&builder, *expected, args, ErrorSpec(0.3, 3e-3)); 292} 293 294std::vector<DotTestParam> CreateDotTestParameters() { 295 std::vector<DotTestParam> params; 296 297 auto add_matrix_matrix_dot_test = [&](int m, int k, int n) { 298 for (bool lhs_row_major : {true, false}) { 299 for (bool rhs_row_major : {true, false}) { 300 params.push_back({/*m=*/m, /*k=*/k, /*n=*/n, 301 /*dot_lhs_row_major=*/lhs_row_major, 302 /*dot_rhs_row_major=*/rhs_row_major, 303 /*has_addend=*/false, /*addend_row_major=*/true}); 304 } 305 } 306 }; 307 308 auto add_matrix_vector_dot_test = [&](int k, int n) { 309 for (bool has_addend : {false, true}) { 310 params.push_back({/*m=*/1, /*k=*/k, /*n=*/n, 311 /*dot_lhs_row_major=*/true, /*dot_rhs_row_major=*/true, 312 /*has_addend=*/has_addend, /*addend_row_major=*/true}); 313 if (n != 1) { 314 params.push_back( 315 {/*m=*/n, /*k=*/k, /*n=*/1, 316 /*dot_lhs_row_major=*/true, /*dot_rhs_row_major=*/true, 317 /*has_addend=*/has_addend, /*addend_row_major=*/true}); 318 } 319 } 320 }; 321 322 add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7); 323 add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520); 324 add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520); 325 326 add_matrix_vector_dot_test(/*k=*/8, /*n=*/8); 327 add_matrix_vector_dot_test(/*k=*/130, /*n=*/8); 328 add_matrix_vector_dot_test(/*k=*/8, /*n=*/130); 329 add_matrix_vector_dot_test(/*k=*/290, /*n=*/130); 330 add_matrix_vector_dot_test(/*k=*/1, /*n=*/1); 331 add_matrix_vector_dot_test(/*k=*/1, /*n=*/16); 332 add_matrix_vector_dot_test(/*k=*/3, /*n=*/16); 333 add_matrix_vector_dot_test(/*k=*/3, /*n=*/3); 334 add_matrix_vector_dot_test(/*k=*/29, /*n=*/29); 335 add_matrix_vector_dot_test(/*k=*/8, /*n=*/2); 336 add_matrix_vector_dot_test(/*k=*/2, /*n=*/8); 337 add_matrix_vector_dot_test(/*k=*/259, /*n=*/258); 338 339 return params; 340} 341 342INSTANTIATE_TEST_CASE_P(DotTests, ParametricDotTest, 343 ::testing::ValuesIn(CreateDotTestParameters()), 344 PrintDotTestParam); 345 346XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) { 347 TestSquareMatrixDot<float>(false, false); 348} 349 350XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFT) { 351 TestSquareMatrixDot<float>(false, true); 352} 353 354XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTF) { 355 TestSquareMatrixDot<float>(true, false); 356} 357 358XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTT) { 359 TestSquareMatrixDot<float>(true, true); 360} 361 362XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFF) { 363 TestSquareMatrixDot<complex64>(false, false); 364} 365 366XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFT) { 367 TestSquareMatrixDot<complex64>(false, true); 368} 369 370XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTF) { 371 TestSquareMatrixDot<complex64>(true, false); 372} 373 374XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTT) { 375 TestSquareMatrixDot<complex64>(true, true); 376} 377 378XLA_TEST_F(DotOperationTest, SquareMatrixDotF64) { 379 TestSquareMatrixDot<double>(); 380} 381 382template <typename Element> 383void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major, 384 bool rhs_row_major) { 385 auto lhs_handle = 386 client_ 387 ->TransferToServer(*Literal::CreateR2WithLayout<Element>( 388 {{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}}, 389 LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major)))) 390 .ConsumeValueOrDie(); 391 auto rhs_handle = 392 client_ 393 ->TransferToServer(*Literal::CreateR2WithLayout<Element>( 394 {{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}}, 395 LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major)))) 396 .ConsumeValueOrDie(); 397 398 ComputationBuilder builder(client_, TestName()); 399 auto prim_type = primitive_util::NativeToPrimitiveType<Element>(); 400 auto result = builder.Dot( 401 builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"), 402 builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs")); 403 404 Array2D<Element> expected({{26.0, 0.0}, {-12.0, 10.0}}); 405 406 ComputeAndCompareR2<Element>( 407 &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); 408} 409 410XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFF) { 411 TestNonsquareMatrixDot<float>(false, false); 412} 413 414XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFT) { 415 TestNonsquareMatrixDot<float>(false, true); 416} 417 418XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTF) { 419 TestNonsquareMatrixDot<float>(true, false); 420} 421 422XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) { 423 TestNonsquareMatrixDot<float>(true, true); 424} 425 426XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF64) { 427 TestNonsquareMatrixDot<double>(); 428} 429 430XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFF) { 431 TestNonsquareMatrixDot<complex64>(false, false); 432} 433 434XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFT) { 435 TestNonsquareMatrixDot<complex64>(false, true); 436} 437 438XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTF) { 439 TestNonsquareMatrixDot<complex64>(true, false); 440} 441 442XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTT) { 443 TestNonsquareMatrixDot<complex64>(true, true); 444} 445 446XLA_TEST_F(DotOperationTest, MatrixVectorC64) { 447 auto lhs_handle = 448 client_ 449 ->TransferToServer(*Literal::CreateR2WithLayout<complex64>( 450 {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0}))) 451 .ConsumeValueOrDie(); 452 auto rhs_handle = 453 client_ 454 ->TransferToServer(*Literal::CreateR2WithLayout<complex64>( 455 {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}}, 456 LayoutUtil::MakeLayout({1, 0}))) 457 .ConsumeValueOrDie(); 458 459 ComputationBuilder builder(client_, TestName()); 460 auto prim_type = primitive_util::NativeToPrimitiveType<complex64>(); 461 auto result = builder.Dot( 462 builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"), 463 builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs")); 464 465 Array2D<complex64> expected({{30.0, -2.0}}); 466 467 ComputeAndCompareR2<complex64>( 468 &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); 469} 470 471XLA_TEST_F(DotOperationTest, ConcurrentMatMul) { 472 ComputationBuilder builder(client_, TestName()); 473 auto matrix1 = builder.ConstantR2<float>({{1.0, 2.0}, {3.0, 4.0}}); 474 auto matrix2 = builder.ConstantR2<float>({{5.0, 6.0}, {7.0, 8.0}}); 475 auto matrix12 = builder.Dot(matrix1, matrix2); 476 auto matrix21 = builder.Dot(matrix2, matrix1); 477 builder.Add(matrix12, matrix21); 478 479 Array2D<float> expected({{42.0, 56.0}, {74.0, 96.0}}); 480 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); 481} 482 483// Regression test for b/32055648. The root of the graph is a kFusion of 4 484// bitcasts. Although bitcasts don't map to thunks, the root should still be 485// sync-dependent on bitcasts' operands. 486XLA_TEST_F(DotOperationTest, BatchMatMul) { 487 ComputationBuilder builder(client_, TestName()); 488 auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "x"); 489 auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "y"); 490 491 auto x_flat = builder.Reshape(x, {0, 1, 2, 3}, {4, 2, 2}); 492 auto y_flat = builder.Reshape(y, {0, 1, 2, 3}, {4, 2, 2}); 493 494 // Slice batches into individual matrices and multiply them. 495 std::vector<xla::ComputationDataHandle> out_slices; 496 for (int i = 0; i < 4; ++i) { 497 // Slice off individual matrices and reshape to 2D tensors. 498 auto x_slice = builder.Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1}); 499 x_slice = builder.Reshape(x_slice, {0, 1, 2}, {2, 2}); 500 auto y_slice = builder.Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1}); 501 y_slice = builder.Reshape(y_slice, {0, 1, 2}, {2, 2}); 502 503 auto out = builder.Dot(x_slice, y_slice); 504 out = builder.Reshape(out, {0, 1}, {1, 2, 2}); 505 out_slices.push_back(out); 506 } 507 auto out_flat = builder.ConcatInDim(out_slices, 0); 508 builder.Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); 509 510 auto x_data = client_ 511 ->TransferToServer(*Literal::CreateR4<float>( 512 {{{{1000, 100}, {10, 1}}, {{2000, 200}, {20, 2}}}, 513 {{{3000, 300}, {30, 3}}, {{4000, 400}, {40, 4}}}})) 514 .ConsumeValueOrDie(); 515 auto y_data = client_ 516 ->TransferToServer(*Literal::CreateR4<float>( 517 {{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}, 518 {{{11, 22}, {33, 44}}, {{55, 66}, {77, 88}}}})) 519 .ConsumeValueOrDie(); 520 521 ComputeAndCompareR4<float>( 522 &builder, 523 /*expected=*/ 524 {{{{1300, 2400}, {13, 24}}, {{11400, 13600}, {114, 136}}}, 525 {{{42900, 79200}, {429, 792}}, {{250800, 299200}, {2508, 2992}}}}, 526 {x_data.get(), y_data.get()}, error_spec_); 527} 528 529XLA_TEST_F(DotOperationTest, GeneralMatMul) { 530 ComputationBuilder builder(client_, TestName()); 531 auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2}), "x"); 532 auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2, 2}), "y"); 533 534 DotDimensionNumbers dnums; 535 dnums.add_lhs_contracting_dimensions(2); 536 dnums.add_rhs_contracting_dimensions(1); 537 dnums.add_lhs_batch_dimensions(0); 538 dnums.add_rhs_batch_dimensions(0); 539 540 auto out = builder.DotGeneral(x, y, dnums); 541 542 auto x_data = client_ 543 ->TransferToServer(*Literal::CreateR3<float>( 544 {{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}})) 545 .ConsumeValueOrDie(); 546 547 auto y_data = client_ 548 ->TransferToServer(*Literal::CreateR3<float>( 549 {{{1.0, 0.0}, {0.0, 1.0}}, {{1.0, 0.0}, {0.0, 1.0}}})) 550 .ConsumeValueOrDie(); 551 552 ComputeAndCompareR3<float>( 553 &builder, 554 /*expected=*/ 555 {{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}}, 556 {x_data.get(), y_data.get()}, error_spec_); 557} 558 559TEST_F(DotOperationTest, TransposeFolding) { 560 for (bool transpose_lhs : {false, true}) { 561 for (bool transpose_rhs : {false, true}) { 562 for (bool row_major : {false, true}) { 563 std::unique_ptr<Array2D<float>> lhs( 564 new Array2D<float>({{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}})); 565 std::unique_ptr<Array2D<float>> rhs( 566 new Array2D<float>({{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}})); 567 568 if (transpose_lhs) { 569 lhs = ReferenceUtil::TransposeArray2D(*lhs); 570 } 571 if (transpose_rhs) { 572 rhs = ReferenceUtil::TransposeArray2D(*rhs); 573 } 574 auto lhs_handle = 575 client_ 576 ->TransferToServer( 577 *Literal::CreateR2FromArray2DWithLayout<float>( 578 *lhs, LayoutUtil::MakeLayout( 579 MinorToMajorForIsRowMajor(row_major)))) 580 .ConsumeValueOrDie(); 581 auto rhs_handle = 582 client_ 583 ->TransferToServer( 584 *Literal::CreateR2FromArray2DWithLayout<float>( 585 *rhs, LayoutUtil::MakeLayout( 586 MinorToMajorForIsRowMajor(row_major)))) 587 .ConsumeValueOrDie(); 588 589 ComputationBuilder builder(client_, TestName()); 590 auto prim_type = primitive_util::NativeToPrimitiveType<float>(); 591 auto lhs_arg = builder.Parameter( 592 0, ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}), 593 "lhs"); 594 auto rhs_arg = builder.Parameter( 595 1, ShapeUtil::MakeShape(prim_type, {rhs->height(), rhs->width()}), 596 "rhs"); 597 if (transpose_lhs) { 598 lhs_arg = builder.Transpose(lhs_arg, {1, 0}); 599 } 600 if (transpose_rhs) { 601 rhs_arg = builder.Transpose(rhs_arg, {1, 0}); 602 } 603 auto result = builder.Dot(lhs_arg, rhs_arg); 604 605 Array2D<float> expected({{26.0, 0.0}, {-12.0, 10.0}}); 606 VLOG(1) << "TestTransposeFolding " << transpose_lhs << " " 607 << transpose_rhs << " " << row_major; 608 ComputeAndCompareR2<float>(&builder, expected, 609 {lhs_handle.get(), rhs_handle.get()}, 610 error_spec_); 611 } 612 } 613 } 614} 615 616TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstLHS) { 617 auto prim_type = primitive_util::NativeToPrimitiveType<float>(); 618 619 std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>( 620 {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); 621 622 ComputationBuilder builder(client_, TestName()); 623 auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); 624 auto rhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), 625 "rhs_arg_0"); 626 auto rhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), 627 "rhs_arg_1"); 628 auto rhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {1, 2}), 629 "rhs_arg_2"); 630 auto result = builder.Dot( 631 lhs_constant, builder.ConcatInDim({rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0)); 632 633 std::unique_ptr<Array2D<float>> arg_0_value_array( 634 new Array2D<float>({{1.0, 2.0}, {3.0, 4.0}})); 635 std::unique_ptr<Array2D<float>> arg_1_value_array( 636 new Array2D<float>({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})); 637 std::unique_ptr<Array2D<float>> arg_2_value_array( 638 new Array2D<float>({{1.0, 2.0}})); 639 640 TF_ASSERT_OK_AND_ASSIGN( 641 auto arg_0_value, 642 client_->TransferToServer( 643 *Literal::CreateR2FromArray2D<float>(*arg_0_value_array))); 644 TF_ASSERT_OK_AND_ASSIGN( 645 auto arg_1_value, 646 client_->TransferToServer( 647 *Literal::CreateR2FromArray2D<float>(*arg_1_value_array))); 648 TF_ASSERT_OK_AND_ASSIGN( 649 auto arg_2_value, 650 client_->TransferToServer( 651 *Literal::CreateR2FromArray2D<float>(*arg_2_value_array))); 652 653 Array2D<float> expected({{53.0, 74.0}, {45.0, 66.0}}); 654 ComputeAndCompareR2<float>( 655 &builder, expected, 656 {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, error_spec_); 657} 658 659TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstRHS) { 660 auto prim_type = primitive_util::NativeToPrimitiveType<float>(); 661 662 std::unique_ptr<Array2D<float>> constant_rhs_array( 663 new Array2D<float>({{1.0, 2.0}, 664 {3.0, 4.0}, 665 {5.0, 6.0}, 666 {6.0, 5.0}, 667 {4.0, 3.0}, 668 {2.0, 1.0}})); 669 670 ComputationBuilder builder(client_, TestName()); 671 auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); 672 auto lhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), 673 "lhs_arg_0"); 674 auto lhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 3}), 675 "lhs_arg_1"); 676 auto lhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {2, 1}), 677 "lhs_arg_2"); 678 auto result = builder.Dot( 679 builder.ConcatInDim({lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1), rhs_constant); 680 681 std::unique_ptr<Array2D<float>> arg_0_value_array( 682 new Array2D<float>({{1.0, 2.0}, {3.0, 4.0}})); 683 std::unique_ptr<Array2D<float>> arg_1_value_array( 684 new Array2D<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); 685 std::unique_ptr<Array2D<float>> arg_2_value_array( 686 new Array2D<float>({{1.0}, {2.0}})); 687 688 TF_ASSERT_OK_AND_ASSIGN( 689 auto arg_0_value, 690 client_->TransferToServer( 691 *Literal::CreateR2FromArray2D<float>(*arg_0_value_array))); 692 TF_ASSERT_OK_AND_ASSIGN( 693 auto arg_1_value, 694 client_->TransferToServer( 695 *Literal::CreateR2FromArray2D<float>(*arg_1_value_array))); 696 TF_ASSERT_OK_AND_ASSIGN( 697 auto arg_2_value, 698 client_->TransferToServer( 699 *Literal::CreateR2FromArray2D<float>(*arg_2_value_array))); 700 701 Array2D<float> expected({{38.0, 36.0}, {93.0, 91.0}}); 702 ComputeAndCompareR2<float>( 703 &builder, expected, 704 {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, error_spec_); 705} 706} // namespace 707} // namespace xla 708