1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <memory>
17#include <utility>
18
19#include "tensorflow/compiler/xla/literal_util.h"
20#include "tensorflow/compiler/xla/ptr_util.h"
21#include "tensorflow/compiler/xla/service/hlo_computation.h"
22#include "tensorflow/compiler/xla/service/hlo_instruction.h"
23#include "tensorflow/compiler/xla/service/hlo_module.h"
24#include "tensorflow/compiler/xla/shape_util.h"
25#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
26#include "tensorflow/compiler/xla/tests/literal_test_util.h"
27#include "tensorflow/compiler/xla/tests/test_macros.h"
28#include "tensorflow/compiler/xla/xla_data.pb.h"
29#include "tensorflow/core/platform/test.h"
30
31namespace xla {
32namespace {
33
34class BroadcastTest : public HloTestBase {};
35
36XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
37  // Test degenerate case of broadcasting a scalar into a scalar.
38  auto builder = HloComputation::Builder(TestName());
39  auto input = builder.AddInstruction(
40      HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
41  builder.AddInstruction(HloInstruction::CreateBroadcast(
42      ShapeUtil::MakeShape(F32, {}), input, {}));
43
44  // Create HLO module, compile, and execute.
45  auto hlo_module = CreateNewModule();
46  hlo_module->AddEntryComputation(builder.Build());
47  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
48
49  LiteralTestUtil::ExpectNear(*Literal::CreateR0<float>(42.0), *result,
50                              error_spec_);
51}
52
53XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
54  auto builder = HloComputation::Builder(TestName());
55  auto input = builder.AddInstruction(
56      HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
57  builder.AddInstruction(HloInstruction::CreateBroadcast(
58      ShapeUtil::MakeShape(F32, {2, 2}), input, {}));
59
60  // Create HLO module, compile, and execute.
61  auto hlo_module = CreateNewModule();
62  hlo_module->AddEntryComputation(builder.Build());
63  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
64
65  LiteralTestUtil::ExpectNear(
66      *Literal::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), *result,
67      error_spec_);
68}
69
70XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
71  auto builder = HloComputation::Builder(TestName());
72  auto input = builder.AddInstruction(HloInstruction::CreateConstant(
73      Literal::CreateR1<float>({1.0, 2.0, 3.0})));
74
75  // Broadcast vector in both dimension 0 and dimension 1. Join them in a tuple
76  // to enable testing of the results.
77  auto element1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
78      ShapeUtil::MakeShape(F32, {3, 2}), input, {0}));
79  auto element2 = builder.AddInstruction(HloInstruction::CreateBroadcast(
80      ShapeUtil::MakeShape(F32, {2, 3}), input, {1}));
81  builder.AddInstruction(HloInstruction::CreateTuple({element1, element2}));
82
83  // Create HLO module, compile, and execute.
84  auto hlo_module = CreateNewModule();
85  hlo_module->AddEntryComputation(builder.Build());
86  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
87
88  LiteralTestUtil::ExpectNear(
89      *Literal::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
90      LiteralView::Create(*result, {0}), error_spec_);
91
92  LiteralTestUtil::ExpectNear(
93      *Literal::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
94      LiteralView::Create(*result, {1}), error_spec_);
95}
96
97XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
98  auto builder = HloComputation::Builder(TestName());
99  auto input = builder.AddInstruction(HloInstruction::CreateConstant(
100      Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
101  builder.AddInstruction(HloInstruction::CreateBroadcast(
102      ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1}));
103
104  // Create HLO module, compile, and execute.
105  auto hlo_module = CreateNewModule();
106  hlo_module->AddEntryComputation(builder.Build());
107  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
108
109  LiteralTestUtil::ExpectNear(
110      *Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), *result,
111      error_spec_);
112}
113
114XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
115  // Degenerately broadcasting a shape into a shape of the same rank reorders
116  // the dimensions, ie transpose.
117  auto builder = HloComputation::Builder(TestName());
118  auto input = builder.AddInstruction(HloInstruction::CreateConstant(
119      Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
120  builder.AddInstruction(HloInstruction::CreateBroadcast(
121      ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0}));
122
123  // Create HLO module, compile, and execute.
124  auto hlo_module = CreateNewModule();
125  hlo_module->AddEntryComputation(builder.Build());
126  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
127
128  LiteralTestUtil::ExpectNear(
129      *Literal::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), *result,
130      error_spec_);
131}
132
133XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
134  auto builder = HloComputation::Builder(TestName());
135  auto input = builder.AddInstruction(HloInstruction::CreateConstant(
136      Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
137  builder.AddInstruction(HloInstruction::CreateBroadcast(
138      ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2}));
139
140  // Create HLO module, compile, and execute.
141  auto hlo_module = CreateNewModule();
142  hlo_module->AddEntryComputation(builder.Build());
143  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
144
145  LiteralTestUtil::ExpectNear(
146      *Literal::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
147                                 {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
148      *result, error_spec_);
149}
150
151TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
152  auto builder = HloComputation::Builder(TestName());
153  auto input = builder.AddInstruction(
154      HloInstruction::CreateConstant(Literal::CreateR1<float>({1.0, 2.0})));
155
156  // Broadcast vector in dimension 1.
157  builder.AddInstruction(HloInstruction::CreateBroadcast(
158      ShapeUtil::MakeShape(F32, {2, 2, 3, 3}), input, {1}));
159
160  // Create HLO module, compile, and execute.
161  auto hlo_module = CreateNewModule();
162  hlo_module->AddEntryComputation(builder.Build());
163  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
164
165  Array4D<float> expected(2, 2, 3, 3);
166  Array2D<float> pz({{1, 2}, {1, 2}});
167  expected.FillWithPZ(pz);
168
169  LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
170                              *result, error_spec_);
171}
172
173TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
174  auto builder = HloComputation::Builder(TestName());
175  std::vector<float> input_data(1025);
176  int64 r1_size = input_data.size();
177  std::iota(input_data.begin(), input_data.end(), 0.0f);
178  auto input = builder.AddInstruction(
179      HloInstruction::CreateConstant(Literal::CreateR1<float>(input_data)));
180
181  // Broadcast vector in dimension 3.
182  builder.AddInstruction(HloInstruction::CreateBroadcast(
183      ShapeUtil::MakeShape(F32, {3, 3, 3, r1_size}), input, {3}));
184
185  // Create HLO module, compile, and execute.
186  auto hlo_module = CreateNewModule();
187  hlo_module->AddEntryComputation(builder.Build());
188  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
189
190  Array4D<float> expected(3, 3, 3, 1025);
191  Array2D<float> yx(3, r1_size);
192  for (int64 y = 0; y < 3; ++y) {
193    for (int64 x = 0; x < r1_size; ++x) {
194      yx(y, x) = input_data[x];
195    }
196  }
197  expected.FillWithYX(yx);
198
199  LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
200                              *result, error_spec_);
201}
202
203XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
204  auto builder = HloComputation::Builder(TestName());
205  Array4D<float> r4_array(32, 64, 7, 7);
206  r4_array.Fill(42.0);
207  std::vector<float> r1_array(64, 42.0);
208
209  auto input = builder.AddInstruction(
210      HloInstruction::CreateConstant(Literal::CreateR1<float>(r1_array)));
211
212  // Broadcast vector in dimension 1.
213  builder.AddInstruction(HloInstruction::CreateBroadcast(
214      ShapeUtil::MakeShape(F32, {32, 64, 7, 7}), input, {1}));
215
216  // Create HLO module, compile, and execute.
217  auto hlo_module = CreateNewModule();
218  hlo_module->AddEntryComputation(builder.Build());
219  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
220
221  LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(r4_array), *result,
222                              error_spec_);
223}
224
225TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
226  auto builder = HloComputation::Builder(TestName());
227  auto input = builder.AddInstruction(
228      HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
229  builder.AddInstruction(HloInstruction::CreateBroadcast(
230      ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {}));
231
232  // Create HLO module, compile, and execute.
233  auto hlo_module = CreateNewModule();
234  hlo_module->AddEntryComputation(builder.Build());
235  LOG(INFO) << hlo_module->ToString();
236  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
237
238  Array4D<float> expected(64, 64, 3, 3);
239  expected.Fill(1.0f);
240
241  LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
242                              *result, error_spec_);
243}
244
245TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
246  auto builder = HloComputation::Builder(TestName());
247  Array2D<float> to_broadcast({{1.0f, 2.0f}, {3.0f, 4.0f}});
248  auto input = builder.AddInstruction(HloInstruction::CreateConstant(
249      Literal::CreateR2FromArray2D<float>(to_broadcast)));
250
251  // Broadcast vector in dimensions 2 and 3.
252  builder.AddInstruction(HloInstruction::CreateBroadcast(
253      ShapeUtil::MakeShape(F32, {3, 3, 2, 2}), input, {2, 3}));
254
255  // Create HLO module, compile, and execute.
256  auto hlo_module = CreateNewModule();
257  hlo_module->AddEntryComputation(builder.Build());
258  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
259
260  Array4D<float> expected(3, 3, 2, 2);
261  expected.FillWithYX(to_broadcast);
262
263  LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
264                              *result, error_spec_);
265}
266
267TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
268  auto builder = HloComputation::Builder(TestName());
269  Array3D<float> input_vals(2, 3, 4);
270  input_vals.FillRandom(1.0);
271
272  Array4D<float> expected(2, 3, 4, 5);
273  for (int i = 0; i < 2; ++i) {
274    for (int j = 0; j < 3; ++j) {
275      for (int k = 0; k < 4; ++k) {
276        for (int m = 0; m < 5; ++m) {
277          expected(i, j, k, m) = input_vals(i, j, k);
278        }
279      }
280    }
281  }
282  auto input = builder.AddInstruction(HloInstruction::CreateConstant(
283      Literal::CreateR3FromArray3D<float>(input_vals)));
284
285  // Broadcast vector in dimensions 2 and 3.
286  builder.AddInstruction(HloInstruction::CreateBroadcast(
287      ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), input, {0, 1, 2}));
288
289  // Create HLO module, compile, and execute.
290  auto hlo_module = CreateNewModule();
291  hlo_module->AddEntryComputation(builder.Build());
292  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
293
294  LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
295                              *result, error_spec_);
296}
297
298}  // namespace
299}  // namespace xla
300