1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Tests the reduce-window XLA operation.
17
18#include <limits>
19#include <memory>
20
21#include "tensorflow/compiler/xla/array2d.h"
22#include "tensorflow/compiler/xla/array3d.h"
23#include "tensorflow/compiler/xla/array4d.h"
24#include "tensorflow/compiler/xla/client/computation_builder.h"
25#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
26#include "tensorflow/compiler/xla/client/local_client.h"
27#include "tensorflow/compiler/xla/client/padding.h"
28#include "tensorflow/compiler/xla/reference_util.h"
29#include "tensorflow/compiler/xla/shape_util.h"
30#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
31#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32#include "tensorflow/compiler/xla/tests/literal_test_util.h"
33#include "tensorflow/compiler/xla/tests/test_macros.h"
34#include "tensorflow/compiler/xla/xla_data.pb.h"
35#include "tensorflow/core/lib/core/status.h"
36#include "tensorflow/core/lib/core/status_test_util.h"
37#include "tensorflow/core/lib/gtl/array_slice.h"
38#include "tensorflow/core/platform/test.h"
39#include "tensorflow/core/platform/types.h"
40
41namespace xla {
42namespace {
43
44#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16
45// Tests both F32 and BF16.
46static std::array<bool, 2> use_bfloat16_params{false, true};
47#else
48// Only tests F32.
49static std::array<bool, 1> use_bfloat16_params{false};
50#endif
51
52class ReduceWindowTestBase : public ClientLibraryTestBase {
53 public:
54  ErrorSpec DefaultErrorSpec() const {
55    if (use_bfloat16()) {
56      return ErrorSpec(1e-1, 5e-2);
57    } else {
58      return ErrorSpec(1e-3, 1e-3);
59    }
60  }
61};
62
63class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
64                         public ReduceWindowTestBase {
65 public:
66  ReduceWindowTest() : builder_(client_, TestName()) {
67    set_use_bfloat16(GetParam());
68  }
69
70  void ReduceWindowAdd(const ComputationDataHandle& input,
71                       tensorflow::gtl::ArraySlice<int64> window_dimensions,
72                       tensorflow::gtl::ArraySlice<int64> window_strides,
73                       Padding padding) {
74    auto init =
75        CreateConstantFromLiteral(*Literal::CreateR0<float>(0.0f), &builder_);
76    builder_.ReduceWindow(input, init,
77                          CreateScalarAddComputation(FloatType(), &builder_),
78                          window_dimensions, window_strides, padding);
79  }
80
81  void ReduceWindowMax(const ComputationDataHandle& input,
82                       tensorflow::gtl::ArraySlice<int64> window_dimensions,
83                       tensorflow::gtl::ArraySlice<int64> window_strides,
84                       Padding padding) {
85    auto init = CreateConstantFromLiteral(Literal::MinValue(F32), &builder_);
86    builder_.ReduceWindow(input, init, CreateScalarMax(), window_dimensions,
87                          window_strides, padding);
88  }
89
90  void ReduceWindowMin(const ComputationDataHandle& input,
91                       tensorflow::gtl::ArraySlice<int64> window_dimensions,
92                       tensorflow::gtl::ArraySlice<int64> window_strides,
93                       Padding padding) {
94    auto init = CreateConstantFromLiteral(Literal::MaxValue(F32), &builder_);
95    builder_.ReduceWindow(input, init,
96                          CreateScalarMinComputation(FloatType(), &builder_),
97                          window_dimensions, window_strides, padding);
98  }
99
100  ComputationBuilder builder_;
101};
102
103TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
104  const auto input = CreateConstantFromLiteral(
105      *Literal::CreateR1<float>({1, 1, 1, 1}), &builder_);
106  const auto init_value =
107      CreateConstantFromLiteral(*Literal::CreateR0<float>(0), &builder_);
108  TF_ASSERT_OK(builder_.first_error());
109  builder_.ReduceWindow(input, init_value,
110                        CreateScalarAddComputation(FloatType(), &builder_),
111                        /*window_dimensions=*/{1, 2},
112                        /*window_strides=*/{1}, Padding::kValid);
113  ASSERT_EQ(builder_.first_error().code(), tensorflow::error::INVALID_ARGUMENT)
114      << builder_.first_error();
115  ASSERT_THAT(builder_.first_error().error_message(),
116              ::testing::HasSubstr("Want input dimensions size"));
117}
118
119// Regression test for b/68964348.
120TEST_P(ReduceWindowTest, R0ReduceWindow) {
121  const auto input =
122      CreateConstantFromLiteral(*Literal::CreateR0<float>(42.0), &builder_);
123  const auto init =
124      CreateConstantFromLiteral(*Literal::CreateR0<float>(1.0), &builder_);
125  builder_.ReduceWindow(input, init,
126                        CreateScalarAddComputation(FloatType(), &builder_),
127                        /*window_dimensions=*/{},
128                        /*window_strides=*/{}, Padding::kSame);
129  ComputeAndCompareLiteral(&builder_, *Literal::CreateR0<float>(43.0), {},
130                           ErrorSpec(0.00001));
131}
132
133TEST_P(ReduceWindowTest, Min3In5Stride2) {
134  const auto input = CreateConstantFromLiteral(
135      *Literal::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
136  ReduceWindowMin(input, {3}, {2}, Padding::kValid);
137  ComputeAndCompareLiteral(&builder_, *Literal::CreateR1<float>({100, 1}), {},
138                           ErrorSpec(0.00001));
139}
140
141TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) {
142  const auto input = CreateConstantFromLiteral(
143      *Literal::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
144  ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1},
145                  Padding::kSame);
146  ComputeAndCompareLiteral(&builder_,
147                           *Literal::CreateR1<float>({1000, 100, 10, 1, 1}), {},
148                           ErrorSpec(0.00001));
149}
150
151XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) {
152  Array4D<float> input_array(1, 0, 2, 1);
153  const auto input = CreateConstantFromArray(input_array, &builder_);
154  Padding padding = Padding::kSame;
155  ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding);
156
157  auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
158                                              {1, 1, 1, 1}, padding);
159
160  ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {},
161                           DefaultErrorSpec());
162}
163
164TEST_P(ReduceWindowTest, NonSquareSmall) {
165  Array4D<float> input_array(1, 2, 2, 1);
166  input_array.FillRandom(2.f, 2.f);
167  const auto input = CreateConstantFromArray(input_array, &builder_);
168
169  Padding padding = Padding::kSame;
170  ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding);
171
172  auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
173                                              {1, 1, 1, 1}, padding);
174
175  ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {},
176                           DefaultErrorSpec());
177}
178
179TEST_P(ReduceWindowTest, MiddleDimsSmall) {
180  Array4D<float> input_array(1, 3, 3, 1);
181  input_array.FillRandom(2.f, 2.f);
182  const auto input = CreateConstantFromArray(input_array, &builder_);
183  Padding padding = Padding::kSame;
184  ReduceWindowAdd(input, {1, 1, 1, 1}, {1, 2, 2, 1}, padding);
185
186  auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1},
187                                              {1, 2, 2, 1}, padding);
188
189  ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {},
190                           DefaultErrorSpec());
191}
192
193TEST_P(ReduceWindowTest, Along2ndMinorDim) {
194  Array4D<float> input_array(3, 6, 7, 32);
195  input_array.FillRandom(2.f, 2.f);
196  const auto input = CreateConstantFromArray(input_array, &builder_);
197
198  // The parameters of this reduction mimic feature norm (e.g. LRN).
199  int lrn_diameter = 7;  // diameter = 2*radius + 1 --> must be odd
200  Padding padding = Padding::kSame;
201  ReduceWindowAdd(input, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
202
203  auto res = ReferenceUtil::ReduceWindow4DAdd(
204      input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
205
206  ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {},
207                           DefaultErrorSpec());
208}
209
210TEST_P(ReduceWindowTest, AmongMajor2Dims) {
211  Array4D<float> input_array(4, 4, 6, 8);
212  input_array.FillWithMinorDimNum();
213  const auto input_data_handle =
214      CreateConstantFromArray(input_array, &builder_);
215
216  int win_len = 3;
217  int win_stride = 1;
218
219  Padding padding = Padding::kSame;
220  // Reduce only along the x and y dimensions, according to the win_len.
221  ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
222                  {win_stride, win_stride, 1, 1}, padding);
223
224  auto result = ReferenceUtil::ReduceWindow4DAdd(
225      input_array, 0.0f, {win_len, win_len, 1, 1},
226      {win_stride, win_stride, 1, 1}, padding);
227
228  ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {},
229                           DefaultErrorSpec());
230}
231
232TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
233  Array4D<float> input_array(9, 12, 4, 89);
234  input_array.FillRandom(2.f, 2.f);
235
236  int win_len = 3;
237  int win_stride = 2;
238
239  const auto input_data_handle =
240      CreateConstantFromArray(input_array, &builder_);
241
242  Padding padding = Padding::kSame;
243  // Reduce only along the x and y dimensions, according to the win_len.
244  ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
245                  {win_stride, win_stride, 1, 1}, padding);
246
247  auto result = ReferenceUtil::ReduceWindow4DAdd(
248      input_array, 0.0f, {win_len, win_len, 1, 1},
249      {win_stride, win_stride, 1, 1}, padding);
250
251  ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {},
252                           DefaultErrorSpec());
253}
254
255// Tests a reduction function that is not a simple add/min/max/etc.
256XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) {
257  Array4D<float> input_array(1, 2, 2, 1);
258  input_array(0, 0, 0, 0) = 1;
259  input_array(0, 0, 1, 0) = 2;
260  input_array(0, 1, 0, 0) = 3;
261  input_array(0, 1, 1, 0) = 4;
262  const auto input = CreateConstantFromArray(input_array, &builder_);
263
264  Padding padding = Padding::kValid;
265  const Shape scalar = ShapeUtil::MakeShape(FloatType(), {});
266  auto b = builder_.CreateSubBuilder("unusual");
267  auto lhs = b->Parameter(0, scalar, "lhs");
268  auto rhs = b->Parameter(1, scalar, "rhs");
269  b->Min(b->Add(lhs, rhs),
270         CreateConstantFromLiteral(*Literal::CreateR0<float>(8.0f), b.get()));
271  Computation reduce_fn = b->BuildAndNoteError();
272
273  builder_.ReduceWindow(
274      input,
275      CreateConstantFromLiteral(*Literal::CreateR0<float>(0.0f), &builder_),
276      reduce_fn,
277      /*window_dimensions=*/{1, 1, 2, 1},
278      /*window_strides=*/{1, 1, 1, 1}, padding);
279
280  const auto reduce_func = [](float arg1, float arg2) {
281    return std::min<float>(arg1 + arg2, 8.0f);
282  };
283
284  auto expected =
285      ReferenceUtil::ReduceWindow4DGeneric(input_array, 0.0f, reduce_func,
286                                           /*window=*/{1, 1, 2, 1},
287                                           /*stride=*/{1, 1, 1, 1}, padding);
288
289  ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*expected), {},
290                           DefaultErrorSpec());
291}
292
293TEST_P(ReduceWindowTest, R4UnitWindow) {
294  Array4D<float> input_array(13, 12, 8, 15);
295  input_array.FillRandom(2.f, 2.f);
296  std::unique_ptr<Literal> input_literal =
297      Literal::CreateR4FromArray4DWithLayout(
298          input_array, LayoutUtil::MakeLayout({0, 3, 2, 1}));
299  ComputationDataHandle input;
300  auto input_data = CreateParameterAndTransferLiteral(
301      0, *input_literal, "parameter", &builder_, &input);
302
303  Padding padding = Padding::kSame;
304  ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding);
305
306  auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1},
307                                              {1, 4, 1, 1}, padding);
308
309  ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res),
310                           {input_data.get()}, DefaultErrorSpec());
311}
312
313XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) {
314  std::vector<int64> input_dims(6, 8);
315  auto shape = ShapeUtil::MakeShape(F32, input_dims);
316
317  std::unique_ptr<Literal> arg_literal = Literal::CreateFromShape(shape);
318  auto generator = [&](tensorflow::gtl::ArraySlice<int64> indexes) -> float {
319    return 1.0f;
320  };
321  TF_EXPECT_OK(arg_literal->Populate<float>(generator));
322
323  const auto input = CreateConstantFromLiteral(*arg_literal, &builder_);
324
325  Padding padding = Padding::kValid;
326  ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
327
328  std::vector<int64> output_layout = {1, 5, 3, 2, 0, 4};
329  std::vector<int64> output_dims = {6, 8, 6, 6, 8, 8};
330  Shape result_shape =
331      ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout);
332  std::unique_ptr<Literal> expected = Literal::CreateFromShape(result_shape);
333  auto out_generator =
334      [&](tensorflow::gtl::ArraySlice<int64> indexes) -> float {
335    return 27.0f;
336  };
337  TF_EXPECT_OK(expected->Populate<float>(out_generator));
338
339  ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec());
340}
341
342XLA_TEST_P(ReduceWindowTest, R6Add) {
343  std::vector<int64> input_dims(6, 8);
344  auto shape = ShapeUtil::MakeShape(F32, input_dims);
345
346  std::unique_ptr<Literal> arg_literal =
347      Literal::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
348
349  const auto input = CreateConstantFromLiteral(*arg_literal, &builder_);
350
351  Padding padding = Padding::kValid;
352  ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
353
354  std::vector<int64> output_dims = {8, 8, 6, 6, 8, 8};
355  std::unique_ptr<Literal> expected =
356      Literal::CreateFullWithDescendingLayout<float>(output_dims, 9.0f);
357
358  ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec());
359}
360
361XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) {
362  Array4D<float> input_array(2, 1, 27, 119);
363  input_array.FillRandom(2.0f);
364  std::unique_ptr<Literal> input_literal =
365      Literal::CreateR4FromArray4DWithLayout(
366          input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
367  ComputationDataHandle input;
368  auto input_data = CreateParameterAndTransferLiteral(
369      0, *input_literal, "parameter", &builder_, &input);
370
371  int win_len = 1;
372  int stride = 8;
373  Padding padding = Padding::kSame;
374  ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
375
376  auto res = ReferenceUtil::ReduceWindow4DAdd(
377      input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
378
379  ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res),
380                           {input_data.get()}, DefaultErrorSpec());
381}
382
383XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) {
384  Array4D<float> input_array(3, 2, 4, 64);
385  input_array.FillRandom(2.0f);
386  std::unique_ptr<Literal> input_literal =
387      Literal::CreateR4FromArray4DWithLayout(
388          input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
389  ComputationDataHandle input;
390  auto input_data = CreateParameterAndTransferLiteral(
391      0, *input_literal, "parameter", &builder_, &input);
392
393  int win_len = 3;
394  int stride = 1;
395  Padding padding = Padding::kSame;
396  ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
397
398  auto res = ReferenceUtil::ReduceWindow4DAdd(
399      input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
400
401  ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res),
402                           {input_data.get()}, DefaultErrorSpec());
403}
404
405XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) {
406  Array4D<float> input_array(1, 3, 12, 200);
407  input_array.FillRandom(2.0f);
408  std::unique_ptr<Literal> input_literal =
409      Literal::CreateR4FromArray4DWithLayout(
410          input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
411  ComputationDataHandle input;
412  auto input_data = CreateParameterAndTransferLiteral(
413      0, *input_literal, "parameter", &builder_, &input);
414
415  int win_len = 8;
416  int stride = 5;
417  Padding padding = Padding::kSame;
418  ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
419
420  auto res = ReferenceUtil::ReduceWindow4DAdd(
421      input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
422
423  ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res),
424                           {input_data.get()}, DefaultErrorSpec());
425}
426
427TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) {
428  Array4D<float> input_array(6, 4, 10, 130);
429  input_array.FillRandom(2.0f);
430
431  int win_len = 3;
432  int win_stride = 2;
433
434  Padding padding = Padding::kSame;
435  const auto input_data_handle =
436      CreateConstantFromArray(input_array, &builder_);
437  // Reduce only along the x and y dimensions, according to the win_len.
438  ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
439                  {win_stride, win_stride, 1, 1}, padding);
440
441  auto result = ReferenceUtil::ReduceWindow4DAdd(
442      input_array, 0.0f, {win_len, win_len, 1, 1},
443      {win_stride, win_stride, 1, 1}, padding);
444  ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {},
445                           DefaultErrorSpec());
446}
447
448XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) {
449  std::vector<float> input_vector(128 * 9, 1);
450  const auto input = CreateConstantFromLiteral(
451      *Literal::CreateR1<float>(input_vector), &builder_);
452  ReduceWindowAdd(input, {32}, {128}, Padding::kValid);
453  ComputeAndCompareLiteral(
454      &builder_,
455      *Literal::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {},
456      DefaultErrorSpec());
457}
458
459XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) {
460  std::vector<float> input_vector{
461      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
462      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
463      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
464      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
465      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
466      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
467      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
468      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
469  const auto input = CreateConstantFromLiteral(
470      *Literal::CreateR1<float>(input_vector), &builder_);
471  ReduceWindowAdd(input, {128}, {128}, Padding::kValid);
472  ComputeAndCompareLiteral(&builder_, *Literal::CreateR1<float>({1088}), {},
473                           DefaultErrorSpec());
474}
475
476XLA_TEST_P(ReduceWindowTest, Add128In128) {
477  std::vector<float> input_vector{
478      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
479      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
480      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
481      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
482      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
483      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
484      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
485      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
486  const auto input = CreateConstantFromLiteral(
487      *Literal::CreateR1<float>(input_vector), &builder_);
488  ReduceWindowAdd(input, {128}, {1}, Padding::kValid);
489  ComputeAndCompareLiteral(&builder_, *Literal::CreateR1<float>({1088}), {},
490                           DefaultErrorSpec());
491}
492
493// Regression test for a bug that appeared in Inception (b/34784899).
494TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) {
495  Array2D<float> input_array(14, 14, 1.0f);
496  const auto input = CreateConstantFromArray(input_array, &builder_);
497
498  int win_len = 3;
499  int stride = 1;
500  Padding padding = Padding::kSame;
501  ReduceWindowAdd(input, {win_len, win_len}, {stride, stride}, padding);
502
503  auto res = ReferenceUtil::ReduceWindow2DAdd(
504      input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding);
505
506  ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray<float>(*res),
507                           {}, DefaultErrorSpec());
508}
509
510TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
511  Array2D<float> input_array(6, 4, 1.0f);
512  ComputationDataHandle input = builder_.Broadcast(
513      CreateConstantFromLiteral(Literal::One(F32), &builder_), {6, 4});
514
515  Padding padding = Padding::kSame;
516  ReduceWindowAdd(input, {4, 2}, {3, 3}, padding);
517
518  auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3},
519                                              padding);
520
521  ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray<float>(*res),
522                           {}, DefaultErrorSpec());
523}
524
525INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest,
526                        ::testing::ValuesIn(use_bfloat16_params));
527
528enum Reducer { kAdd, kMax };
529
530struct R4ReduceWindowTestData {
531  int64 base_bounds[4];
532  int64 window_bounds[4];
533  int64 strides[4];
534  int64 pad_low[4];
535  int64 pad_high[4];
536  int64 layout[4];
537
538  Reducer reducer;
539};
540
541string R4ReduceWindowTestDataToString(
542    const ::testing::TestParamInfo<
543        ::testing::tuple<R4ReduceWindowTestData, bool>>& data) {
544  const auto& param = ::testing::get<0>(data.param);
545  string str = tensorflow::strings::StrCat(
546      "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"),  //
547      "__window_bounds_",
548      tensorflow::str_util::Join(param.window_bounds, "x"),            //
549      "__strides_", tensorflow::str_util::Join(param.strides, "x"),    //
550      "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"),    //
551      "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"),  //
552      "__layout_", tensorflow::str_util::Join(param.layout, "_"),      //
553      (param.reducer == kAdd) ? "_add" : "_max");
554  CHECK(param.reducer == kAdd || param.reducer == kMax);
555
556  // Test names are not allowed to contain the '-' character.
557  std::replace(str.begin(), str.end(), '-', 'n');
558  if (::testing::get<1>(data.param)) {
559    str = tensorflow::strings::StrCat(str, "_bfloat16");
560  }
561  return str;
562}
563
564class R4ReduceWindowTest : public ReduceWindowTestBase,
565                           public ::testing::WithParamInterface<
566                               ::testing::tuple<R4ReduceWindowTestData, bool>> {
567 protected:
568  R4ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
569
570  void DoIt() {
571    ComputationBuilder b(client_, TestName());
572    const auto& param = ::testing::get<0>(GetParam());
573
574    const float kInitValue = 0.0f;
575
576    Array4D<float> input(param.base_bounds[0], param.base_bounds[1],
577                         param.base_bounds[2], param.base_bounds[3]);
578    input.FillIota(1);
579    std::unique_ptr<Literal> input_literal =
580        Literal::CreateR4FromArray4DWithLayout(
581            input, LayoutUtil::MakeLayout(param.layout));
582    ComputationDataHandle parameter;
583    auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
584                                                       &b, &parameter);
585
586    std::vector<std::pair<int64, int64>> padding(4);
587    for (int i = 0; i < 4; ++i) {
588      padding[i] = {param.pad_low[i], param.pad_high[i]};
589    }
590
591    auto init_value =
592        CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b);
593    CHECK(param.reducer == kAdd || param.reducer == kMax);
594    auto computation = param.reducer == kAdd
595                           ? CreateScalarAddComputation(FloatType(), &b)
596                           : CreateScalarMaxComputation(FloatType(), &b);
597    b.ReduceWindowWithGeneralPadding(
598        /*operand=*/parameter,
599        /*init_value=*/init_value,
600        /*computation=*/computation,
601        /*window_dimensions=*/param.window_bounds,
602        /*window_strides=*/param.strides,
603        /*padding=*/padding);
604
605    CHECK(param.reducer == kAdd || param.reducer == kMax);
606    auto reduce_func = param.reducer == kAdd
607                           ? +[](float a, float b) { return a + b; }
608                           : +[](float a, float b) { return std::max(a, b); };
609    std::unique_ptr<Array4D<float>> expected =
610        ReferenceUtil::ReduceWindow4DGeneric(
611            /*operand=*/input,
612            /*init=*/kInitValue,
613            /*reduce_func=*/reduce_func,
614            /*window=*/param.window_bounds,
615            /*stride=*/param.strides,
616            /*padding=*/padding);
617    std::unique_ptr<Literal> expected_literal =
618        Literal::CreateFromArray(*expected);
619    const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
620        input_literal->shape().element_type(),
621        AsInt64Slice(expected_literal->shape().dimensions()), param.layout);
622    ComputeAndCompareLiteral(&b, *expected_literal, {input_arg.get()},
623                             DefaultErrorSpec(), &expected_shape_with_layout);
624  }
625};
626
627TEST_P(R4ReduceWindowTest, DoIt) { DoIt(); }
628
629// base_bounds, window_bounds, strides, pad_low, pad_high
630const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
631    // Minimal edge case.
632    R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 1, 1},
633                           /*window_bounds=*/{1, 1, 1, 1},
634                           /*strides=*/{1, 1, 1, 1},
635                           /*pad_low=*/{0, 0, 0, 0},
636                           /*pad_high=*/{0, 0, 0, 0},
637                           /*layout=*/{3, 2, 1, 0},
638                           /*reducer=*/kAdd},
639
640    // Arbitrary padding (not kSame or kValid).
641    R4ReduceWindowTestData{/*base_bounds=*/{9, 12, 4, 89},
642                           /*window_bounds=*/{3, 3, 1, 1},
643                           /*strides=*/{2, 2, 1, 1},
644                           /*pad_low=*/{4, 4, 0, 0},
645                           /*pad_high=*/{4, 4, 0, 0},
646                           /*layout=*/{3, 2, 1, 0},
647                           /*reducer=*/kAdd},
648
649    // Zero base bound edge case.
650    R4ReduceWindowTestData{/*base_bounds=*/{1, 0, 1, 1},
651                           /*window_bounds=*/{1, 1, 1, 1},
652                           /*strides=*/{1, 1, 1, 1},
653                           /*pad_low=*/{0, 0, 0, 0},
654                           /*pad_high=*/{0, 0, 0, 0},
655                           /*layout=*/{3, 2, 1, 0},
656                           /*reducer=*/kAdd},
657
658    // With non-1x1 window.
659    R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
660                           /*window_bounds=*/{2, 3, 1, 1},
661                           /*strides=*/{1, 1, 1, 1},
662                           /*pad_low=*/{0, 0, 0, 0},
663                           /*pad_high=*/{0, 0, 0, 0},
664                           /*layout=*/{3, 2, 1, 0},
665                           /*reducer=*/kAdd},
666
667    // With max instead of add.
668    R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
669                           /*window_bounds=*/{2, 3, 1, 1},
670                           /*strides=*/{1, 1, 1, 1},
671                           /*pad_low=*/{0, 0, 0, 0},
672                           /*pad_high=*/{0, 0, 0, 0},
673                           /*layout=*/{3, 2, 1, 0},
674                           /*reducer=*/kMax},
675
676    // With stride.
677    R4ReduceWindowTestData{/*base_bounds=*/{4, 10, 17, 140},
678                           /*window_bounds=*/{3, 2, 1, 1},
679                           /*strides=*/{2, 4, 1, 1},
680                           /*pad_low=*/{0, 0, 0, 0},
681                           /*pad_high=*/{0, 0, 0, 0},
682                           /*layout=*/{3, 2, 1, 0},
683                           /*reducer=*/kAdd},
684
685    // With low padding.
686    R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
687                           /*window_bounds=*/{3, 2, 1, 1},
688                           /*strides=*/{2, 2, 1, 1},
689                           /*pad_low=*/{3, 2, 0, 0},
690                           /*pad_high=*/{0, 0, 0, 0},
691                           /*layout=*/{3, 2, 1, 0},
692                           /*reducer=*/kAdd},
693
694    // With high padding.
695    R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
696                           /*window_bounds=*/{3, 2, 1, 1},
697                           /*strides=*/{2, 2, 1, 1},
698                           /*pad_low=*/{0, 0, 0, 0},
699                           /*pad_high=*/{2, 3, 0, 0},
700                           /*layout=*/{3, 2, 1, 0},
701                           /*reducer=*/kAdd},
702
703    // Window touches both sides of the padding simultaneously.
704    R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 17, 140},
705                           /*window_bounds=*/{3, 3, 1, 1},
706                           /*strides=*/{1, 1, 1, 1},
707                           /*pad_low=*/{1, 1, 0, 0},
708                           /*pad_high=*/{1, 1, 0, 0},
709                           /*layout=*/{3, 2, 1, 0},
710                           /*reducer=*/kAdd},
711
712    // Window is entirely in the padding for some positions.
713    R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 17, 140},
714                           /*window_bounds=*/{3, 3, 1, 1},
715                           /*strides=*/{1, 1, 1, 1},
716                           /*pad_low=*/{4, 4, 0, 0},
717                           /*pad_high=*/{4, 4, 0, 0},
718                           /*layout=*/{3, 2, 1, 0},
719                           /*reducer=*/kAdd},
720
721    // Zero base bound with padding edge case.
722    R4ReduceWindowTestData{/*base_bounds=*/{2, 0, 3, 4},
723                           /*window_bounds=*/{1, 1, 1, 1},
724                           /*strides=*/{1, 1, 1, 1},
725                           /*pad_low=*/{0, 1, 0, 0},
726                           /*pad_high=*/{0, 0, 0, 0},
727                           /*layout=*/{3, 2, 1, 0},
728                           /*reducer=*/kAdd},
729
730    // With stride, low padding and high padding.
731    R4ReduceWindowTestData{/*base_bounds=*/{4, 3, 17, 140},
732                           /*window_bounds=*/{3, 4, 1, 1},
733                           /*strides=*/{3, 1, 1, 1},
734                           /*pad_low=*/{10, 1, 0, 0},
735                           /*pad_high=*/{2, 3, 0, 0},
736                           /*layout=*/{3, 2, 1, 0},
737                           /*reducer=*/kAdd},
738
739    // With second minor dimension == 9.
740    R4ReduceWindowTestData{/*base_bounds=*/{2, 3, 9, 127},
741                           /*window_bounds=*/{1, 1, 1, 1},
742                           /*strides=*/{1, 1, 1, 1},
743                           /*pad_low=*/{0, 0, 0, 0},
744                           /*pad_high=*/{0, 0, 0, 0},
745                           /*layout=*/{3, 2, 1, 0},
746                           /*reducer=*/kAdd},
747
748    // With minor dimension == 129.
749    R4ReduceWindowTestData{/*base_bounds=*/{3, 2, 7, 129},
750                           /*window_bounds=*/{1, 1, 1, 1},
751                           /*strides=*/{1, 1, 1, 1},
752                           /*pad_low=*/{0, 0, 0, 0},
753                           /*pad_high=*/{0, 0, 0, 0},
754                           /*layout=*/{3, 2, 1, 0},
755                           /*reducer=*/kAdd},
756
757    // With minor dims reduction and non-overlapped stride.
758    R4ReduceWindowTestData{/*base_bounds=*/{2, 2, 4, 16},
759                           /*window_bounds=*/{1, 1, 2, 2},
760                           /*strides=*/{1, 1, 2, 2},
761                           /*pad_low=*/{0, 0, 0, 0},
762                           /*pad_high=*/{0, 0, 0, 0},
763                           /*layout=*/{3, 2, 1, 0},
764                           /*reducer=*/kAdd},
765
766    // With minor dims reduction and overlapped stride.
767    R4ReduceWindowTestData{/*base_bounds=*/{2, 2, 4, 16},
768                           /*window_bounds=*/{1, 1, 4, 4},
769                           /*strides=*/{1, 1, 2, 2},
770                           /*pad_low=*/{0, 0, 0, 0},
771                           /*pad_high=*/{1, 0, 0, 0},
772                           /*layout=*/{3, 2, 1, 0},
773                           /*reducer=*/kAdd},
774};
775
776INSTANTIATE_TEST_CASE_P(
777    R4ReduceWindowTestInstantiation, R4ReduceWindowTest,
778    ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowTestValues),
779                       ::testing::ValuesIn(use_bfloat16_params)),
780    R4ReduceWindowTestDataToString);
781
782class R4ReduceWindowLargeTest : public R4ReduceWindowTest {};
783
784XLA_TEST_P(R4ReduceWindowLargeTest, DISABLED_ON_INTERPRETER(DoIt)) { DoIt(); }
785
786// Test cases that are large/slow/failed.
787const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = {
788    R4ReduceWindowTestData{/*base_bounds=*/{28, 28, 256, 128},
789                           /*window_bounds=*/{3, 3, 1, 5},
790                           /*strides=*/{1, 1, 1, 5},
791                           /*pad_low=*/{1, 1, 0, 0},
792                           /*pad_high=*/{1, 1, 0, 0},
793                           /*layout=*/{3, 2, 1, 0},
794                           /*reducer=*/kMax},
795
796    R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 64, 128},
797                           /*window_bounds=*/{3, 3, 1, 1},
798                           /*strides=*/{2, 2, 1, 1},
799                           /*pad_low=*/{0, 0, 0, 0},
800                           /*pad_high=*/{1, 1, 0, 0},
801                           /*layout=*/{3, 2, 1, 0},
802                           /*reducer=*/kAdd},
803
804    R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 32768 - 3, 2},
805                           /*window_bounds=*/{1, 1, 4, 1},
806                           /*strides=*/{1, 1, 4, 1},
807                           /*pad_low=*/{0, 0, 1, 0},
808                           /*pad_high=*/{0, 0, 2, 0},
809                           /*layout=*/{3, 2, 1, 0},
810                           /*reducer=*/kMax},
811};
812
813INSTANTIATE_TEST_CASE_P(
814    R4ReduceWindowLargeTestInstantiation, R4ReduceWindowLargeTest,
815    ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowLargeTestValues),
816                       ::testing::ValuesIn(use_bfloat16_params)),
817    R4ReduceWindowTestDataToString);
818
819class R4ReduceWindowAnyDimsTest : public R4ReduceWindowTest {};
820
821// TODO(b/72234705): Fix the test cases failed on CPU and GPU.
822XLA_TEST_P(R4ReduceWindowAnyDimsTest,
823           DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt)))) {
824  DoIt();
825}
826
827const R4ReduceWindowTestData kR4ReduceWindowAnyDimsTestValues[] = {
828    R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
829                           /*window_bounds=*/{2, 3, 4, 5},
830                           /*strides=*/{1, 1, 1, 1},
831                           /*pad_low=*/{0, 0, 0, 0},
832                           /*pad_high=*/{0, 0, 0, 0},
833                           /*layout=*/{3, 2, 1, 0},
834                           /*reducer=*/kAdd},
835    R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
836                           /*window_bounds=*/{2, 3, 1, 1},
837                           /*strides=*/{1, 1, 1, 1},
838                           /*pad_low=*/{0, 0, 0, 0},
839                           /*pad_high=*/{0, 0, 0, 0},
840                           /*layout=*/{3, 2, 1, 0},
841                           /*reducer=*/kMax},
842    // With 0321 layout.
843    R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
844                           /*window_bounds=*/{2, 3, 4, 5},
845                           /*strides=*/{1, 2, 3, 4},
846                           /*pad_low=*/{0, 0, 0, 0},
847                           /*pad_high=*/{0, 0, 0, 0},
848                           /*layout=*/{0, 3, 2, 1},
849                           /*reducer=*/kAdd},
850
851    // With 0123 layout.
852    R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 23},
853                           /*window_bounds=*/{2, 3, 7, 9},
854                           /*strides=*/{1, 2, 5, 8},
855                           /*pad_low=*/{0, 0, 0, 0},
856                           /*pad_high=*/{0, 0, 0, 0},
857                           /*layout=*/{0, 1, 2, 3},
858                           /*reducer=*/kAdd},
859};
860
861INSTANTIATE_TEST_CASE_P(
862    R4ReduceWindowAnyDimsTestInstantiation, R4ReduceWindowAnyDimsTest,
863    ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowAnyDimsTestValues),
864                       ::testing::ValuesIn(use_bfloat16_params)),
865    R4ReduceWindowTestDataToString);
866
867struct R3ReduceWindowTestData {
868  int64 base_bounds[3];
869  int64 window_bounds[3];
870  int64 strides[3];
871  int64 layout[3];
872  Padding padding;
873  Reducer reducer;
874} kR3TestCases[] = {
875    {/*base_bounds=*/{2, 1, 2}, /*window_bounds=*/{1, 1, 2},
876     /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
877     /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
878    {/*base_bounds=*/{4, 3, 3}, /*window_bounds=*/{2, 2, 2},
879     /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
880     /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
881    {/*base_bounds=*/{4, 3, 3}, /*window_bounds=*/{2, 2, 2},
882     /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
883     /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
884    {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
885     /*strides=*/{1, 2, 2}, /*layout=*/{2, 1, 0},
886     /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
887    {/*base_bounds=*/{10, 21, 129}, /*window_bounds=*/{2, 9, 1},
888     /*strides=*/{5, 2, 1}, /*layout=*/{2, 1, 0},
889     /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
890    {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
891     /*strides=*/{1, 2, 2}, /*layout=*/{0, 1, 2},
892     /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
893    {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
894     /*strides=*/{1, 2, 2}, /*layout=*/{1, 0, 2},
895     /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
896};
897
898string R3ReduceWindowTestDataToString(
899    const ::testing::TestParamInfo<
900        ::testing::tuple<R3ReduceWindowTestData, bool>>& data) {
901  const auto& param = ::testing::get<0>(data.param);
902  string str = tensorflow::strings::StrCat(
903      "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"),
904      "__window_bounds_", tensorflow::str_util::Join(param.window_bounds, "x"),
905      "__strides_", tensorflow::str_util::Join(param.strides, "x"),
906      "__padding_", param.padding == Padding::kSame ? "same" : "valid",
907      "__layout_", param.layout[0], "_", param.layout[1], "_", param.layout[2],
908      "__reducer_", param.reducer == kAdd ? "add" : "max");
909  if (::testing::get<1>(data.param)) {
910    str = tensorflow::strings::StrCat(str, "_bfloat16");
911  }
912  return str;
913}
914
915class R3ReduceWindowTest : public ReduceWindowTestBase,
916                           public ::testing::WithParamInterface<
917                               ::testing::tuple<R3ReduceWindowTestData, bool>> {
918 protected:
919  R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
920};
921
922TEST_P(R3ReduceWindowTest, Add) {
923  ComputationBuilder b(client_, TestName());
924  const auto& param = ::testing::get<0>(GetParam());
925  CHECK(param.reducer == kAdd);
926
927  const float kInitValue = 0.0f;
928  Array3D<float> input(param.base_bounds[0], param.base_bounds[1],
929                       param.base_bounds[2], 1.0f);
930  std::unique_ptr<Literal> input_literal =
931      Literal::CreateR3FromArray3DWithLayout(
932          input, LayoutUtil::MakeLayout(param.layout));
933
934  ComputationDataHandle parameter;
935  auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
936                                                     &b, &parameter);
937  auto init_value =
938      CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b);
939  b.ReduceWindow(/*operand=*/parameter,
940                 /*init_value=*/init_value,
941                 /*computation=*/CreateScalarAddComputation(FloatType(), &b),
942                 /*window_dimensions=*/param.window_bounds,
943                 /*window_strides=*/param.strides, /*padding=*/param.padding);
944
945  auto expected = ReferenceUtil::ReduceWindow3DAdd(
946      /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds,
947      /*stride=*/param.strides, /*padding=*/param.padding);
948
949  ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected),
950                           {input_arg.get()}, DefaultErrorSpec());
951}
952
953INSTANTIATE_TEST_CASE_P(
954    R3ReduceWindowTestInstantiation, R3ReduceWindowTest,
955    ::testing::Combine(::testing::ValuesIn(kR3TestCases),
956                       ::testing::ValuesIn(use_bfloat16_params)),
957    R3ReduceWindowTestDataToString);
958
959struct R2ReduceWindowTestData {
960  int64 base_bounds[2];
961  int64 window_bounds[2];
962  int64 strides[2];
963  int64 layout[2];
964  Padding padding;
965  Reducer reducer;
966} kR2TestCases[] = {
967    {/*base_bounds=*/{4, 18}, /*window_bounds=*/{2, 4},
968     /*strides=*/{1, 2}, /*layout=*/{0, 1},
969     /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
970    {/*base_bounds=*/{2, 5}, /*window_bounds=*/{2, 4},
971     /*strides=*/{1, 1}, /*layout=*/{0, 1},
972     /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
973    {/*base_bounds=*/{1, 3}, /*window_bounds=*/{2, 3},
974     /*strides=*/{1, 1}, /*layout=*/{0, 1},
975     /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
976    {/*base_bounds=*/{3, 129}, /*window_bounds=*/{1, 100},
977     /*strides=*/{2, 99}, /*layout=*/{0, 1},
978     /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
979    {/*base_bounds=*/{6, 152}, /*window_bounds=*/{2, 25},
980     /*strides=*/{5, 4}, /*layout=*/{0, 1},
981     /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
982    {/*base_bounds=*/{6, 4}, /*window_bounds=*/{4, 2},
983     /*strides=*/{3, 3}, /*layout=*/{0, 1},
984     /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
985    {/*base_bounds=*/{5, 147}, /*window_bounds=*/{1, 36},
986     /*strides=*/{4, 5}, /*layout=*/{1, 0},
987     /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
988    {/*base_bounds=*/{4, 153}, /*window_bounds=*/{2, 93},
989     /*strides=*/{1, 1}, /*layout=*/{1, 0},
990     /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
991    // Regression test for a bug that appeared in Inception (b/34784899).
992    {/*base_bounds=*/{28, 28}, /*window_bounds=*/{3, 3},
993     /*strides=*/{1, 1}, /*layout=*/{1, 0},
994     /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
995    // Regression test for a bug that appeared in Inception (b/34784899).
996    {/*base_bounds=*/{4, 32}, /*window_bounds=*/{2, 2},
997     /*strides=*/{2, 2}, /*layout=*/{1, 0},
998     /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
999    {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2},
1000     /*strides=*/{1, 1}, /*layout=*/{1, 0},
1001     /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
1002};
1003
1004string R2ReduceWindowTestDataToString(
1005    const ::testing::TestParamInfo<
1006        ::testing::tuple<R2ReduceWindowTestData, bool>>& data) {
1007  const auto& param = ::testing::get<0>(data.param);
1008  string str = tensorflow::strings::StrCat(
1009      "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"),  //
1010      "__window_bounds_",
1011      tensorflow::str_util::Join(param.window_bounds, "x"),              //
1012      "__strides_", tensorflow::str_util::Join(param.strides, "x"),      //
1013      "__padding_", param.padding == Padding::kSame ? "same" : "valid",  //
1014      "__layout_", param.layout[0], "_", param.layout[1],                //
1015      "__reducer_", param.reducer == kAdd ? "add" : "max");
1016  if (::testing::get<1>(data.param)) {
1017    str = tensorflow::strings::StrCat(str, "_bfloat16");
1018  }
1019  return str;
1020}
1021
1022class R2ReduceWindowTest : public ReduceWindowTestBase,
1023                           public ::testing::WithParamInterface<
1024                               ::testing::tuple<R2ReduceWindowTestData, bool>> {
1025 protected:
1026  R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
1027
1028  void DoIt() {
1029    ComputationBuilder b(client_, TestName());
1030    const auto& param = ::testing::get<0>(GetParam());
1031    CHECK(param.reducer == kAdd);
1032
1033    const float kInitValue = 0.0f;
1034    Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
1035    std::unique_ptr<Literal> input_literal =
1036        Literal::CreateR2FromArray2DWithLayout(
1037            input, LayoutUtil::MakeLayout(param.layout));
1038
1039    ComputationDataHandle parameter;
1040    auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
1041                                                       &b, &parameter);
1042    auto init_value =
1043        CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b);
1044    b.ReduceWindow(/*operand=*/parameter,
1045                   /*init_value=*/init_value,
1046                   /*computation=*/CreateScalarAddComputation(FloatType(), &b),
1047                   /*window_dimensions=*/param.window_bounds,
1048                   /*window_strides=*/param.strides, /*padding=*/param.padding);
1049
1050    auto expected = ReferenceUtil::ReduceWindow2DAdd(
1051        /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds,
1052        /*stride=*/param.strides, /*padding=*/param.padding);
1053
1054    ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected),
1055                             {input_arg.get()}, DefaultErrorSpec());
1056  }
1057};
1058
1059TEST_P(R2ReduceWindowTest, DoIt) { DoIt(); }
1060
1061INSTANTIATE_TEST_CASE_P(
1062    R2ReduceWindowTestInstantiation, R2ReduceWindowTest,
1063    ::testing::Combine(::testing::ValuesIn(kR2TestCases),
1064                       ::testing::ValuesIn(use_bfloat16_params)),
1065    R2ReduceWindowTestDataToString);
1066
1067class R2ReduceWindowFailingCpuGpuBf16Test : public R2ReduceWindowTest {};
1068
1069// TODO(b/72234705): Fix the test cases failed on CPU and GPU.
1070XLA_TEST_P(R2ReduceWindowFailingCpuGpuBf16Test,
1071           DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt)))) {
1072  DoIt();
1073}
1074
1075const R2ReduceWindowTestData kR2FailingValuesCpuGpuBf16Test[] = {
1076    {/*base_bounds=*/{8, 128}, /*window_bounds=*/{8, 128},
1077     /*strides=*/{1, 1}, /*layout=*/{1, 0},
1078     /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
1079};
1080
1081INSTANTIATE_TEST_CASE_P(
1082    R2ReduceWindowFailingInstantiation, R2ReduceWindowFailingCpuGpuBf16Test,
1083    ::testing::Combine(::testing::ValuesIn(kR2FailingValuesCpuGpuBf16Test),
1084                       ::testing::ValuesIn(use_bfloat16_params)),
1085    R2ReduceWindowTestDataToString);
1086
1087struct R1ReduceWindowTestData {
1088  int64 base_bounds[1];
1089  int64 window_bounds[1];
1090  int64 strides[1];
1091  int64 pad_low[1];
1092  int64 pad_high[1];
1093  Reducer reducer;
1094} kR1TestCases[] = {
1095    {/*base_bounds=*/{1}, /*window_bounds=*/{1},
1096     /*strides=*/{1},
1097     /*pad_low=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].first},
1098     /*pad_high=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].second},
1099     /*reducer=*/Reducer::kAdd},
1100
1101    {/*base_bounds=*/{3}, /*window_bounds=*/{3},
1102     /*strides=*/{1},
1103     /*pad_low=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].first},
1104     /*pad_high=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].second},
1105     /*reducer=*/Reducer::kAdd},
1106
1107    {/*base_bounds=*/{3}, /*window_bounds=*/{2},
1108     /*strides=*/{1},
1109     /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].first},
1110     /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].second},
1111     /*reducer=*/Reducer::kAdd},
1112
1113    {/*base_bounds=*/{5}, /*window_bounds=*/{1},
1114     /*strides=*/{1},
1115     /*pad_low=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].first},
1116     /*pad_high=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].second},
1117     /*reducer=*/Reducer::kMax},
1118
1119    {/*base_bounds=*/{16}, /*window_bounds=*/{4},
1120     /*strides=*/{4},
1121     /*pad_low=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].first},
1122     /*pad_high=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].second},
1123     /*reducer=*/Reducer::kMax},
1124
1125    {/*base_bounds=*/{16}, /*window_bounds=*/{4},
1126     /*strides=*/{3},
1127     /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].first},
1128     /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].second},
1129     /*reducer=*/Reducer::kAdd},
1130
1131    {/*base_bounds=*/{128 * 2},
1132     /*window_bounds=*/{30},
1133     /*strides=*/{27},
1134     /*pad_low=*/
1135     {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].first},
1136     /*pad_high=*/
1137     {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].second},
1138     /*reducer=*/Reducer::kAdd},
1139
1140    {/*base_bounds=*/{128 * 17},
1141     /*window_bounds=*/{7},
1142     /*strides=*/{64},
1143     /*pad_low=*/
1144     {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].first},
1145     /*pad_high=*/
1146     {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].second},
1147     /*reducer=*/Reducer::kAdd},
1148
1149    {/*base_bounds=*/{128 * 2},
1150     /*window_bounds=*/{32},
1151     /*strides=*/{56},
1152     /*pad_low=*/
1153     {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].first},
1154     /*pad_high=*/
1155     {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].second},
1156     /*reducer=*/Reducer::kAdd},
1157
1158    {/*base_bounds=*/{3}, /*window_bounds=*/{2},
1159     /*strides=*/{1},
1160     /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].first},
1161     /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].second},
1162     /*reducer=*/Reducer::kAdd},
1163
1164    {/*base_bounds=*/{5}, /*window_bounds=*/{3},
1165     /*strides=*/{2},
1166     /*pad_low=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].first},
1167     /*pad_high=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].second},
1168     /*reducer=*/Reducer::kAdd},
1169
1170    {/*base_bounds=*/{16}, /*window_bounds=*/{4},
1171     /*strides=*/{3},
1172     /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].first},
1173     /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].second},
1174     /*reducer=*/Reducer::kAdd},
1175
1176    {/*base_bounds=*/{5}, /*window_bounds=*/{5},
1177     /*strides=*/{1},
1178     /*pad_low=*/{0},
1179     /*pad_high=*/{5},
1180     /*reducer=*/Reducer::kAdd},
1181
1182    {/*base_bounds=*/{5}, /*window_bounds=*/{5},
1183     /*strides=*/{1},
1184     /*pad_low=*/{5},
1185     /*pad_high=*/{0},
1186     /*reducer=*/Reducer::kAdd},
1187};
1188
1189string R1ReduceWindowTestDataToString(
1190    const ::testing::TestParamInfo<
1191        ::testing::tuple<R1ReduceWindowTestData, bool>>& data) {
1192  const auto& param = ::testing::get<0>(data.param);
1193  string str = tensorflow::strings::StrCat(
1194      "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"),
1195      "__window_bounds_", tensorflow::str_util::Join(param.window_bounds, "x"),
1196      "__strides_", tensorflow::str_util::Join(param.strides, "x"),
1197      "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"),
1198      "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"),
1199      "__reducer_", param.reducer == kAdd ? "add" : "max");
1200  if (::testing::get<1>(data.param)) {
1201    str = tensorflow::strings::StrCat(str, "_bfloat16");
1202  }
1203  return str;
1204}
1205
1206class R1ReduceWindowTest : public ReduceWindowTestBase,
1207                           public ::testing::WithParamInterface<
1208                               ::testing::tuple<R1ReduceWindowTestData, bool>> {
1209 protected:
1210  R1ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
1211};
1212
1213TEST_P(R1ReduceWindowTest, DoIt) {
1214  ComputationBuilder b(client_, TestName());
1215  const auto& param = ::testing::get<0>(GetParam());
1216  CHECK(param.reducer == kAdd || param.reducer == kMax);
1217
1218  const float kInitValue = 0.0f;
1219  std::vector<float> input_vector(param.base_bounds[0]);
1220  std::iota(std::begin(input_vector), std::end(input_vector), 0);
1221  std::unique_ptr<Literal> input_literal =
1222      Literal::CreateR1(tensorflow::gtl::ArraySlice<float>(input_vector));
1223  ComputationDataHandle parameter;
1224  auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
1225                                                     &b, &parameter);
1226
1227  std::vector<std::pair<int64, int64>> padding(1);
1228  padding[0] = {param.pad_low[0], param.pad_high[0]};
1229
1230  auto computation = param.reducer == kAdd
1231                         ? CreateScalarAddComputation(FloatType(), &b)
1232                         : CreateScalarMaxComputation(FloatType(), &b);
1233  auto init_value =
1234      CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b);
1235  b.ReduceWindowWithGeneralPadding(
1236      /*operand=*/parameter,
1237      /*init_value=*/init_value,
1238      /*computation=*/computation,
1239      /*window_dimensions=*/param.window_bounds,
1240      /*window_strides=*/param.strides, /*padding=*/padding);
1241
1242  auto reduce_func = param.reducer == kAdd
1243                         ? +[](float a, float b) { return a + b; }
1244                         : +[](float a, float b) { return std::max(a, b); };
1245  auto expected = ReferenceUtil::ReduceWindow1DGeneric(
1246      /*operand=*/tensorflow::gtl::ArraySlice<float>(input_vector),
1247      /*init=*/kInitValue,
1248      /*reduce_func=*/reduce_func,
1249      /*window=*/param.window_bounds,
1250      /*stride=*/param.strides,
1251      /*padding=*/padding);
1252
1253  ComputeAndCompareLiteral(&b, *Literal::CreateR1<float>(*expected),
1254                           {input_arg.get()}, DefaultErrorSpec());
1255}
1256
1257INSTANTIATE_TEST_CASE_P(
1258    R1ReduceWindowTestInstantiation, R1ReduceWindowTest,
1259    ::testing::Combine(::testing::ValuesIn(kR1TestCases),
1260                       ::testing::ValuesIn(use_bfloat16_params)),
1261    R1ReduceWindowTestDataToString);
1262
1263// Test class for text-based test cases. Note that this compares with the
1264// results on the interpreter backend.
1265class ReduceWindowTextTest : public HloTestBase {};
1266
1267TEST_F(ReduceWindowTextTest, R2General256x384) {
1268  const string& hlo_string = R"(
1269HloModule R2Window
1270mul {
1271  lhs = f32[] parameter(0)
1272  rhs = f32[] parameter(1)
1273  ROOT mul = f32[] multiply(lhs, rhs)
1274}
1275ENTRY R2Window {
1276  operand = f32[256,384]{1,0} parameter(0)
1277  constant = f32[] constant(1)
1278  ROOT reduce-window = f32[256,384]{1,0} reduce-window(operand, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul
1279}
1280)";
1281  EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1282}
1283
1284TEST_F(ReduceWindowTextTest, R2General256x384Layout01) {
1285  const string& hlo_string = R"(
1286HloModule R2Window
1287mul {
1288lhs = f32[] parameter(0)
1289rhs = f32[] parameter(1)
1290ROOT mul = f32[] multiply(lhs, rhs)
1291}
1292ENTRY R2Window {
1293operand = f32[256,384]{0,1} parameter(0)
1294constant = f32[] constant(1)
1295ROOT reduce-window = f32[256,384]{0,1} reduce-window(operand, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul
1296}
1297)";
1298  EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1299}
1300
1301TEST_F(ReduceWindowTextTest, R2General2x5) {
1302  const string& hlo_string = R"(
1303HloModule R2Window
1304mul {
1305  lhs = f32[] parameter(0)
1306  rhs = f32[] parameter(1)
1307  ROOT mul = f32[] multiply(lhs, rhs)
1308}
1309ENTRY R2Window {
1310  operand = f32[2,5]{1,0} parameter(0)
1311  constant = f32[] constant(1)
1312  ROOT reduce-window = f32[3,5]{1,0} reduce-window(operand, constant), window={size=2x1 pad=0_2x0_0}, to_apply=mul
1313}
1314)";
1315  EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1316}
1317
1318}  // namespace
1319}  // namespace xla
1320