11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <memory>
171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <numeric>
181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <vector>
191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/array2d.h"
211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/array4d.h"
221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation_builder.h"
231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/local_client.h"
241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/literal_util.h"
251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/statusor.h"
265c8acccfc9e90d694a8394f5522097bfe87379b2A. Unique TensorFlower#include "tensorflow/compiler/xla/test.h"
271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/literal_test_util.h"
291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/test_macros.h"
301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla {
321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace {
331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
345f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlowerclass BroadcastSimpleTest : public ClientLibraryTestBase {
355f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower public:
365f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  ComputationDataHandle BuildBinOp(HloOpcode op,
375f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower                                   const ComputationDataHandle& lhs,
385f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower                                   const ComputationDataHandle& rhs,
395f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower                                   ComputationBuilder* builder) {
405f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    switch (op) {
415f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      case HloOpcode::kMinimum: {
425f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower        return builder->Min(lhs, rhs);
435f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      }
445f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      case HloOpcode::kMaximum: {
455f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower        return builder->Max(lhs, rhs);
465f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      }
475f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      case HloOpcode::kMultiply: {
485f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower        return builder->Mul(lhs, rhs);
495f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      }
505f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      default: {
515f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower        // Default to Add
525f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower        return builder->Add(lhs, rhs);
535f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      }
545f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    }
555f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  }
565f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower
575f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  std::unique_ptr<GlobalData> MakeR3Data(
585f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      tensorflow::gtl::ArraySlice<int64> bounds,
595f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      tensorflow::gtl::ArraySlice<int64> minor_to_major, Shape* r3_shape,
605f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      Array3D<float>* r3_array, float start, float end, int seed) {
615f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
625f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    r3_array->FillRandom(start, end, seed);
6346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower    auto r3_data = Literal::CreateR3FromArray3D(*r3_array)->Relayout(
6446737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower        LayoutUtil::MakeLayout(minor_to_major));
655f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    std::unique_ptr<GlobalData> r3_global_data =
665f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower        client_->TransferToServer(*r3_data).ConsumeValueOrDie();
675f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    return r3_global_data;
685f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  }
695f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower
705f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  std::unique_ptr<GlobalData> MakeR2Data(
715f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      tensorflow::gtl::ArraySlice<int64> bounds,
725f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      tensorflow::gtl::ArraySlice<int64> minor_to_major, Shape* r2_shape,
735f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      Array2D<float>* r2_array, float start, float end, int seed) {
745f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
755f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    r2_array->FillRandom(start, end, seed);
7646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower    auto r2_data = Literal::CreateR2FromArray2D(*r2_array)->Relayout(
7746737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower        LayoutUtil::MakeLayout(minor_to_major));
785f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    std::unique_ptr<GlobalData> r2_global_data =
795f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower        client_->TransferToServer(*r2_data).ConsumeValueOrDie();
805f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    return r2_global_data;
815f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  }
825f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower
835f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  float ApplyOpToFloats(HloOpcode op, float lhs, float rhs) {
845f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    switch (op) {
855f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      case HloOpcode::kMinimum: {
865f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower        return std::min(lhs, rhs);
875f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      }
885f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      case HloOpcode::kMaximum: {
895f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower        return std::max(lhs, rhs);
905f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      }
915f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      case HloOpcode::kMultiply: {
925f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower        return lhs * rhs;
935f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      }
945f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      case HloOpcode::kAdd: {
955f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower        return lhs + rhs;
965f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      }
975f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      default: {
985f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower        // Default to Add
996dd43ec8cb299459b835e50faa4f3ffad044098cA. Unique TensorFlower        LOG(FATAL);
1005f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      }
1015f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    }
1025f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  }
1035f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower};
1045f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower
1059992074410d0b8d7102b7a63ff5f01a1a4554357A. Unique TensorFlowerusing ::testing::HasSubstr;
1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) {
1081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder b(client_, TestName());
1091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  b.Broadcast(b.ConstantR0<float>(1.5), {});
1101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR0<float>(&b, 1.5, {}, ErrorSpec(0.0001));
1111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) {
1141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder b(client_, TestName());
1151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  b.Broadcast(b.ConstantR0<float>(2.25), {2, 3});
1161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array2D<float> expected(2, 3, 2.25);
1171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
1181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1200b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlowerXLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) {
1210b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower  ComputationBuilder b(client_, TestName());
1220b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower  ComputationDataHandle src;
1230b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower  std::unique_ptr<GlobalData> param_data =
1240b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower      CreateR0Parameter<float>(2.25f, /*parameter_number=*/0, /*name=*/"src",
1250b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower                               /*builder=*/&b, /*data_handle=*/&src);
1260b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower
1270b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower  b.Broadcast(src, {2, 3});
1280b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower  Array2D<float> expected(2, 3, 2.25);
1290b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower  ComputeAndCompareR2<float>(&b, expected, {param_data.get()},
1300b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower                             ErrorSpec(0.0001));
1310b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower}
1320b2d231920170035a5326d7d631ada3f9a472022A. Unique TensorFlower
1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) {
1341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder b(client_, TestName());
1351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  b.Broadcast(b.ConstantR0<float>(2.25), {2, 0});
1361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array2D<float> expected(2, 0);
1371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
1381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_0x2) {
1411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder b(client_, TestName());
1421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  b.Broadcast(b.ConstantR0<float>(2.25), {0, 2});
1431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array2D<float> expected(0, 2);
1441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(BroadcastSimpleTest, 1DTo2D) {
1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder b(client_, TestName());
1491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  b.Broadcast(b.ConstantR1<float>({1, 2, 3}), {2});
1501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array2D<float> expected(2, 3);
1521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  expected(0, 0) = 1;
1531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  expected(0, 1) = 2;
1541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  expected(0, 2) = 3;
1551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  expected(1, 0) = 1;
1561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  expected(1, 1) = 2;
1571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  expected(1, 2) = 3;
1581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
1591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
161074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins// Tests implicit broadcasting of PREDs.
162e56628b085ffa7922e5238537f6ebd6deee0f0ccA. Unique TensorFlowerXLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) {
163074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins  ComputationBuilder b(client_, TestName());
164074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins
165074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins  Array2D<bool> x_vals(2, 1);
166074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins  x_vals(0, 0) = true;
167074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins  x_vals(1, 0) = false;
168074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins  Array3D<bool> y_vals(2, 2, 1);
169074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins  y_vals(0, 0, 0) = false;
170074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins  y_vals(0, 1, 0) = false;
171074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins  y_vals(1, 0, 0) = true;
172074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins  y_vals(1, 1, 0) = true;
173074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins
174074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins  ComputationDataHandle x, y;
175074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins  auto x_data = CreateR2Parameter<bool>(x_vals, 0, "x", &b, &x);
176074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins  auto y_data = CreateR3Parameter<bool>(y_vals, 1, "y", &b, &y);
177e56628b085ffa7922e5238537f6ebd6deee0f0ccA. Unique TensorFlower  b.And(x, y, /*broadcast_dimensions=*/{1, 2});
178074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins
179074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins  Array3D<bool> expected(2, 2, 1);
180074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins  expected(0, 0, 0) = false;
181074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins  expected(0, 1, 0) = false;
182074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins  expected(1, 0, 0) = true;
183074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins  expected(1, 1, 0) = false;
184074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins
185074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins  ComputeAndCompareR3<bool>(&b, expected, {x_data.get(), y_data.get()});
186074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins}
187074b92397b77f50cc80f2a04aa28811f1b4f5c97Peter Hawkins
1881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) {
1891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder b(client_, TestName());
1901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  b.Broadcast(b.ConstantR1<float>({}), {2});
1911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array2D<float> expected(2, 0);
1931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
1941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(BroadcastSimpleTest, 1DToZeroElement2D) {
1971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder b(client_, TestName());
1981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  b.Broadcast(b.ConstantR1<float>({1, 2, 3}), {0});
1991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array2D<float> expected(0, 3);
2011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
2021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
2031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
2051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Verify that binary op and degenerate dimension broadcast work together in
2061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // the same operation.
2071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  //
2081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // The lhs shape [1, 2] is first broadcast up to [2, 1, 2] using in-dimension
2091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // broadcasting (broadcast_dimensions {1, 2}), then is added to the rhs shape
2101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // [2, 3, 1]. Degenerate dimension broadcasting then broadcasts the size one
2111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // dimensions.
2121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder b(client_, TestName());
2131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  b.Add(b.ConstantR2<float>({{1.0, 5.0}}),
21546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower        b.ConstantLiteral(*Literal::CreateR3<float>(
2161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
2171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        /*broadcast_dimensions=*/{1, 2});
2181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto expected =
22046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
22146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower                                {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
2221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
224688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower}
225688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
226688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerstruct R3ImplicitBroadcastSpec {
227688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  std::array<int64, 3> output_bounds;
228688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  std::array<int64, 3> minor2major_layout;
229688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  std::array<int64, 3> input_bounds;
230688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  HloOpcode op;
231688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower} kR3ImplicitBroadcastTestCases[] = {
232688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower    {{{1, 1, 1}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd},
233688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower    {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 5}}, HloOpcode::kMaximum},
234688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower    {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 1}}, HloOpcode::kMinimum},
235688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower    {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 1}}, HloOpcode::kMultiply},
236688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower    {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd},
237688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower    {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 5}}, HloOpcode::kAdd},
238688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower    {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 4, 1}}, HloOpcode::kAdd},
239688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower    {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 5}}, HloOpcode::kAdd},
240688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower    {{{3, 199, 5}}, {{2, 1, 0}}, {{1, 199, 1}}, HloOpcode::kMinimum},
241688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower    {{{3, 4, 199}}, {{2, 1, 0}}, {{1, 1, 199}}, HloOpcode::kAdd},
242688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower};
243688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
244688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerclass BroadcastR3ImplicitTest
245688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower    : public BroadcastSimpleTest,
246688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower      public ::testing::WithParamInterface<R3ImplicitBroadcastSpec> {};
247688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
248688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerXLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
249688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  const R3ImplicitBroadcastSpec& spec = GetParam();
250688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  ComputationBuilder builder(client_, TestName());
2515f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower
2525f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  Shape r3_shape, r3_implicit_shape;
253688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  Array3D<float> r3_array(spec.output_bounds[0], spec.output_bounds[1],
254688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower                          spec.output_bounds[2]);
255688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  Array3D<float> r3_implicit_array(spec.input_bounds[0], spec.input_bounds[1],
256688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower                                   spec.input_bounds[2]);
2575f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower
2585f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  std::unique_ptr<GlobalData> r3_global_data =
2595f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      MakeR3Data(spec.output_bounds, spec.minor2major_layout, &r3_shape,
2605f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower                 &r3_array, 1.0, 2.5, 56789);
261688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  std::unique_ptr<GlobalData> r3_implicit_global_data =
2625f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      MakeR3Data(spec.input_bounds, spec.minor2major_layout, &r3_implicit_shape,
2635f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower                 &r3_implicit_array, 1.0, 0.2, 56789);
264688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
265688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  auto r3_implicit_parameter = builder.Parameter(0, r3_implicit_shape, "input");
266688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  auto r3_parameter = builder.Parameter(1, r3_shape, "input");
2675f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  ComputationDataHandle op =
2685f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder);
269688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
270688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  Array3D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1],
271688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower                                spec.output_bounds[2]);
272688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  auto Each = ([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) {
273688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower    float r3_implicit = r3_implicit_array(indices[0] % spec.input_bounds[0],
274688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower                                          indices[1] % spec.input_bounds[1],
275688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower                                          indices[2] % spec.input_bounds[2]);
276688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower    float r3 = r3_array(indices[0], indices[1], indices[2]);
2775f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    *value = ApplyOpToFloats(spec.op, r3_implicit, r3);
278688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  });
279688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
280688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  int n1 = expected_array.n1();
281688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  int n2 = expected_array.n2();
282688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  int n3 = expected_array.n3();
283688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  for (int64 i = 0; i < n1; i++) {
284688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower    for (int64 j = 0; j < n2; j++) {
285688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower      for (int64 k = 0; k < n3; k++) {
286688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower        Each({i, j, k}, &expected_array(i, j, k));
287688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower      }
288688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower    }
289688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  }
29046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto expected = Literal::CreateR3FromArray3D(expected_array);
291688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  ComputeAndCompareLiteral(
292688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower      &builder, *expected,
293688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower      {r3_implicit_global_data.get(), r3_global_data.get()},
294688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower      ErrorSpec(1e-7, 1e-7));
295688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower}
296688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
297688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerINSTANTIATE_TEST_CASE_P(BroadcastR3ImplicitTestInstances,
298688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower                        BroadcastR3ImplicitTest,
299688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower                        ::testing::ValuesIn(kR3ImplicitBroadcastTestCases));
300688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
301688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower// r1 and r3's dim0 matches, and r1's dim1 and dim2 have size 1:
302688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerXLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) {
303688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  ComputationBuilder b(client_, TestName());
304688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  ComputationDataHandle r1h;
305688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  ComputationDataHandle r3h;
306688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
307688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  Array3D<float> r1d = {{{1}}, {{2}}};
308688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  Array3D<float> r3d = {{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}};
309688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  auto r1 = CreateR3Parameter(r1d, 1, "r1", &b, &r1h);
310688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  auto r3 = CreateR3Parameter(r3d, 0, "r3", &b, &r3h);
311688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
312688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  b.Add(r3h, r1h);
313688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
314688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  auto expected =
31546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}});
316688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
317688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()},
318688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower                           ErrorSpec(0.0001));
319688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower}
320688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
321688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerXLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) {
322688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  ComputationBuilder b(client_, TestName());
32346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1, 2}}}));
324688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  auto r3 = b.ConstantLiteral(
32546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
326688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  b.Add(r3, r1);
327688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
328688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  auto expected =
32946737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}});
330688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
331688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
332688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower}
333688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
334688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerXLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) {
335688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  ComputationBuilder b(client_, TestName());
33646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1}, {2}}}));
337688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  auto r3 = b.ConstantLiteral(
33846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
339688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  b.Add(r3, r1);
340688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
341688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  auto expected =
34246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}});
343688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
344688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
345688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower}
346688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
347688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerXLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) {
348688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  ComputationBuilder b(client_, TestName());
34946737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1, 2}, {3, 4}}}));
350688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  auto r3 = b.ConstantLiteral(
35146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
352688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  b.Add(r3, r1);
353688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
354688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  auto expected =
35546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}});
356688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
357688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
358688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower}
359688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
360688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerXLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) {
361688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  ComputationBuilder b(client_, TestName());
36246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
363688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  auto r3 = b.ConstantLiteral(
36446737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
365688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  b.Add(r3, r1);
366688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
367688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  auto expected =
36846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}});
369688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
370688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
371688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower}
372688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
373688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerXLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) {
374688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  ComputationBuilder b(client_, TestName());
37546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto r1 =
37646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      b.ConstantLiteral(*Literal::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
377688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  auto r3 = b.ConstantLiteral(
37846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
379688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  b.Add(r3, r1);
380688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
381688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  auto expected =
38246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}});
383688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
384688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
385688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower}
386688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
387688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerXLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) {
388688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  ComputationBuilder b(client_, TestName());
38946737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1}}}));
390688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  auto r3 = b.ConstantLiteral(
39146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
392688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  b.Add(r3, r1);
393688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
394688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  auto expected =
39546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}});
396688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
397688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
398688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower}
399688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
4005f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlowerstruct R2ImplicitBroadcastSpec {
4015f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  std::array<int64, 2> output_bounds;
4025f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  std::array<int64, 2> minor2major_layout;
4035f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  std::array<int64, 2> input_bounds1;
4045f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  std::array<int64, 2> input_bounds2;
4055f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  HloOpcode op1;
4065f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  HloOpcode op2;
4075f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower} kR2ImplicitBroadcastTestCases[] = {
4085f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    {{{2, 3}}, {{1, 0}}, {{2, 1}}, {{2, 1}}, HloOpcode::kAdd, HloOpcode::kAdd},
4095f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    {{{2, 3}}, {{1, 0}}, {{2, 1}}, {{1, 3}}, HloOpcode::kAdd, HloOpcode::kAdd},
4105f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    {{{2, 3}},
4115f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{1, 0}},
4125f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{2, 1}},
4135f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{1, 1}},
4145f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kAdd,
4155f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kMinimum},
4165f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    {{{2, 3}},
4175f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{1, 0}},
4185f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{1, 3}},
4195f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{1, 1}},
4205f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kAdd,
4215f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kMinimum},
4225f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    {{{2, 3}},
4235f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{1, 0}},
4245f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{1, 1}},
4255f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{1, 1}},
4265f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kAdd,
4275f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kMinimum},
4285f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    {{{2, 3}}, {{0, 1}}, {{2, 1}}, {{2, 1}}, HloOpcode::kAdd, HloOpcode::kAdd},
4295f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    {{{150, 150}},
4305f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{1, 0}},
4315f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{150, 1}},
4325f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{150, 1}},
4335f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kAdd,
4345f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kAdd},
4355f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    {{{150, 150}},
4365f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{1, 0}},
4375f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{150, 1}},
4385f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{1, 150}},
4395f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kAdd,
4405f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kAdd},
4415f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    {{{150, 150}},
4425f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{1, 0}},
4435f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{150, 1}},
4445f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{1, 1}},
4455f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kAdd,
4465f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kAdd},
4475f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    {{{50, 150}},
4485f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{1, 0}},
4495f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{50, 1}},
4505f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{50, 1}},
4515f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kAdd,
4525f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kAdd},
4535f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    {{{50, 150}},
4545f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{1, 0}},
4555f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{50, 1}},
4565f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{1, 150}},
4575f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kAdd,
4585f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kAdd},
4595f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    {{{50, 150}},
4605f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{1, 0}},
4615f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{50, 1}},
4625f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{1, 1}},
4635f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kAdd,
4645f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kAdd},
4655f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    {{{150, 50}},
4665f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{1, 0}},
4675f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{150, 1}},
4685f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{150, 1}},
4695f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kAdd,
4705f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kAdd},
4715f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    {{{150, 50}},
4725f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{1, 0}},
4735f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{150, 1}},
4745f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{1, 50}},
4755f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kAdd,
4765f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kAdd},
4775f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    {{{150, 50}},
4785f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{1, 0}},
4795f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{150, 1}},
4805f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     {{1, 1}},
4815f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kAdd,
4825f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower     HloOpcode::kAdd}};
4835f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower
4845f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlowerclass BroadcastR2ImplicitTest
4855f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    : public BroadcastSimpleTest,
4865f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      public ::testing::WithParamInterface<R2ImplicitBroadcastSpec> {};
4875f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower
4885f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower// Test r2 op1 r2_implicit_1 op2 r2_implicit_2
4895f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower// where R2 is a rank-2 operand, and r2_implicit_2 are two
4905f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower// rank-2 operands with degenerate dimensions:
4915f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlowerXLA_TEST_P(BroadcastR2ImplicitTest, Doit) {
4925f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  const R2ImplicitBroadcastSpec& spec = GetParam();
4935f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower
4945f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  ComputationBuilder builder(client_, TestName());
4955f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower
4965f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  // Operands with degenerate dimensions require implicit broadcasting:
4975f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  Shape r2_shape, r2_implicit_shape1, r2_implicit_shape2;
4985f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  Array2D<float> r2_array(spec.output_bounds[0], spec.output_bounds[1]);
4995f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  Array2D<float> r2_implicit_array1(spec.input_bounds1[0],
5005f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower                                    spec.input_bounds1[1]);
5015f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  Array2D<float> r2_implicit_array2(spec.input_bounds2[0],
5025f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower                                    spec.input_bounds2[1]);
5035f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower
5045f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  std::unique_ptr<GlobalData> r2_global_data =
5055f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      MakeR2Data(spec.output_bounds, spec.minor2major_layout, &r2_shape,
5065f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower                 &r2_array, 1.0, 2.5, 56789);
5075f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  std::unique_ptr<GlobalData> r2_implicit_global_data1 =
5085f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      MakeR2Data(spec.input_bounds1, spec.minor2major_layout,
5095f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower                 &r2_implicit_shape1, &r2_implicit_array1, 1.0, 0.2, 56789);
5105f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  std::unique_ptr<GlobalData> r2_implicit_global_data2 =
5115f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      MakeR2Data(spec.input_bounds2, spec.minor2major_layout,
5125f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower                 &r2_implicit_shape2, &r2_implicit_array2, 0.8, 0.4, 56789);
5135f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower
5145f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  auto r2_implicit_parameter1 =
5155f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      builder.Parameter(0, r2_implicit_shape1, "input0");
5165f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  auto r2_parameter = builder.Parameter(1, r2_shape, "input1");
5175f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  auto r2_implicit_parameter2 =
5185f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      builder.Parameter(2, r2_implicit_shape2, "input2");
5195f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower
5205f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  ComputationDataHandle op1 =
5215f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      BuildBinOp(spec.op1, r2_implicit_parameter1, r2_parameter, &builder);
5225f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  ComputationDataHandle op2 =
5235f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder);
5245f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower
5255f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  Array2D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1]);
5265f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower
5275f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  expected_array.Each([&](int64 i, int64 j, float* v) {
5285f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    float v1 = r2_implicit_array1(i % spec.input_bounds1[0],
5295f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower                                  j % spec.input_bounds1[1]);
5305f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    float v2 = r2_array(i, j);
5315f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    float v3 = r2_implicit_array2(i % spec.input_bounds2[0],
5325f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower                                  j % spec.input_bounds2[1]);
5335f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    float tmp = ApplyOpToFloats(spec.op1, v1, v2);
5345f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower    *v = ApplyOpToFloats(spec.op2, tmp, v3);
5355f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  });
5365f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower
53746737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto expected = Literal::CreateR2FromArray2D(expected_array);
5385f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower  ComputeAndCompareLiteral(
5395f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      &builder, *expected,
5405f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      {r2_implicit_global_data1.get(), r2_global_data.get(),
5415f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower       r2_implicit_global_data2.get()},
5425f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower      ErrorSpec(1e-6, 1e-6));
5435f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower}
5445f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower
5455f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlowerINSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances,
5465f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower                        BroadcastR2ImplicitTest,
5475f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower                        ::testing::ValuesIn(kR2ImplicitBroadcastTestCases));
5485f6cddc05ed8f90a72bc11230a81133d67624e7aA. Unique TensorFlower
549688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerXLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) {
550688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  ComputationBuilder b(client_, TestName());
55146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto r1 = b.ConstantLiteral(*Literal::CreateR2<float>({{1, 2}}));
55246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto r2 = b.ConstantLiteral(*Literal::CreateR2<float>({{1, 2}, {3, 4}}));
553688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  b.Add(r2, r1);
554688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
55546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto expected = Literal::CreateR2<float>({{2, 4}, {4, 6}});
556688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
557688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
558688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower}
559688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
560688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlowerXLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) {
561688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  ComputationBuilder b(client_, TestName());
56246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto r1 = b.ConstantLiteral(*Literal::CreateR2<float>({{1}, {2}}));
56346737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto r2 = b.ConstantLiteral(*Literal::CreateR2<float>({{1, 2}, {3, 4}}));
564688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  b.Add(r2, r1);
565688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
56646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto expected = Literal::CreateR2<float>({{2, 3}, {5, 6}});
567688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower
568688f5a6c77bd97e116f55e130872ea0713f9cdb1A. Unique TensorFlower  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
5691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
571902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt RouneXLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) {
572902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  ComputationBuilder b(client_, TestName());
573902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  auto r1 = b.ConstantR1<float>({10, 20});
574902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  auto r3 = b.ConstantLiteral(
57546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
576902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  b.Add(r3, r1, {0});
577902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune
57846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto expected =
57946737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR3<float>({{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}});
580902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune
581902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
582902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune}
583902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune
584902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt RouneXLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) {
585902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  ComputationBuilder b(client_, TestName());
586902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  auto r1 = b.ConstantR1<float>({10, 20});
587902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  auto r3 = b.ConstantLiteral(
58846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
589902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  b.Add(r1, r3, {1});
590902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune
59146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto expected =
59246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR3<float>({{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}});
593902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune
594902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
595902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune}
596902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune
597902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt RouneXLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) {
598902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  ComputationBuilder b(client_, TestName());
599902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  auto r1 = b.ConstantR1<float>({10, 20});
600902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  auto r3 = b.ConstantLiteral(
60146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
602902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  b.Add(r1, r3, {2});
603902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune
60446737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto expected =
60546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      Literal::CreateR3<float>({{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}});
606902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune
607902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
608902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune}
609902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune
610902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt RouneXLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
611902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  ComputationBuilder b(client_, TestName());
612902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  auto r1_0 = b.ConstantR1<float>({1000, 2000});
613902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  auto r1_1 = b.ConstantR1<float>({100, 200});
614902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  auto r1_2 = b.ConstantR1<float>({10, 20});
615902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  auto r3 = b.ConstantLiteral(
61646737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
617902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  for (int i = 0; i < 3; ++i) {
618902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune    r3 = b.Add(r1_0, r3, {0});
619902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune    r3 = b.Add(r3, r1_1, {1});
620902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune    r3 = b.Add(r1_2, r3, {2});
621902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  }
622902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  r3 = b.Mul(r3, b.ConstantR0<float>(-2));
623902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune
62446737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto expected = Literal::CreateR3<float>(
625902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune      {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}},
626902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune       {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}});
627902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune
628902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
629902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune}
630902e55230ef3136387fd9b3dc05e28fd06db95f9Bjarke Hammersholt Roune
631ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt RouneXLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
632ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune  ComputationBuilder b(client_, TestName());
633ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune  auto r1_0 = b.ConstantR1<float>({1000, 2000});
634ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune  auto r1_1 = b.ConstantR1<float>({100, 200});
635ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune  auto r1_2 = b.ConstantR1<float>({10, 20});
636ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune  auto r0 = b.ConstantR0<float>(3);
637ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune  auto r3 = b.Broadcast(r0, {2, 2, 2});
638ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune  for (int i = 0; i < 3; ++i) {
639ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune    r3 = b.Add(r1_0, r3, {0});
640ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune    r3 = b.Add(r3, r1_1, {1});
641ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune    r3 = b.Add(r1_2, r3, {2});
642ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune  }
643ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune  r3 = b.Mul(r3, b.ConstantR0<float>(-1));
644ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune
64546737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower  auto expected = Literal::CreateR3<float>(
646ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune      {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}},
647ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune       {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}});
648ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune
649ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
650ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune}
651ee6f27b647fd51b11f9795042c4f6941c77d1c86Bjarke Hammersholt Roune
6521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
6531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Binary dimension broadcasting of the smaller lhs ([2, 2] up to [2, 2, 2])
6541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // results in a shape incompatible with the lhs [2, 3, 1].
6551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder b(client_, TestName());
6561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  b.Add(b.ConstantR2<float>({{1.0, 5.0}, {1.0, 5.0}}),
65846737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower        b.ConstantLiteral(*Literal::CreateR3<float>(
6591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
6601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        /*broadcast_dimensions=*/{1, 2});
6611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result_status = Execute(&b, {});
6631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_FALSE(result_status.ok());
6645c8acccfc9e90d694a8394f5522097bfe87379b2A. Unique TensorFlower  EXPECT_THAT(result_status.status().error_message(),
6659992074410d0b8d7102b7a63ff5f01a1a4554357A. Unique TensorFlower              HasSubstr("broadcast dimension 0 mismatch"));
6661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
6671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) {
6691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Test invalid broadcasting with [1, 2] and [2, 3] inputs.
6701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder b(client_, TestName());
6711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  b.Add(b.ConstantR2<float>({{1.0, 2.0}}),
6731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        b.ConstantR2<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
6741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result_status = Execute(&b, {});
6761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_FALSE(result_status.ok());
6779992074410d0b8d7102b7a63ff5f01a1a4554357A. Unique TensorFlower  EXPECT_THAT(result_status.status().error_message(),
6789992074410d0b8d7102b7a63ff5f01a1a4554357A. Unique TensorFlower              HasSubstr("binary op BINOP_ADD with incompatible shapes"));
6791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
6801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsXLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) {
6821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Test invalid broadcasting with [1, 2] and [2, 3] inputs.
6831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder b(client_, TestName());
6841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  b.Add(b.ConstantR2<float>({{1.0, 2.0}}),
6861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        b.ConstantR2<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
6871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto result_status = Execute(&b, {});
6891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  EXPECT_FALSE(result_status.ok());
6909992074410d0b8d7102b7a63ff5f01a1a4554357A. Unique TensorFlower  EXPECT_THAT(result_status.status().error_message(),
6919992074410d0b8d7102b7a63ff5f01a1a4554357A. Unique TensorFlower              HasSubstr("binary op BINOP_ADD with incompatible shapes"));
6921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
6931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
6941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace
6951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace xla
696