11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <cmath>
171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <memory>
181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <vector>
191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/array2d.h"
211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/array4d.h"
221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation.h"
231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation_builder.h"
241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/local_client.h"
261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/literal_util.h"
271464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/reference_util.h"
281464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_computation.h"
291464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_instruction.h"
301464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_module.h"
311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/shape_util.h"
321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/statusor.h"
331464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/test.h"
341464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/test_helpers.h"
351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
361464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/literal_test_util.h"
381464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/tests/test_macros.h"
391464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/tests/test_utils.h"
401464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/util.h"
411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/xla_data.pb.h"
423d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower#include "tensorflow/core/lib/math/math_util.h"
43913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar#include "tensorflow/core/lib/strings/str_util.h"
441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/logging.h"
451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/test.h"
461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/types.h"
471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla {
491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace {
501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
516a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebarclass BatchNormalizationTest
526a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar    : public ClientLibraryTestBase,
536a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar      public ::testing::WithParamInterface<bool /*use_cudnn_batchnorm*/> {
541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins protected:
551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  BatchNormalizationTest() : input_array_(kSamples, kZ, kY, kX) {
566a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar    mutable_debug_options()->set_xla_gpu_use_cudnn_batchnorm(GetParam());
576a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    Array2D<float> pz({
591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        // z0 z1
601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        {-1.0f, 4.1f},  // p0
611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        {2.0f, 4.1f},   // p1
621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        {5.0f, 4.4f},   // p2
631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    });
641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    input_array_.FillWithPZ(pz);
657d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan    input_literal_ = std::move(*Literal::CreateR4FromArray4D(input_array_));
661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    CHECK_EQ(kSamples, input_array_.planes());
671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    CHECK_EQ(kZ, input_array_.depth());
681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    CHECK_EQ(kY, input_array_.height());
691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    CHECK_EQ(kY, input_array_.width());
701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  static constexpr int64 kSamples = 3;
731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  static constexpr int64 kX = 1;
741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  static constexpr int64 kY = 1;
751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  static constexpr int64 kZ = 2;
761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array4D<float> input_array_;
781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Literal input_literal_;
791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const ErrorSpec error_spec_{0.001, 0.001};
801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
826a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar// If testing the GPU backend, run the tests twice, with and without cudnn
836a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar// batchnorm.  Otherwise, just run the tests once -- the value of this flag
846a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar// doesn't matter.
856a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar#ifdef XLA_TEST_BACKEND_GPU
866a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarINSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest,
876a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                        ::testing::Bool());
886a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar#else
896a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarINSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest,
906a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                        ::testing::Values(false));
916a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar#endif
926a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
936a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormalizationTest, SubtractInZ) {
941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder builder(client_, "subtract_in_z_one_sample");
951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto x = builder.ConstantLiteral(input_literal_);
961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto y = builder.ConstantR1<float>({3.14, 4.25});
971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  builder.Sub(x, y, /*broadcast_dimensions=*/{1});
981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array4D<float> expected(kSamples, kZ, kY, kX);
1001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array2D<float> pz({
1011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      {-1.0f - 3.14f, 4.1f - 4.25f},  // p0
1021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      {2.0f - 3.14f, 4.1f - 4.25f},   // p1
1031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      {5.0f - 3.14f, 4.4f - 4.25f},   // p2
1041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  });
1051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  expected.FillWithPZ(pz);
1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1096a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) {
1101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder builder(client_, "square_tesseract_elementwise");
1111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto x = builder.ConstantLiteral(input_literal_);
1121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  builder.SquareF32(x);
1131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1143d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower  using tensorflow::MathUtil;
1153d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower
1161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array4D<float> expected(kSamples, kZ, kY, kX);
1171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array2D<float> expected_pz({
1183d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower      {MathUtil::IPow(-1.0f, 2), MathUtil::IPow(4.1f, 2)},
1193d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower      {MathUtil::IPow(2.0f, 2), MathUtil::IPow(4.1f, 2)},
1203d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower      {MathUtil::IPow(5.0f, 2), MathUtil::IPow(4.4f, 2)},
1211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  });
1221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  expected.FillWithPZ(expected_pz);
1231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
1241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1266a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormalizationTest, SumToZ) {
1271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder builder(client_, "sum_to_z");
1281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto input_activations = builder.ConstantLiteral(input_literal_);
1291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Computation add = CreateScalarAddComputation(F32, &builder);
1301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Reduce all but the Z dimension.
1311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add,
1321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                 {0, 2, 3});
1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<float> expected = {6, 12.6};
1351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
1361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1386a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormalizationTest, SquareAndReduce) {
1391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder builder(client_, "square_and_reduce");
1401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto input_activations = builder.ConstantLiteral(input_literal_);
1411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto set_means = builder.ConstantR1<float>({2.f, 4.2f});
1421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto activation_deviations = builder.Sub(input_activations, set_means,
1431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                           /*broadcast_dimensions=*/{1});
1441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Computation add = CreateScalarAddComputation(F32, &builder);
1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto dev_squares = builder.SquareF32(activation_deviations);
1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto sum_of_squares = builder.Reduce(
1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      dev_squares, builder.ConstantR0<float>(0.0f), add, {0, 2, 3});
1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<float> expected = {18, 0.06};
1501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
1511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1536a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormalizationTest, VarianceToStddev) {
1541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder builder(client_, "variance_to_stddev");
1551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto variance = builder.ConstantR1<float>({6.f, .02f});
1561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto sqrt = builder.SqrtF32(variance);
1571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<float> expected = {2.44948974f, 0.14142136f};
1591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
1601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// Compare against a forward batch normalization example in the NN spec
1631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// reference.
1646a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) {
1651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder builder(client_, "batch_normalize_per_spec");
1661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto input_activations =
1671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.CheckShape(builder.ConstantLiteral(input_literal_),
1681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                         ShapeUtil::MakeShape(F32, {3, 2, 1, 1}));
1691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto gamma = builder.ConstantR1<float>({1.0, 1.0});
1701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto beta = builder.ConstantR1<float>({0.0, 0.0});
1711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Computation add = CreateScalarAddComputation(F32, &builder);
1721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Reduce all dimensions except dimension 1.
1731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2});
1741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto sum = builder.CheckShape(
1751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add,
1761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                     /*dimensions_to_reduce=*/{0, 2, 3}),
1771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      TwoElementVectorF32);
1781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto input_shape = builder.GetShape(input_activations).ConsumeValueOrDie();
1791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto sum_shape = builder.GetShape(sum).ConsumeValueOrDie();
1801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto count = builder.ConstantR0<float>(ShapeUtil::ElementsIn(*input_shape) /
1811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                         ShapeUtil::ElementsIn(*sum_shape));
1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto set_means = builder.Div(sum, count);
1831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const float kEpsilon = 1e-9f;
1851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto epsilon = builder.ConstantR0<float>(kEpsilon);
1861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto epsilon2 = builder.ConstantR1<float>({kEpsilon, kEpsilon});
1871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto activation_deviations = builder.Sub(input_activations, set_means,
1881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                           /*broadcast_dimensions=*/{1});
1891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto dev_squares = builder.SquareF32(activation_deviations);
1901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto sum_of_squares = builder.CheckShape(
1911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.Reduce(dev_squares, builder.ConstantR0<float>(0.0f), add,
1921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                     /*dimensions_to_reduce=*/{0, 2, 3}),
1931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      TwoElementVectorF32);
1941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto variance = builder.Div(sum_of_squares, count);
1951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto standard_deviation = builder.SqrtF32(variance);
1961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto standard_deviation_above_epsilon = builder.CheckShape(
1971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2}));
1981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto gt_eps = builder.Select(standard_deviation_above_epsilon,
1991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                               standard_deviation, epsilon2);
2001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto normalization_factors = builder.ReciprocalF32(gt_eps);
2011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto normalized_input_activations =
2021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.Mul(activation_deviations, normalization_factors,
2031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  /*broadcast_dimensions=*/{1});
2041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  /* auto output_activations = */ builder.Add(
2051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.Mul(normalized_input_activations, gamma,
2061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  /*broadcast_dimensions=*/{1}),
2071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      beta, /*broadcast_dimensions=*/{1});
2081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array4D<float> expected(kSamples, kZ, kY, kX);
2101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array2D<float> pz({
2111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      {-3.f / std::sqrt(6.f), -.1f / std::sqrt(.02f)},
2121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      {0.f, -.1f / std::sqrt(.02f)},
2131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      {3.f / std::sqrt(6.f), .2f / std::sqrt(.02f)},
2141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  });
2151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  expected.FillWithPZ(pz);
2161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
2181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
2191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2206a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormalizationTest, BasicTraining) {
2216a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  const int kFeatureIndex = 3;
2226a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  ComputationBuilder builder(client_, TestName());
2236a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
2246a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  auto operand = builder.ConstantR4FromArray4D<float>(
2256a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar      {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}});
2266a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
2276a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  auto scale = builder.ConstantR1<float>({2.0f, 3.0f});
2286a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
2296a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  auto offset = builder.ConstantR1<float>({1.0f, 2.0f});
2306a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
2316a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  auto tuple = builder.BatchNormTraining(operand, scale, offset,
2326a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                                         /*epsilon=*/0.001, kFeatureIndex);
2336a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
2347d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan  auto expected = Literal::MakeTuple(
2356a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar      {Literal::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
2366a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                                 {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}})
2376a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar           .get(),
2386a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar       Literal::CreateR1<float>({4, 5}).get(),
2396a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar       Literal::CreateR1<float>({5, 5}).get()});
2406a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
2417d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
2426a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar}
2436a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
2446a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormalizationTest, BasicTrainingOnSublane) {
2456a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  const int kFeatureIndex = 2;
2466a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  ComputationBuilder builder(client_, TestName());
2476a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
2486a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  auto operand = builder.ConstantR4FromArray4D<float>(
2496a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar      {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}});
2506a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
2516a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  auto scale = builder.ConstantR1<float>({2.0f, 3.0f});
2526a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
2536a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  auto offset = builder.ConstantR1<float>({1.0f, 2.0f});
2546a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
2556a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  auto tuple = builder.BatchNormTraining(operand, scale, offset,
2566a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                                         /*epsilon=*/0.001, kFeatureIndex);
2576a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
2587d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan  auto expected = Literal::MakeTuple(
2596a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar      {Literal::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
2606a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                                 {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}})
2616a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar           .get(),
2626a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar       Literal::CreateR1<float>({4, 5}).get(),
2636a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar       Literal::CreateR1<float>({5, 5}).get()});
2646a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
2657d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
2666a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar}
2676a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
2686a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
2696a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  // Use 0 dimension as feature, tests layout analyzer.
2706a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  const int kFeatureIndex = 0;
2716a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  ComputationBuilder builder(client_, TestName());
2726a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
2736a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  ComputationDataHandle h0;
2746a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  auto operand = CreateR3Parameter<float>(Array3D<float>(260, 2, 2, 1.0f),
2756a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                                          /*parameter_number=*/0, "operand",
2766a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                                          &builder, &h0);
2776a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  ComputationDataHandle h1;
2786a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  auto scale =
2796a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar      CreateR1Parameter<float>(std::vector<float>(260, 1.0f),
2806a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                               /*parameter_number=*/1, "scale", &builder, &h1);
2816a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  ComputationDataHandle h2;
2826a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  auto offset =
2836a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar      CreateR1Parameter<float>(std::vector<float>(260, 1.0f),
2846a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                               /*parameter_number=*/2, "offset", &builder, &h2);
2856a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
2866a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  auto tuple = builder.BatchNormTraining(h0, h1, h2,
2876a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                                         /*epsilon=*/1, kFeatureIndex);
2886a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
2897d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan  auto expected = Literal::MakeTuple(
2906a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar      {Literal::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f))
2916a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar           .get(),
2926a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar       Literal::CreateR1<float>(std::vector<float>(260, 1.0f)).get(),
2936a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar       Literal::CreateR1<float>(std::vector<float>(260, 0.0f)).get()});
2946a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
2957d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan  ComputeAndCompareTuple(&builder, *expected,
2966a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                         {operand.get(), scale.get(), offset.get()},
2976a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                         ErrorSpec(0.1));
2986a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar}
2996a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
3006a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) {
3016a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  // Test the correctness of choosing a large epsilon value.
3026a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  const int kFeatureIndex = 2;
3036a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  ComputationBuilder builder(client_, TestName());
3046a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
3056a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  ComputationDataHandle h0;
3066a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  auto operand = CreateR3Parameter<float>({{{0.0f}, {10.0f}, {20.0f}, {30.0f}}},
3076a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                                          /*parameter_number=*/0, "operand",
3086a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                                          &builder, &h0);
3096a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  ComputationDataHandle h1;
3106a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  auto scale =
3116a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar      CreateR1Parameter<float>(std::vector<float>(1, 1.0f),
3126a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                               /*parameter_number=*/1, "scale", &builder, &h1);
3136a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  ComputationDataHandle h2;
3146a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  auto offset =
3156a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar      CreateR1Parameter<float>(std::vector<float>(1, 0.0f),
3166a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                               /*parameter_number=*/2, "offset", &builder, &h2);
3176a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
3186a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  // var = 125, mean = 15, epsilon = -100
3196a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  auto tuple = builder.BatchNormTraining(h0, h1, h2,
3206a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                                         /*epsilon=*/-100, kFeatureIndex);
3216a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
3227d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan  auto expected = Literal::MakeTuple(
3236a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar      {Literal::CreateR3FromArray3D<float>({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}})
3246a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar           .get(),
3256a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar       Literal::CreateR1<float>(std::vector<float>(1, 15.0f)).get(),
3266a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar       Literal::CreateR1<float>(std::vector<float>(1, 125.0f)).get()});
3276a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
3287d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan  ComputeAndCompareTuple(&builder, *expected,
3296a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                         {operand.get(), scale.get(), offset.get()},
3306a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                         ErrorSpec(0.1));
3316a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar}
3326a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
3336a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) {
3346a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  const int kFeatureIndex = 2;
3356a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  ComputationBuilder builder(client_, TestName());
3366a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
3376a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  auto operand =
3386a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar      builder.ConstantR4FromArray4D<float>(Array4D<float>(2, 2, 2, 1, 0.0f));
3396a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
3406a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  auto scale = builder.ConstantR1<float>({1.0f, 1.0f});
3416a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
3426a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  auto mean = builder.ConstantR1<float>({0.0f, 0.0f});
3436a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
3446a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  auto var = builder.ConstantR1<float>({1.0f, 1.0f});
3456a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
3466a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  auto grad_output = builder.ConstantR4FromArray4D<float>(
3476a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar      {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}});
3486a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
3496a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  builder.BatchNormGrad(operand, scale, mean, var, grad_output,
3506a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                        /*epsilon=*/0.0, kFeatureIndex);
3516a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
3527d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan  auto expected = Literal::MakeTuple(
3536a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar      {Literal::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
3546a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                                 {{{1.f}, {1.f}}, {{3.f}, {3.f}}}})
3556a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar           .get(),
3566a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar       Literal::CreateR1<float>({0, 0}).get(),
3576a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar       Literal::CreateR1<float>({16, 20}).get()});
3586a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
3597d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
3606a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar}
3616a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
3621464b9930de871fd11870941963253670f737c23A. Unique TensorFlowerstruct BatchNormTestParam {
3631464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<int64> bounds;
3641464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  int64 feature_index;
3651464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  float random_value_mean;
3661464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  float random_value_var;
3676a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  bool use_cudnn_batchnorm;
368913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar
369913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar  friend ::std::ostream& operator<<(::std::ostream& os,
370913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar                                    const BatchNormTestParam& p) {
371913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar    os << "bounds={" << tensorflow::str_util::Join(p.bounds, ", ") << "}, ";
372913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar    os << "feature_index=" << p.feature_index << ", ";
373913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar    os << "random_value_mean=" << p.random_value_mean << ", ";
374913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar    os << "random_value_var=" << p.random_value_var;
3756a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
3766a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar    // Don't print use_cudnn_batchnorm when it's false, because most backends
3776a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar    // never set it to true.
3786a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar    if (p.use_cudnn_batchnorm) {
3796a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar      os << ", use_cudnn_batchnorm=true";
3806a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar    }
381913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar    return os;
382913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar  }
3831464b9930de871fd11870941963253670f737c23A. Unique TensorFlower};
3841464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
3851464b9930de871fd11870941963253670f737c23A. Unique TensorFlower// Tests to test the fused operation of BatchNorm.
3866a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebarclass BatchNormTestManySizes
3876a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar    : public ClientLibraryTestBase,
3886a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar      public ::testing::WithParamInterface<BatchNormTestParam> {
3896a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar public:
3906a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  BatchNormTestManySizes() {
3916a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar    mutable_debug_options()->set_xla_gpu_use_cudnn_batchnorm(
3926a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar        GetParam().use_cudnn_batchnorm);
3936a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  }
3941464b9930de871fd11870941963253670f737c23A. Unique TensorFlower};
3951464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
3966a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebarstd::vector<BatchNormTestParam> BuildBatchNormTestParams() {
3976a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  std::vector<BatchNormTestParam> params;
3986a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
3996a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  auto add_testcase = [&](std::vector<int64> bounds, int64 feature_index,
4006a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                          float random_value_mean, float random_value_var) {
4016a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar    BatchNormTestParam p{bounds, feature_index, random_value_mean,
4026a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                         random_value_var, /*use_cudnn_batchnorm=*/false};
4036a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar    params.push_back(p);
4046a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
4056a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar    // If testing the GPU backend, also run with cudnn batchnorm enabled.
4066a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar#ifdef XLA_TEST_BACKEND_GPU
4076a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar    p.use_cudnn_batchnorm = true;
4086a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar    params.push_back(p);
4096a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar#endif
4106a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  };
4116a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
4126a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  add_testcase({2, 2, 2, 2}, 0, 100.2f, 200.0f);
4136a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  add_testcase({2, 2, 2, 2}, 3, 300.f, 400.0f);
4146a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
4156a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  add_testcase({1, 10, 1, 1}, 0, 10.1f, 20.1f);
4166a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  add_testcase({10, 10, 10, 10}, 1, 3.14f, 314.15f);
4176a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  add_testcase({10, 10, 10, 10}, 2, 666.6f, 777.7f);
4186a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  add_testcase({10, 10, 10, 10}, 1, -666.6f, 777.7f);
4196a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  add_testcase({10, 10, 10, 10}, 2, 0.f, 777.7f);
4206a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  add_testcase({1, 1, 10, 130}, 2, 0.f, 777.7f);
4216a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  add_testcase({1, 1, 130, 11}, 2, 0.f, 777.7f);
4226a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  add_testcase({1, 1, 10, 1}, 3, 888.8f, 9.9f);
4236a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
4246a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  add_testcase({24, 129, 1, 2}, 2, 10000, 10000);
4256a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  add_testcase({24, 129, 1, 2}, 3, 10000, 10000);
4266a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
4276a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  // Feature on low dimension to trigger relayout, check that internal logical
4286a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  // to physical dimension calculation is correct after relayout.
4296a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  add_testcase({1, 2, 3, 4}, 0, 100, 100);
4306a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
4311654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar  // Zero-sized tensor.
4321654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar  add_testcase({1, 0, 100, 42}, 0, 100, 100);
4331654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar
4346a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar  return params;
4356a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar}
4366a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
4376a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarINSTANTIATE_TEST_CASE_P(BatchNormTest_Instantiation, BatchNormTestManySizes,
4386a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar                        ::testing::ValuesIn(BuildBatchNormTestParams()));
4396a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar
4406a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
4411464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  float epsilon = 0.001;
4421464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputationBuilder builder(client_, TestName());
4431464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  const std::vector<int64>& bounds = GetParam().bounds;
4441464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
4451464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  input_array.FillRandom(GetParam().random_value_var,
4461464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                         GetParam().random_value_mean);
4471464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
4481464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  const int64 feature_index = GetParam().feature_index;
4491464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  const int64 num_elements_per_feature =
4501464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      Product(bounds) / bounds[feature_index];
4511464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  const int64 feature_bound = bounds[feature_index];
4521464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<float> offset(feature_bound, 1);
4531464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<float> scale(feature_bound, 2);
4541464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
4551464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto input_squared =
4561464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
4571464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<int64> reduce_dims;
458030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
4591464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    if (i != feature_index) {
4601464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      reduce_dims.push_back(i);
4611464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    }
4621464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  }
4631464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
4641464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto sum =
4651464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
4661464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                  [](float a, float b) { return a + b; });
4671464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
4681464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto sum_squared =
4691464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
4701464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                  [](float a, float b) { return a + b; });
4711464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
4721464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<float> mean(feature_bound);
4731464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
4741464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
4751464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    mean[i] = sum[i] / num_elements_per_feature;
4761464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  }
4771464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
4781464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<float> mean_square(feature_bound);
4791464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
4801464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    mean_square[i] = mean[i] * mean[i];
4811464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  }
4821464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
4831464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<float> square_mean(feature_bound);
4841464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
4851464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    square_mean[i] = sum_squared[i] / num_elements_per_feature;
4861464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  }
4871464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
4881464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<float> var(feature_bound);
4891464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
4901464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    var[i] = square_mean[i] - mean_square[i];
4911464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  }
4921464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
493f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  Array4D<float> mean4D =
4941464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
495f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
496f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
497f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  auto offset4D =
4981464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
4991464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
500f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
501f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower                                                scale4D, offset4D, epsilon);
5021464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
5031464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto expected_normalized = Literal::CreateR4FromArray4D<float>(normalized);
5041464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
5051464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto offset_literal = Literal::CreateR1<float>(offset);
5061464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto scale_literal = Literal::CreateR1<float>(scale);
5071464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
5081464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
5091464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto input_activations =
5101464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      builder.Parameter(0, input_literal->shape(), "input");
5111464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto scale_activations =
5121464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      builder.Parameter(1, scale_literal->shape(), "offset");
5131464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto offset_activations =
5141464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      builder.Parameter(2, offset_literal->shape(), "scale");
5151464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
5167d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan  auto expected = Literal::MakeTuple({expected_normalized.get(),
5177d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan                                      Literal::CreateR1<float>(mean).get(),
5187d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan                                      Literal::CreateR1<float>(var).get()});
5191464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
5201464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::unique_ptr<GlobalData> input_data =
5211464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      client_->TransferToServer(*input_literal).ConsumeValueOrDie();
5221464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::unique_ptr<GlobalData> scale_data =
5231464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
5241464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::unique_ptr<GlobalData> offset_data =
5251464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
5261464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
5271464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  builder.BatchNormTraining(input_activations, scale_activations,
5281464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                            offset_activations, epsilon, feature_index);
5291464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
5301654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar  // Run all HLO passes during this test.  In particular, ClientLibraryTestBase
5311654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar  // disables constant folding, but we want it enabled for our zero-sized tensor
5321654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar  // testcase.
5331654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar  execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
5341464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputeAndCompareTuple(
5357d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan      &builder, *expected,
5361464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      {input_data.get(), scale_data.get(), offset_data.get()},
5371464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      ErrorSpec(0.01, 1));
5381464b9930de871fd11870941963253670f737c23A. Unique TensorFlower}
5391464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
5406a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) {
5417359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  float epsilon = 0.001;
5427359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  ComputationBuilder builder(client_, TestName());
5437359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  const std::vector<int64>& bounds = GetParam().bounds;
5447359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
5457359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  input_array.FillRandom(GetParam().random_value_var,
5467359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower                         GetParam().random_value_mean);
5477359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
5487359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  const int64 feature_index = GetParam().feature_index;
5497359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  const int64 num_elements_per_feature =
5507359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      Product(bounds) / bounds[feature_index];
5517359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  const int64 feature_bound = bounds[feature_index];
5527359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::vector<float> offset(feature_bound, 1);
5537359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::vector<float> scale(feature_bound, 2);
5547359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
5557359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto input_squared =
5567359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
5577359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::vector<int64> reduce_dims;
5587359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
5597359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower    if (i != feature_index) {
5607359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      reduce_dims.push_back(i);
5617359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower    }
5627359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  }
5637359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
5647359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto sum =
5657359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
5667359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower                                  [](float a, float b) { return a + b; });
5677359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
5687359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto sum_squared =
5697359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
5707359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower                                  [](float a, float b) { return a + b; });
5717359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
5727359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::vector<float> mean(feature_bound);
5737359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
5747359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
5757359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower    mean[i] = sum[i] / num_elements_per_feature;
5767359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  }
5777359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
5787359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::vector<float> mean_square(feature_bound);
5797359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
5807359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower    mean_square[i] = mean[i] * mean[i];
5817359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  }
5827359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
5837359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::vector<float> square_mean(feature_bound);
5847359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
5857359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower    square_mean[i] = sum_squared[i] / num_elements_per_feature;
5867359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  }
5877359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
5887359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::vector<float> var(feature_bound);
5897359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
5907359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower    var[i] = square_mean[i] - mean_square[i];
5917359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  }
5927359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
5937359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  Array4D<float> mean4D =
5947359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
5957359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
5967359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
5977359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto offset4D =
5987359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
5997359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
6007359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
6017359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower                                                scale4D, offset4D, epsilon);
6027359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
6037359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto offset_literal = Literal::CreateR1<float>(offset);
6047359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto scale_literal = Literal::CreateR1<float>(scale);
6057359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto mean_literal = Literal::CreateR1<float>(mean);
6067359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto var_literal = Literal::CreateR1<float>(var);
6077359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
6087359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
6097359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto input_activations =
6107359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      builder.Parameter(0, input_literal->shape(), "input");
6117359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto scale_activations =
6127359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      builder.Parameter(1, scale_literal->shape(), "offset");
6137359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto offset_activations =
6147359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      builder.Parameter(2, offset_literal->shape(), "scale");
6157359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto mean_activations = builder.Parameter(3, mean_literal->shape(), "mean");
6167359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto variance_activations =
6177359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      builder.Parameter(4, var_literal->shape(), "variance");
6187359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
6197359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  Array4D<float> expected = normalized;
6207359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
6217359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::unique_ptr<GlobalData> input_data =
6227359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      client_->TransferToServer(*input_literal).ConsumeValueOrDie();
6237359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::unique_ptr<GlobalData> scale_data =
6247359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
6257359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::unique_ptr<GlobalData> offset_data =
6267359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
6277359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::unique_ptr<GlobalData> mean_data =
6287359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
6297359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::unique_ptr<GlobalData> variance_data =
6307359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      client_->TransferToServer(*var_literal).ConsumeValueOrDie();
6317359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
6327359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  builder.BatchNormInference(input_activations, scale_activations,
6337359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower                             offset_activations, mean_activations,
6347359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower                             variance_activations, epsilon, feature_index);
6357359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
6361654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar  // Run all HLO passes during this test.  In particular, ClientLibraryTestBase
6371654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar  // disables constant folding, but we want it enabled for our zero-sized tensor
6381654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar  // testcase.
6391654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar  execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
6401654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar
6417359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  ComputeAndCompareR4<float>(
6427359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      &builder, expected,
6437359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      {input_data.get(), scale_data.get(), offset_data.get(), mean_data.get(),
6447359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower       variance_data.get()},
6457359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      ErrorSpec(0.01, 1));
6467359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower}
6477359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
6486a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
649ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  float epsilon = 0.001;
650ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  ComputationBuilder builder(client_, TestName());
651ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  const std::vector<int64>& bounds = GetParam().bounds;
652ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
653ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  input_array.FillRandom(GetParam().random_value_var,
654ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                         GetParam().random_value_mean);
655ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
656ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  Array4D<float> grad_output_array(bounds[0], bounds[1], bounds[2], bounds[3]);
657ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  grad_output_array.FillRandom(GetParam().random_value_var,
658ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                               GetParam().random_value_mean);
659ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
660ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  const int64 feature_index = GetParam().feature_index;
661ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  const int64 num_elements_per_feature =
662ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      Product(bounds) / bounds[feature_index];
663ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  const int64 feature_bound = bounds[feature_index];
664ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::vector<float> scale(feature_bound, 2);
665ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
666ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto input_squared =
667ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
668ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::vector<int64> reduce_dims;
669030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
670ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower    if (i != feature_index) {
671ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      reduce_dims.push_back(i);
672ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower    }
673ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  }
674ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
675ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto sum =
676ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
677ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                                  [](float a, float b) { return a + b; });
678ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
679ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto sum_squared =
680ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
681ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                                  [](float a, float b) { return a + b; });
682ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
683ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::vector<float> mean(feature_bound);
684ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
685ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
6861654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar    if (num_elements_per_feature > 0) {
6871654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar      mean[i] = sum[i] / num_elements_per_feature;
6881654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar    } else {
6891654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar      mean[i] = 0;
6901654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar    }
691ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  }
692ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
693ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::vector<float> mean_square(feature_bound);
694ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
695ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower    mean_square[i] = mean[i] * mean[i];
696ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  }
697ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
698ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::vector<float> square_mean(feature_bound);
699ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
7001654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar    if (num_elements_per_feature > 0) {
7011654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar      square_mean[i] = sum_squared[i] / num_elements_per_feature;
7021654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar    } else {
7031654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar      square_mean[i] = 0;
7041654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar    }
705ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  }
706ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
707ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::vector<float> var(feature_bound);
708ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
709ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower    var[i] = square_mean[i] - mean_square[i];
710ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  }
711ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
712f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  Array4D<float> mean4D =
713ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
714f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
715f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
716ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
717ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto var_add_epsilon = *ReferenceUtil::MapArray4D(
718030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      var4D, [epsilon](float a) { return a + epsilon; });
719030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
720030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D(
721030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      var_add_epsilon, [epsilon](float a) { return 1 / std::sqrt(a); });
722ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
723ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto grad_output_times_var =
724ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      *ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon,
725ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                                 [](float a, float b) { return a * b; });
726ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
727ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto activation_shifted = *ReferenceUtil::MapArray4D(
728f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower      input_array, mean4D, [](float a, float b) { return a - b; });
729ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
730030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto activation_shifted_times_grad_output =
731030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      *ReferenceUtil::MapArray4D(grad_output_array, activation_shifted,
732ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                                 [](float a, float b) { return a * b; });
733ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
734030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto grad_scale_before_reduction = *ReferenceUtil::MapArray4D(
735030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      activation_shifted_times_grad_output, rsqrt_var_add_epsilon,
736030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      [](float a, float b) { return a * b; });
737030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
738ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto grad_scale = ReferenceUtil::Reduce4DTo1D(
739ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      grad_scale_before_reduction, /*init=*/0.0f, reduce_dims,
740ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      [](float a, float b) { return a + b; });
741ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
742ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto grad_offset =
743ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(grad_output_array, /*init=*/0.0f, reduce_dims,
744ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                                  [](float a, float b) { return a + b; });
745ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
746030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto scale_times_rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D(
747030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      scale4D, rsqrt_var_add_epsilon, [](float a, float b) { return a * b; });
748030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
749030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto I1 = *ReferenceUtil::MapArray4D(
750030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      grad_output_array, [&](float a) { return num_elements_per_feature * a; });
751030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
752030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto I2 = *ReferenceUtil::Broadcast1DTo4D(grad_offset, bounds, feature_index);
753030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
754030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  // I3 = sum(output_grad * (activation - mean(activation)))
755030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto I3 = *ReferenceUtil::Broadcast1DTo4D(
756030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(activation_shifted_times_grad_output,
757030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower                                  /*init=*/0.0f, reduce_dims,
758030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower                                  [](float a, float b) { return a + b; }),
759030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      bounds, feature_index);
760030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
761030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  // I4 = (activation - mean(activation)) *
762030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  //   sum(output_grad * (activation - mean(activation)))
763030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto I4 = *ReferenceUtil::MapArray4D(I3, activation_shifted,
764030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower                                       [](float a, float b) { return a * b; });
765030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
766030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  // I5 = (activation - mean(activation)) *
767030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  //   sum(output_grad * (activation - mean(activation))) / (variance +
768030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  //   epsilon))
769030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto I5 = *ReferenceUtil::MapArray4D(I4, var_add_epsilon,
770030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower                                       [](float a, float b) { return a / b; });
771030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
772030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto grad_activation = *ReferenceUtil::MapArray4D(
773030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      I1, I2, [](float a, float b) { return a - b; });
774030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
775030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  grad_activation = *ReferenceUtil::MapArray4D(
776030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      grad_activation, I5, [](float a, float b) { return a - b; });
777030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
778030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  grad_activation = *ReferenceUtil::MapArray4D(
779030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      grad_activation, scale4D, [](float a, float b) { return a * b; });
780030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
781030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  grad_activation = *ReferenceUtil::MapArray4D(
7821654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar      grad_activation, rsqrt_var_add_epsilon, [=](float a, float b) {
7831654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar        if (num_elements_per_feature > 0) {
7841654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar          return a * b / num_elements_per_feature;
7851654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar        }
7861654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar        return 0.f;
7871654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar      });
788030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
789ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto expected_grad_activation =
790ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      Literal::CreateR4FromArray4D<float>(grad_activation);
791ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
792ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
793ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto scale_literal = Literal::CreateR1<float>(scale);
794ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto mean_literal = Literal::CreateR1<float>(mean);
795ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto var_literal = Literal::CreateR1<float>(var);
796ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto grad_output_literal =
797ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      Literal::CreateR4FromArray4D<float>(grad_output_array);
798ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
799ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto input_parameter = builder.Parameter(0, input_literal->shape(), "input");
800ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto scale_parameter = builder.Parameter(1, scale_literal->shape(), "scale");
801ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto mean_parameter = builder.Parameter(2, mean_literal->shape(), "mean");
802ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto var_parameter = builder.Parameter(3, var_literal->shape(), "variance");
803ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto grad_output_parameter =
804ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      builder.Parameter(4, grad_output_literal->shape(), "grad_output");
805ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
806ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::unique_ptr<GlobalData> input_data =
807ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      client_->TransferToServer(*input_literal).ConsumeValueOrDie();
808ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::unique_ptr<GlobalData> scale_data =
809ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
810ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::unique_ptr<GlobalData> mean_data =
811ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
812ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::unique_ptr<GlobalData> var_data =
813ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      client_->TransferToServer(*var_literal).ConsumeValueOrDie();
814ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::unique_ptr<GlobalData> grad_output_data =
815ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie();
816ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
817ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto t = builder.BatchNormGrad(input_parameter, scale_parameter,
818ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                                 mean_parameter, var_parameter,
819ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                                 grad_output_parameter, epsilon, feature_index);
820ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
821ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto expected =
8227d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan      Literal::MakeTuple({expected_grad_activation.get(),
8237d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan                          Literal::CreateR1<float>(grad_scale).get(),
8247d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan                          Literal::CreateR1<float>(grad_offset).get()});
825ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
8261654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar  // Run all HLO passes during this test.  In particular, ClientLibraryTestBase
8271654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar  // disables constant folding, but we want it enabled for our zero-sized tensor
8281654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar  // testcase.
8291654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar  execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
8301654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar
8317d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan  ComputeAndCompareTuple(&builder, *expected,
832ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                         {input_data.get(), scale_data.get(), mean_data.get(),
833ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                          var_data.get(), grad_output_data.get()},
834ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                         ErrorSpec(0.01, 1));
835ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower}
836ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
8371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace
8381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace xla
839