1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <cmath>
17#include <memory>
18#include <vector>
19
20#include "tensorflow/compiler/xla/array2d.h"
21#include "tensorflow/compiler/xla/array4d.h"
22#include "tensorflow/compiler/xla/client/computation.h"
23#include "tensorflow/compiler/xla/client/computation_builder.h"
24#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
25#include "tensorflow/compiler/xla/client/local_client.h"
26#include "tensorflow/compiler/xla/literal_util.h"
27#include "tensorflow/compiler/xla/reference_util.h"
28#include "tensorflow/compiler/xla/service/hlo_computation.h"
29#include "tensorflow/compiler/xla/service/hlo_instruction.h"
30#include "tensorflow/compiler/xla/service/hlo_module.h"
31#include "tensorflow/compiler/xla/shape_util.h"
32#include "tensorflow/compiler/xla/statusor.h"
33#include "tensorflow/compiler/xla/test.h"
34#include "tensorflow/compiler/xla/test_helpers.h"
35#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
36#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
37#include "tensorflow/compiler/xla/tests/literal_test_util.h"
38#include "tensorflow/compiler/xla/tests/test_macros.h"
39#include "tensorflow/compiler/xla/tests/test_utils.h"
40#include "tensorflow/compiler/xla/util.h"
41#include "tensorflow/compiler/xla/xla_data.pb.h"
42#include "tensorflow/core/lib/math/math_util.h"
43#include "tensorflow/core/lib/strings/str_util.h"
44#include "tensorflow/core/platform/logging.h"
45#include "tensorflow/core/platform/test.h"
46#include "tensorflow/core/platform/types.h"
47
48namespace xla {
49namespace {
50
51class BatchNormalizationTest
52    : public ClientLibraryTestBase,
53      public ::testing::WithParamInterface<bool /*use_cudnn_batchnorm*/> {
54 protected:
55  BatchNormalizationTest() : input_array_(kSamples, kZ, kY, kX) {
56    mutable_debug_options()->set_xla_gpu_use_cudnn_batchnorm(GetParam());
57
58    Array2D<float> pz({
59        // z0 z1
60        {-1.0f, 4.1f},  // p0
61        {2.0f, 4.1f},   // p1
62        {5.0f, 4.4f},   // p2
63    });
64    input_array_.FillWithPZ(pz);
65    input_literal_ = std::move(*Literal::CreateR4FromArray4D(input_array_));
66    CHECK_EQ(kSamples, input_array_.planes());
67    CHECK_EQ(kZ, input_array_.depth());
68    CHECK_EQ(kY, input_array_.height());
69    CHECK_EQ(kY, input_array_.width());
70  }
71
72  static constexpr int64 kSamples = 3;
73  static constexpr int64 kX = 1;
74  static constexpr int64 kY = 1;
75  static constexpr int64 kZ = 2;
76
77  Array4D<float> input_array_;
78  Literal input_literal_;
79  const ErrorSpec error_spec_{0.001, 0.001};
80};
81
82// If testing the GPU backend, run the tests twice, with and without cudnn
83// batchnorm.  Otherwise, just run the tests once -- the value of this flag
84// doesn't matter.
85#ifdef XLA_TEST_BACKEND_GPU
86INSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest,
87                        ::testing::Bool());
88#else
89INSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest,
90                        ::testing::Values(false));
91#endif
92
93XLA_TEST_P(BatchNormalizationTest, SubtractInZ) {
94  ComputationBuilder builder(client_, "subtract_in_z_one_sample");
95  auto x = builder.ConstantLiteral(input_literal_);
96  auto y = builder.ConstantR1<float>({3.14, 4.25});
97  builder.Sub(x, y, /*broadcast_dimensions=*/{1});
98
99  Array4D<float> expected(kSamples, kZ, kY, kX);
100  Array2D<float> pz({
101      {-1.0f - 3.14f, 4.1f - 4.25f},  // p0
102      {2.0f - 3.14f, 4.1f - 4.25f},   // p1
103      {5.0f - 3.14f, 4.4f - 4.25f},   // p2
104  });
105  expected.FillWithPZ(pz);
106  ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
107}
108
109XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) {
110  ComputationBuilder builder(client_, "square_tesseract_elementwise");
111  auto x = builder.ConstantLiteral(input_literal_);
112  builder.SquareF32(x);
113
114  using tensorflow::MathUtil;
115
116  Array4D<float> expected(kSamples, kZ, kY, kX);
117  Array2D<float> expected_pz({
118      {MathUtil::IPow(-1.0f, 2), MathUtil::IPow(4.1f, 2)},
119      {MathUtil::IPow(2.0f, 2), MathUtil::IPow(4.1f, 2)},
120      {MathUtil::IPow(5.0f, 2), MathUtil::IPow(4.4f, 2)},
121  });
122  expected.FillWithPZ(expected_pz);
123  ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
124}
125
126XLA_TEST_P(BatchNormalizationTest, SumToZ) {
127  ComputationBuilder builder(client_, "sum_to_z");
128  auto input_activations = builder.ConstantLiteral(input_literal_);
129  Computation add = CreateScalarAddComputation(F32, &builder);
130  // Reduce all but the Z dimension.
131  builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add,
132                 {0, 2, 3});
133
134  std::vector<float> expected = {6, 12.6};
135  ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
136}
137
138XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) {
139  ComputationBuilder builder(client_, "square_and_reduce");
140  auto input_activations = builder.ConstantLiteral(input_literal_);
141  auto set_means = builder.ConstantR1<float>({2.f, 4.2f});
142  auto activation_deviations = builder.Sub(input_activations, set_means,
143                                           /*broadcast_dimensions=*/{1});
144  Computation add = CreateScalarAddComputation(F32, &builder);
145  auto dev_squares = builder.SquareF32(activation_deviations);
146  auto sum_of_squares = builder.Reduce(
147      dev_squares, builder.ConstantR0<float>(0.0f), add, {0, 2, 3});
148
149  std::vector<float> expected = {18, 0.06};
150  ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
151}
152
153XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) {
154  ComputationBuilder builder(client_, "variance_to_stddev");
155  auto variance = builder.ConstantR1<float>({6.f, .02f});
156  auto sqrt = builder.SqrtF32(variance);
157
158  std::vector<float> expected = {2.44948974f, 0.14142136f};
159  ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
160}
161
162// Compare against a forward batch normalization example in the NN spec
163// reference.
164XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) {
165  ComputationBuilder builder(client_, "batch_normalize_per_spec");
166  auto input_activations =
167      builder.CheckShape(builder.ConstantLiteral(input_literal_),
168                         ShapeUtil::MakeShape(F32, {3, 2, 1, 1}));
169  auto gamma = builder.ConstantR1<float>({1.0, 1.0});
170  auto beta = builder.ConstantR1<float>({0.0, 0.0});
171  Computation add = CreateScalarAddComputation(F32, &builder);
172  // Reduce all dimensions except dimension 1.
173  Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2});
174  auto sum = builder.CheckShape(
175      builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add,
176                     /*dimensions_to_reduce=*/{0, 2, 3}),
177      TwoElementVectorF32);
178  auto input_shape = builder.GetShape(input_activations).ConsumeValueOrDie();
179  auto sum_shape = builder.GetShape(sum).ConsumeValueOrDie();
180  auto count = builder.ConstantR0<float>(ShapeUtil::ElementsIn(*input_shape) /
181                                         ShapeUtil::ElementsIn(*sum_shape));
182  auto set_means = builder.Div(sum, count);
183
184  const float kEpsilon = 1e-9f;
185  auto epsilon = builder.ConstantR0<float>(kEpsilon);
186  auto epsilon2 = builder.ConstantR1<float>({kEpsilon, kEpsilon});
187  auto activation_deviations = builder.Sub(input_activations, set_means,
188                                           /*broadcast_dimensions=*/{1});
189  auto dev_squares = builder.SquareF32(activation_deviations);
190  auto sum_of_squares = builder.CheckShape(
191      builder.Reduce(dev_squares, builder.ConstantR0<float>(0.0f), add,
192                     /*dimensions_to_reduce=*/{0, 2, 3}),
193      TwoElementVectorF32);
194  auto variance = builder.Div(sum_of_squares, count);
195  auto standard_deviation = builder.SqrtF32(variance);
196  auto standard_deviation_above_epsilon = builder.CheckShape(
197      builder.Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2}));
198  auto gt_eps = builder.Select(standard_deviation_above_epsilon,
199                               standard_deviation, epsilon2);
200  auto normalization_factors = builder.ReciprocalF32(gt_eps);
201  auto normalized_input_activations =
202      builder.Mul(activation_deviations, normalization_factors,
203                  /*broadcast_dimensions=*/{1});
204  /* auto output_activations = */ builder.Add(
205      builder.Mul(normalized_input_activations, gamma,
206                  /*broadcast_dimensions=*/{1}),
207      beta, /*broadcast_dimensions=*/{1});
208
209  Array4D<float> expected(kSamples, kZ, kY, kX);
210  Array2D<float> pz({
211      {-3.f / std::sqrt(6.f), -.1f / std::sqrt(.02f)},
212      {0.f, -.1f / std::sqrt(.02f)},
213      {3.f / std::sqrt(6.f), .2f / std::sqrt(.02f)},
214  });
215  expected.FillWithPZ(pz);
216
217  ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
218}
219
220XLA_TEST_P(BatchNormalizationTest, BasicTraining) {
221  const int kFeatureIndex = 3;
222  ComputationBuilder builder(client_, TestName());
223
224  auto operand = builder.ConstantR4FromArray4D<float>(
225      {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}});
226
227  auto scale = builder.ConstantR1<float>({2.0f, 3.0f});
228
229  auto offset = builder.ConstantR1<float>({1.0f, 2.0f});
230
231  auto tuple = builder.BatchNormTraining(operand, scale, offset,
232                                         /*epsilon=*/0.001, kFeatureIndex);
233
234  auto expected = Literal::MakeTuple(
235      {Literal::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
236                                 {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}})
237           .get(),
238       Literal::CreateR1<float>({4, 5}).get(),
239       Literal::CreateR1<float>({5, 5}).get()});
240
241  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
242}
243
244XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnSublane) {
245  const int kFeatureIndex = 2;
246  ComputationBuilder builder(client_, TestName());
247
248  auto operand = builder.ConstantR4FromArray4D<float>(
249      {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}});
250
251  auto scale = builder.ConstantR1<float>({2.0f, 3.0f});
252
253  auto offset = builder.ConstantR1<float>({1.0f, 2.0f});
254
255  auto tuple = builder.BatchNormTraining(operand, scale, offset,
256                                         /*epsilon=*/0.001, kFeatureIndex);
257
258  auto expected = Literal::MakeTuple(
259      {Literal::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
260                                 {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}})
261           .get(),
262       Literal::CreateR1<float>({4, 5}).get(),
263       Literal::CreateR1<float>({5, 5}).get()});
264
265  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
266}
267
268XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
269  // Use 0 dimension as feature, tests layout analyzer.
270  const int kFeatureIndex = 0;
271  ComputationBuilder builder(client_, TestName());
272
273  ComputationDataHandle h0;
274  auto operand = CreateR3Parameter<float>(Array3D<float>(260, 2, 2, 1.0f),
275                                          /*parameter_number=*/0, "operand",
276                                          &builder, &h0);
277  ComputationDataHandle h1;
278  auto scale =
279      CreateR1Parameter<float>(std::vector<float>(260, 1.0f),
280                               /*parameter_number=*/1, "scale", &builder, &h1);
281  ComputationDataHandle h2;
282  auto offset =
283      CreateR1Parameter<float>(std::vector<float>(260, 1.0f),
284                               /*parameter_number=*/2, "offset", &builder, &h2);
285
286  auto tuple = builder.BatchNormTraining(h0, h1, h2,
287                                         /*epsilon=*/1, kFeatureIndex);
288
289  auto expected = Literal::MakeTuple(
290      {Literal::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f))
291           .get(),
292       Literal::CreateR1<float>(std::vector<float>(260, 1.0f)).get(),
293       Literal::CreateR1<float>(std::vector<float>(260, 0.0f)).get()});
294
295  ComputeAndCompareTuple(&builder, *expected,
296                         {operand.get(), scale.get(), offset.get()},
297                         ErrorSpec(0.1));
298}
299
300XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) {
301  // Test the correctness of choosing a large epsilon value.
302  const int kFeatureIndex = 2;
303  ComputationBuilder builder(client_, TestName());
304
305  ComputationDataHandle h0;
306  auto operand = CreateR3Parameter<float>({{{0.0f}, {10.0f}, {20.0f}, {30.0f}}},
307                                          /*parameter_number=*/0, "operand",
308                                          &builder, &h0);
309  ComputationDataHandle h1;
310  auto scale =
311      CreateR1Parameter<float>(std::vector<float>(1, 1.0f),
312                               /*parameter_number=*/1, "scale", &builder, &h1);
313  ComputationDataHandle h2;
314  auto offset =
315      CreateR1Parameter<float>(std::vector<float>(1, 0.0f),
316                               /*parameter_number=*/2, "offset", &builder, &h2);
317
318  // var = 125, mean = 15, epsilon = -100
319  auto tuple = builder.BatchNormTraining(h0, h1, h2,
320                                         /*epsilon=*/-100, kFeatureIndex);
321
322  auto expected = Literal::MakeTuple(
323      {Literal::CreateR3FromArray3D<float>({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}})
324           .get(),
325       Literal::CreateR1<float>(std::vector<float>(1, 15.0f)).get(),
326       Literal::CreateR1<float>(std::vector<float>(1, 125.0f)).get()});
327
328  ComputeAndCompareTuple(&builder, *expected,
329                         {operand.get(), scale.get(), offset.get()},
330                         ErrorSpec(0.1));
331}
332
333XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) {
334  const int kFeatureIndex = 2;
335  ComputationBuilder builder(client_, TestName());
336
337  auto operand =
338      builder.ConstantR4FromArray4D<float>(Array4D<float>(2, 2, 2, 1, 0.0f));
339
340  auto scale = builder.ConstantR1<float>({1.0f, 1.0f});
341
342  auto mean = builder.ConstantR1<float>({0.0f, 0.0f});
343
344  auto var = builder.ConstantR1<float>({1.0f, 1.0f});
345
346  auto grad_output = builder.ConstantR4FromArray4D<float>(
347      {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}});
348
349  builder.BatchNormGrad(operand, scale, mean, var, grad_output,
350                        /*epsilon=*/0.0, kFeatureIndex);
351
352  auto expected = Literal::MakeTuple(
353      {Literal::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
354                                 {{{1.f}, {1.f}}, {{3.f}, {3.f}}}})
355           .get(),
356       Literal::CreateR1<float>({0, 0}).get(),
357       Literal::CreateR1<float>({16, 20}).get()});
358
359  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
360}
361
362struct BatchNormTestParam {
363  std::vector<int64> bounds;
364  int64 feature_index;
365  float random_value_mean;
366  float random_value_var;
367  bool use_cudnn_batchnorm;
368
369  friend ::std::ostream& operator<<(::std::ostream& os,
370                                    const BatchNormTestParam& p) {
371    os << "bounds={" << tensorflow::str_util::Join(p.bounds, ", ") << "}, ";
372    os << "feature_index=" << p.feature_index << ", ";
373    os << "random_value_mean=" << p.random_value_mean << ", ";
374    os << "random_value_var=" << p.random_value_var;
375
376    // Don't print use_cudnn_batchnorm when it's false, because most backends
377    // never set it to true.
378    if (p.use_cudnn_batchnorm) {
379      os << ", use_cudnn_batchnorm=true";
380    }
381    return os;
382  }
383};
384
385// Tests to test the fused operation of BatchNorm.
386class BatchNormTestManySizes
387    : public ClientLibraryTestBase,
388      public ::testing::WithParamInterface<BatchNormTestParam> {
389 public:
390  BatchNormTestManySizes() {
391    mutable_debug_options()->set_xla_gpu_use_cudnn_batchnorm(
392        GetParam().use_cudnn_batchnorm);
393  }
394};
395
396std::vector<BatchNormTestParam> BuildBatchNormTestParams() {
397  std::vector<BatchNormTestParam> params;
398
399  auto add_testcase = [&](std::vector<int64> bounds, int64 feature_index,
400                          float random_value_mean, float random_value_var) {
401    BatchNormTestParam p{bounds, feature_index, random_value_mean,
402                         random_value_var, /*use_cudnn_batchnorm=*/false};
403    params.push_back(p);
404
405    // If testing the GPU backend, also run with cudnn batchnorm enabled.
406#ifdef XLA_TEST_BACKEND_GPU
407    p.use_cudnn_batchnorm = true;
408    params.push_back(p);
409#endif
410  };
411
412  add_testcase({2, 2, 2, 2}, 0, 100.2f, 200.0f);
413  add_testcase({2, 2, 2, 2}, 3, 300.f, 400.0f);
414
415  add_testcase({1, 10, 1, 1}, 0, 10.1f, 20.1f);
416  add_testcase({10, 10, 10, 10}, 1, 3.14f, 314.15f);
417  add_testcase({10, 10, 10, 10}, 2, 666.6f, 777.7f);
418  add_testcase({10, 10, 10, 10}, 1, -666.6f, 777.7f);
419  add_testcase({10, 10, 10, 10}, 2, 0.f, 777.7f);
420  add_testcase({1, 1, 10, 130}, 2, 0.f, 777.7f);
421  add_testcase({1, 1, 130, 11}, 2, 0.f, 777.7f);
422  add_testcase({1, 1, 10, 1}, 3, 888.8f, 9.9f);
423
424  add_testcase({24, 129, 1, 2}, 2, 10000, 10000);
425  add_testcase({24, 129, 1, 2}, 3, 10000, 10000);
426
427  // Feature on low dimension to trigger relayout, check that internal logical
428  // to physical dimension calculation is correct after relayout.
429  add_testcase({1, 2, 3, 4}, 0, 100, 100);
430
431  // Zero-sized tensor.
432  add_testcase({1, 0, 100, 42}, 0, 100, 100);
433
434  return params;
435}
436
437INSTANTIATE_TEST_CASE_P(BatchNormTest_Instantiation, BatchNormTestManySizes,
438                        ::testing::ValuesIn(BuildBatchNormTestParams()));
439
440XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
441  float epsilon = 0.001;
442  ComputationBuilder builder(client_, TestName());
443  const std::vector<int64>& bounds = GetParam().bounds;
444  Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
445  input_array.FillRandom(GetParam().random_value_var,
446                         GetParam().random_value_mean);
447
448  const int64 feature_index = GetParam().feature_index;
449  const int64 num_elements_per_feature =
450      Product(bounds) / bounds[feature_index];
451  const int64 feature_bound = bounds[feature_index];
452  std::vector<float> offset(feature_bound, 1);
453  std::vector<float> scale(feature_bound, 2);
454
455  auto input_squared =
456      ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
457  std::vector<int64> reduce_dims;
458  for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
459    if (i != feature_index) {
460      reduce_dims.push_back(i);
461    }
462  }
463
464  auto sum =
465      ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
466                                  [](float a, float b) { return a + b; });
467
468  auto sum_squared =
469      ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
470                                  [](float a, float b) { return a + b; });
471
472  std::vector<float> mean(feature_bound);
473
474  for (int64 i = 0; i < feature_bound; ++i) {
475    mean[i] = sum[i] / num_elements_per_feature;
476  }
477
478  std::vector<float> mean_square(feature_bound);
479  for (int64 i = 0; i < feature_bound; ++i) {
480    mean_square[i] = mean[i] * mean[i];
481  }
482
483  std::vector<float> square_mean(feature_bound);
484  for (int64 i = 0; i < feature_bound; ++i) {
485    square_mean[i] = sum_squared[i] / num_elements_per_feature;
486  }
487
488  std::vector<float> var(feature_bound);
489  for (int64 i = 0; i < feature_bound; ++i) {
490    var[i] = square_mean[i] - mean_square[i];
491  }
492
493  Array4D<float> mean4D =
494      *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
495  auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
496  auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
497  auto offset4D =
498      *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
499
500  auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
501                                                scale4D, offset4D, epsilon);
502
503  auto expected_normalized = Literal::CreateR4FromArray4D<float>(normalized);
504
505  auto offset_literal = Literal::CreateR1<float>(offset);
506  auto scale_literal = Literal::CreateR1<float>(scale);
507  auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
508
509  auto input_activations =
510      builder.Parameter(0, input_literal->shape(), "input");
511  auto scale_activations =
512      builder.Parameter(1, scale_literal->shape(), "offset");
513  auto offset_activations =
514      builder.Parameter(2, offset_literal->shape(), "scale");
515
516  auto expected = Literal::MakeTuple({expected_normalized.get(),
517                                      Literal::CreateR1<float>(mean).get(),
518                                      Literal::CreateR1<float>(var).get()});
519
520  std::unique_ptr<GlobalData> input_data =
521      client_->TransferToServer(*input_literal).ConsumeValueOrDie();
522  std::unique_ptr<GlobalData> scale_data =
523      client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
524  std::unique_ptr<GlobalData> offset_data =
525      client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
526
527  builder.BatchNormTraining(input_activations, scale_activations,
528                            offset_activations, epsilon, feature_index);
529
530  // Run all HLO passes during this test.  In particular, ClientLibraryTestBase
531  // disables constant folding, but we want it enabled for our zero-sized tensor
532  // testcase.
533  execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
534  ComputeAndCompareTuple(
535      &builder, *expected,
536      {input_data.get(), scale_data.get(), offset_data.get()},
537      ErrorSpec(0.01, 1));
538}
539
540XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) {
541  float epsilon = 0.001;
542  ComputationBuilder builder(client_, TestName());
543  const std::vector<int64>& bounds = GetParam().bounds;
544  Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
545  input_array.FillRandom(GetParam().random_value_var,
546                         GetParam().random_value_mean);
547
548  const int64 feature_index = GetParam().feature_index;
549  const int64 num_elements_per_feature =
550      Product(bounds) / bounds[feature_index];
551  const int64 feature_bound = bounds[feature_index];
552  std::vector<float> offset(feature_bound, 1);
553  std::vector<float> scale(feature_bound, 2);
554
555  auto input_squared =
556      ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
557  std::vector<int64> reduce_dims;
558  for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
559    if (i != feature_index) {
560      reduce_dims.push_back(i);
561    }
562  }
563
564  auto sum =
565      ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
566                                  [](float a, float b) { return a + b; });
567
568  auto sum_squared =
569      ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
570                                  [](float a, float b) { return a + b; });
571
572  std::vector<float> mean(feature_bound);
573
574  for (int64 i = 0; i < feature_bound; ++i) {
575    mean[i] = sum[i] / num_elements_per_feature;
576  }
577
578  std::vector<float> mean_square(feature_bound);
579  for (int64 i = 0; i < feature_bound; ++i) {
580    mean_square[i] = mean[i] * mean[i];
581  }
582
583  std::vector<float> square_mean(feature_bound);
584  for (int64 i = 0; i < feature_bound; ++i) {
585    square_mean[i] = sum_squared[i] / num_elements_per_feature;
586  }
587
588  std::vector<float> var(feature_bound);
589  for (int64 i = 0; i < feature_bound; ++i) {
590    var[i] = square_mean[i] - mean_square[i];
591  }
592
593  Array4D<float> mean4D =
594      *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
595  auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
596  auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
597  auto offset4D =
598      *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
599
600  auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
601                                                scale4D, offset4D, epsilon);
602
603  auto offset_literal = Literal::CreateR1<float>(offset);
604  auto scale_literal = Literal::CreateR1<float>(scale);
605  auto mean_literal = Literal::CreateR1<float>(mean);
606  auto var_literal = Literal::CreateR1<float>(var);
607  auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
608
609  auto input_activations =
610      builder.Parameter(0, input_literal->shape(), "input");
611  auto scale_activations =
612      builder.Parameter(1, scale_literal->shape(), "offset");
613  auto offset_activations =
614      builder.Parameter(2, offset_literal->shape(), "scale");
615  auto mean_activations = builder.Parameter(3, mean_literal->shape(), "mean");
616  auto variance_activations =
617      builder.Parameter(4, var_literal->shape(), "variance");
618
619  Array4D<float> expected = normalized;
620
621  std::unique_ptr<GlobalData> input_data =
622      client_->TransferToServer(*input_literal).ConsumeValueOrDie();
623  std::unique_ptr<GlobalData> scale_data =
624      client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
625  std::unique_ptr<GlobalData> offset_data =
626      client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
627  std::unique_ptr<GlobalData> mean_data =
628      client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
629  std::unique_ptr<GlobalData> variance_data =
630      client_->TransferToServer(*var_literal).ConsumeValueOrDie();
631
632  builder.BatchNormInference(input_activations, scale_activations,
633                             offset_activations, mean_activations,
634                             variance_activations, epsilon, feature_index);
635
636  // Run all HLO passes during this test.  In particular, ClientLibraryTestBase
637  // disables constant folding, but we want it enabled for our zero-sized tensor
638  // testcase.
639  execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
640
641  ComputeAndCompareR4<float>(
642      &builder, expected,
643      {input_data.get(), scale_data.get(), offset_data.get(), mean_data.get(),
644       variance_data.get()},
645      ErrorSpec(0.01, 1));
646}
647
648XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
649  float epsilon = 0.001;
650  ComputationBuilder builder(client_, TestName());
651  const std::vector<int64>& bounds = GetParam().bounds;
652  Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
653  input_array.FillRandom(GetParam().random_value_var,
654                         GetParam().random_value_mean);
655
656  Array4D<float> grad_output_array(bounds[0], bounds[1], bounds[2], bounds[3]);
657  grad_output_array.FillRandom(GetParam().random_value_var,
658                               GetParam().random_value_mean);
659
660  const int64 feature_index = GetParam().feature_index;
661  const int64 num_elements_per_feature =
662      Product(bounds) / bounds[feature_index];
663  const int64 feature_bound = bounds[feature_index];
664  std::vector<float> scale(feature_bound, 2);
665
666  auto input_squared =
667      ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
668  std::vector<int64> reduce_dims;
669  for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
670    if (i != feature_index) {
671      reduce_dims.push_back(i);
672    }
673  }
674
675  auto sum =
676      ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
677                                  [](float a, float b) { return a + b; });
678
679  auto sum_squared =
680      ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
681                                  [](float a, float b) { return a + b; });
682
683  std::vector<float> mean(feature_bound);
684
685  for (int64 i = 0; i < feature_bound; ++i) {
686    if (num_elements_per_feature > 0) {
687      mean[i] = sum[i] / num_elements_per_feature;
688    } else {
689      mean[i] = 0;
690    }
691  }
692
693  std::vector<float> mean_square(feature_bound);
694  for (int64 i = 0; i < feature_bound; ++i) {
695    mean_square[i] = mean[i] * mean[i];
696  }
697
698  std::vector<float> square_mean(feature_bound);
699  for (int64 i = 0; i < feature_bound; ++i) {
700    if (num_elements_per_feature > 0) {
701      square_mean[i] = sum_squared[i] / num_elements_per_feature;
702    } else {
703      square_mean[i] = 0;
704    }
705  }
706
707  std::vector<float> var(feature_bound);
708  for (int64 i = 0; i < feature_bound; ++i) {
709    var[i] = square_mean[i] - mean_square[i];
710  }
711
712  Array4D<float> mean4D =
713      *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
714  auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
715  auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
716
717  auto var_add_epsilon = *ReferenceUtil::MapArray4D(
718      var4D, [epsilon](float a) { return a + epsilon; });
719
720  auto rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D(
721      var_add_epsilon, [epsilon](float a) { return 1 / std::sqrt(a); });
722
723  auto grad_output_times_var =
724      *ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon,
725                                 [](float a, float b) { return a * b; });
726
727  auto activation_shifted = *ReferenceUtil::MapArray4D(
728      input_array, mean4D, [](float a, float b) { return a - b; });
729
730  auto activation_shifted_times_grad_output =
731      *ReferenceUtil::MapArray4D(grad_output_array, activation_shifted,
732                                 [](float a, float b) { return a * b; });
733
734  auto grad_scale_before_reduction = *ReferenceUtil::MapArray4D(
735      activation_shifted_times_grad_output, rsqrt_var_add_epsilon,
736      [](float a, float b) { return a * b; });
737
738  auto grad_scale = ReferenceUtil::Reduce4DTo1D(
739      grad_scale_before_reduction, /*init=*/0.0f, reduce_dims,
740      [](float a, float b) { return a + b; });
741
742  auto grad_offset =
743      ReferenceUtil::Reduce4DTo1D(grad_output_array, /*init=*/0.0f, reduce_dims,
744                                  [](float a, float b) { return a + b; });
745
746  auto scale_times_rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D(
747      scale4D, rsqrt_var_add_epsilon, [](float a, float b) { return a * b; });
748
749  auto I1 = *ReferenceUtil::MapArray4D(
750      grad_output_array, [&](float a) { return num_elements_per_feature * a; });
751
752  auto I2 = *ReferenceUtil::Broadcast1DTo4D(grad_offset, bounds, feature_index);
753
754  // I3 = sum(output_grad * (activation - mean(activation)))
755  auto I3 = *ReferenceUtil::Broadcast1DTo4D(
756      ReferenceUtil::Reduce4DTo1D(activation_shifted_times_grad_output,
757                                  /*init=*/0.0f, reduce_dims,
758                                  [](float a, float b) { return a + b; }),
759      bounds, feature_index);
760
761  // I4 = (activation - mean(activation)) *
762  //   sum(output_grad * (activation - mean(activation)))
763  auto I4 = *ReferenceUtil::MapArray4D(I3, activation_shifted,
764                                       [](float a, float b) { return a * b; });
765
766  // I5 = (activation - mean(activation)) *
767  //   sum(output_grad * (activation - mean(activation))) / (variance +
768  //   epsilon))
769  auto I5 = *ReferenceUtil::MapArray4D(I4, var_add_epsilon,
770                                       [](float a, float b) { return a / b; });
771
772  auto grad_activation = *ReferenceUtil::MapArray4D(
773      I1, I2, [](float a, float b) { return a - b; });
774
775  grad_activation = *ReferenceUtil::MapArray4D(
776      grad_activation, I5, [](float a, float b) { return a - b; });
777
778  grad_activation = *ReferenceUtil::MapArray4D(
779      grad_activation, scale4D, [](float a, float b) { return a * b; });
780
781  grad_activation = *ReferenceUtil::MapArray4D(
782      grad_activation, rsqrt_var_add_epsilon, [=](float a, float b) {
783        if (num_elements_per_feature > 0) {
784          return a * b / num_elements_per_feature;
785        }
786        return 0.f;
787      });
788
789  auto expected_grad_activation =
790      Literal::CreateR4FromArray4D<float>(grad_activation);
791
792  auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
793  auto scale_literal = Literal::CreateR1<float>(scale);
794  auto mean_literal = Literal::CreateR1<float>(mean);
795  auto var_literal = Literal::CreateR1<float>(var);
796  auto grad_output_literal =
797      Literal::CreateR4FromArray4D<float>(grad_output_array);
798
799  auto input_parameter = builder.Parameter(0, input_literal->shape(), "input");
800  auto scale_parameter = builder.Parameter(1, scale_literal->shape(), "scale");
801  auto mean_parameter = builder.Parameter(2, mean_literal->shape(), "mean");
802  auto var_parameter = builder.Parameter(3, var_literal->shape(), "variance");
803  auto grad_output_parameter =
804      builder.Parameter(4, grad_output_literal->shape(), "grad_output");
805
806  std::unique_ptr<GlobalData> input_data =
807      client_->TransferToServer(*input_literal).ConsumeValueOrDie();
808  std::unique_ptr<GlobalData> scale_data =
809      client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
810  std::unique_ptr<GlobalData> mean_data =
811      client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
812  std::unique_ptr<GlobalData> var_data =
813      client_->TransferToServer(*var_literal).ConsumeValueOrDie();
814  std::unique_ptr<GlobalData> grad_output_data =
815      client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie();
816
817  auto t = builder.BatchNormGrad(input_parameter, scale_parameter,
818                                 mean_parameter, var_parameter,
819                                 grad_output_parameter, epsilon, feature_index);
820
821  auto expected =
822      Literal::MakeTuple({expected_grad_activation.get(),
823                          Literal::CreateR1<float>(grad_scale).get(),
824                          Literal::CreateR1<float>(grad_offset).get()});
825
826  // Run all HLO passes during this test.  In particular, ClientLibraryTestBase
827  // disables constant folding, but we want it enabled for our zero-sized tensor
828  // testcase.
829  execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
830
831  ComputeAndCompareTuple(&builder, *expected,
832                         {input_data.get(), scale_data.get(), mean_data.get(),
833                          var_data.get(), grad_output_data.get()},
834                         ErrorSpec(0.01, 1));
835}
836
837}  // namespace
838}  // namespace xla
839