batch_normalization_test.cc revision 913175c2bd38f6e97de399b29cfe1195bffbaa25
11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <cmath>
171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <memory>
181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <vector>
191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/array2d.h"
211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/array4d.h"
221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation.h"
231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation_builder.h"
241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/local_client.h"
261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/literal_util.h"
271464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/reference_util.h"
281464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_computation.h"
291464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_instruction.h"
301464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_module.h"
311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/shape_util.h"
321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/statusor.h"
331464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/test.h"
341464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/test_helpers.h"
351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
361464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/literal_test_util.h"
381464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/tests/test_macros.h"
391464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/tests/test_utils.h"
401464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/util.h"
411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/xla_data.pb.h"
42913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar#include "tensorflow/core/lib/strings/str_util.h"
431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/logging.h"
441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/test.h"
451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/types.h"
461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla {
481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace {
491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass BatchNormalizationTest : public ClientLibraryTestBase {
511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins protected:
521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  BatchNormalizationTest() : input_array_(kSamples, kZ, kY, kX) {
531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    Array2D<float> pz({
541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        // z0 z1
551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        {-1.0f, 4.1f},  // p0
561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        {2.0f, 4.1f},   // p1
571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        {5.0f, 4.4f},   // p2
581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    });
591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    input_array_.FillWithPZ(pz);
6046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower    input_literal_ = *Literal::CreateR4FromArray4D(input_array_);
611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    CHECK_EQ(kSamples, input_array_.planes());
621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    CHECK_EQ(kZ, input_array_.depth());
631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    CHECK_EQ(kY, input_array_.height());
641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    CHECK_EQ(kY, input_array_.width());
651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  static constexpr int64 kSamples = 3;
681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  static constexpr int64 kX = 1;
691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  static constexpr int64 kY = 1;
701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  static constexpr int64 kZ = 2;
711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array4D<float> input_array_;
731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Literal input_literal_;
741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const ErrorSpec error_spec_{0.001, 0.001};
751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, SubtractInZ) {
781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder builder(client_, "subtract_in_z_one_sample");
791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto x = builder.ConstantLiteral(input_literal_);
801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto y = builder.ConstantR1<float>({3.14, 4.25});
811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  builder.Sub(x, y, /*broadcast_dimensions=*/{1});
821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array4D<float> expected(kSamples, kZ, kY, kX);
841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array2D<float> pz({
851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      {-1.0f - 3.14f, 4.1f - 4.25f},  // p0
861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      {2.0f - 3.14f, 4.1f - 4.25f},   // p1
871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      {5.0f - 3.14f, 4.4f - 4.25f},   // p2
881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  });
891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  expected.FillWithPZ(pz);
901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, SquareTesseractElementwise) {
941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder builder(client_, "square_tesseract_elementwise");
951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto x = builder.ConstantLiteral(input_literal_);
961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  builder.SquareF32(x);
971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array4D<float> expected(kSamples, kZ, kY, kX);
991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array2D<float> expected_pz({
1001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      {std::pow(-1.0f, 2.0f), std::pow(4.1f, 2.0f)},
1011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      {std::pow(2.0f, 2.0f), std::pow(4.1f, 2.0f)},
1021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      {std::pow(5.0f, 2.0f), std::pow(4.4f, 2.0f)},
1031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  });
1041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  expected.FillWithPZ(expected_pz);
1051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, SumToZ) {
1091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder builder(client_, "sum_to_z");
1101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto input_activations = builder.ConstantLiteral(input_literal_);
1111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Computation add = CreateScalarAddComputation(F32, &builder);
1121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Reduce all but the Z dimension.
1131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add,
1141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                 {0, 2, 3});
1151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<float> expected = {6, 12.6};
1171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
1181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, SquareAndReduce) {
1211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder builder(client_, "square_and_reduce");
1221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto input_activations = builder.ConstantLiteral(input_literal_);
1231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto set_means = builder.ConstantR1<float>({2.f, 4.2f});
1241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto activation_deviations = builder.Sub(input_activations, set_means,
1251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                           /*broadcast_dimensions=*/{1});
1261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Computation add = CreateScalarAddComputation(F32, &builder);
1271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto dev_squares = builder.SquareF32(activation_deviations);
1281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto sum_of_squares = builder.Reduce(
1291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      dev_squares, builder.ConstantR0<float>(0.0f), add, {0, 2, 3});
1301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<float> expected = {18, 0.06};
1321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, VarianceToStddev) {
1361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder builder(client_, "variance_to_stddev");
1371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto variance = builder.ConstantR1<float>({6.f, .02f});
1381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto sqrt = builder.SqrtF32(variance);
1391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<float> expected = {2.44948974f, 0.14142136f};
1411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
1421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// Compare against a forward batch normalization example in the NN spec
1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// reference.
1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, SpecComparisonForward) {
1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder builder(client_, "batch_normalize_per_spec");
1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto input_activations =
1491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.CheckShape(builder.ConstantLiteral(input_literal_),
1501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                         ShapeUtil::MakeShape(F32, {3, 2, 1, 1}));
1511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto gamma = builder.ConstantR1<float>({1.0, 1.0});
1521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto beta = builder.ConstantR1<float>({0.0, 0.0});
1531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Computation add = CreateScalarAddComputation(F32, &builder);
1541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Reduce all dimensions except dimension 1.
1551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2});
1561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto sum = builder.CheckShape(
1571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add,
1581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                     /*dimensions_to_reduce=*/{0, 2, 3}),
1591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      TwoElementVectorF32);
1601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto input_shape = builder.GetShape(input_activations).ConsumeValueOrDie();
1611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto sum_shape = builder.GetShape(sum).ConsumeValueOrDie();
1621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto count = builder.ConstantR0<float>(ShapeUtil::ElementsIn(*input_shape) /
1631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                         ShapeUtil::ElementsIn(*sum_shape));
1641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto set_means = builder.Div(sum, count);
1651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const float kEpsilon = 1e-9f;
1671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto epsilon = builder.ConstantR0<float>(kEpsilon);
1681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto epsilon2 = builder.ConstantR1<float>({kEpsilon, kEpsilon});
1691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto activation_deviations = builder.Sub(input_activations, set_means,
1701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                           /*broadcast_dimensions=*/{1});
1711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto dev_squares = builder.SquareF32(activation_deviations);
1721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto sum_of_squares = builder.CheckShape(
1731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.Reduce(dev_squares, builder.ConstantR0<float>(0.0f), add,
1741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                     /*dimensions_to_reduce=*/{0, 2, 3}),
1751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      TwoElementVectorF32);
1761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto variance = builder.Div(sum_of_squares, count);
1771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto standard_deviation = builder.SqrtF32(variance);
1781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto standard_deviation_above_epsilon = builder.CheckShape(
1791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2}));
1801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto gt_eps = builder.Select(standard_deviation_above_epsilon,
1811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                               standard_deviation, epsilon2);
1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto normalization_factors = builder.ReciprocalF32(gt_eps);
1831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto normalized_input_activations =
1841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.Mul(activation_deviations, normalization_factors,
1851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  /*broadcast_dimensions=*/{1});
1861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  /* auto output_activations = */ builder.Add(
1871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.Mul(normalized_input_activations, gamma,
1881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  /*broadcast_dimensions=*/{1}),
1891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      beta, /*broadcast_dimensions=*/{1});
1901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array4D<float> expected(kSamples, kZ, kY, kX);
1921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array2D<float> pz({
1931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      {-3.f / std::sqrt(6.f), -.1f / std::sqrt(.02f)},
1941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      {0.f, -.1f / std::sqrt(.02f)},
1951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      {3.f / std::sqrt(6.f), .2f / std::sqrt(.02f)},
1961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  });
1971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  expected.FillWithPZ(pz);
1981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
2001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
2011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2021464b9930de871fd11870941963253670f737c23A. Unique TensorFlowerstruct BatchNormTestParam {
2031464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<int64> bounds;
2041464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  int64 feature_index;
2051464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  float random_value_mean;
2061464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  float random_value_var;
207913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar
208913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar  friend ::std::ostream& operator<<(::std::ostream& os,
209913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar                                    const BatchNormTestParam& p) {
210913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar    os << "bounds={" << tensorflow::str_util::Join(p.bounds, ", ") << "}, ";
211913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar    os << "feature_index=" << p.feature_index << ", ";
212913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar    os << "random_value_mean=" << p.random_value_mean << ", ";
213913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar    os << "random_value_var=" << p.random_value_var;
214913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar    return os;
215913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar  }
2161464b9930de871fd11870941963253670f737c23A. Unique TensorFlower};
2171464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2181464b9930de871fd11870941963253670f737c23A. Unique TensorFlower// Tests to test the fused operation of BatchNorm.
2191464b9930de871fd11870941963253670f737c23A. Unique TensorFlowerclass BatchNormTest : public ClientLibraryTestBase,
2201464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      public ::testing::WithParamInterface<BatchNormTestParam> {
2211464b9930de871fd11870941963253670f737c23A. Unique TensorFlower};
2221464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
223f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_P(BatchNormTest, RandomizedTests) {
2241464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  float epsilon = 0.001;
2251464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputationBuilder builder(client_, TestName());
2261464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  const std::vector<int64>& bounds = GetParam().bounds;
2271464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
2281464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  input_array.FillRandom(GetParam().random_value_var,
2291464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                         GetParam().random_value_mean);
2301464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2311464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  const int64 feature_index = GetParam().feature_index;
2321464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  const int64 num_elements_per_feature =
2331464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      Product(bounds) / bounds[feature_index];
2341464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  const int64 feature_bound = bounds[feature_index];
2351464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<float> offset(feature_bound, 1);
2361464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<float> scale(feature_bound, 2);
2371464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2381464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto input_squared =
2391464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
2401464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<int64> reduce_dims;
241030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
2421464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    if (i != feature_index) {
2431464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      reduce_dims.push_back(i);
2441464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    }
2451464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  }
2461464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2471464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto sum =
2481464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
2491464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                  [](float a, float b) { return a + b; });
2501464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2511464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto sum_squared =
2521464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
2531464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                  [](float a, float b) { return a + b; });
2541464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2551464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<float> mean(feature_bound);
2561464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2571464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
2581464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    mean[i] = sum[i] / num_elements_per_feature;
2591464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  }
2601464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2611464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<float> mean_square(feature_bound);
2621464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
2631464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    mean_square[i] = mean[i] * mean[i];
2641464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  }
2651464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2661464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<float> square_mean(feature_bound);
2671464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
2681464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    square_mean[i] = sum_squared[i] / num_elements_per_feature;
2691464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  }
2701464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2711464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<float> var(feature_bound);
2721464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
2731464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    var[i] = square_mean[i] - mean_square[i];
2741464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  }
2751464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
276f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  Array4D<float> mean4D =
2771464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
278f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
279f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
280f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  auto offset4D =
2811464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
2821464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
283f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
284f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower                                                scale4D, offset4D, epsilon);
2851464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2861464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto expected_normalized = Literal::CreateR4FromArray4D<float>(normalized);
2871464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2881464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto offset_literal = Literal::CreateR1<float>(offset);
2891464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto scale_literal = Literal::CreateR1<float>(scale);
2901464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
2911464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2921464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto input_activations =
2931464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      builder.Parameter(0, input_literal->shape(), "input");
2941464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto scale_activations =
2951464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      builder.Parameter(1, scale_literal->shape(), "offset");
2961464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto offset_activations =
2971464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      builder.Parameter(2, offset_literal->shape(), "scale");
2981464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2991464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto expected = *Literal::MakeTuple({expected_normalized.get(),
3001464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                       Literal::CreateR1<float>(mean).get(),
3011464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                       Literal::CreateR1<float>(var).get()});
3021464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
3031464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::unique_ptr<GlobalData> input_data =
3041464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      client_->TransferToServer(*input_literal).ConsumeValueOrDie();
3051464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::unique_ptr<GlobalData> scale_data =
3061464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
3071464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::unique_ptr<GlobalData> offset_data =
3081464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
3091464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
3101464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  builder.BatchNormTraining(input_activations, scale_activations,
3111464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                            offset_activations, epsilon, feature_index);
3121464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
3131464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputeAndCompareTuple(
3141464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      &builder, expected,
3151464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      {input_data.get(), scale_data.get(), offset_data.get()},
3161464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      ErrorSpec(0.01, 1));
3171464b9930de871fd11870941963253670f737c23A. Unique TensorFlower}
3181464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
3197359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlowerXLA_TEST_P(BatchNormTest, RandomizedInferencingTests) {
3207359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  float epsilon = 0.001;
3217359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  ComputationBuilder builder(client_, TestName());
3227359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  const std::vector<int64>& bounds = GetParam().bounds;
3237359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
3247359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  input_array.FillRandom(GetParam().random_value_var,
3257359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower                         GetParam().random_value_mean);
3267359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3277359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  const int64 feature_index = GetParam().feature_index;
3287359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  const int64 num_elements_per_feature =
3297359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      Product(bounds) / bounds[feature_index];
3307359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  const int64 feature_bound = bounds[feature_index];
3317359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::vector<float> offset(feature_bound, 1);
3327359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::vector<float> scale(feature_bound, 2);
3337359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3347359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto input_squared =
3357359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
3367359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::vector<int64> reduce_dims;
3377359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
3387359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower    if (i != feature_index) {
3397359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      reduce_dims.push_back(i);
3407359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower    }
3417359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  }
3427359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3437359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto sum =
3447359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
3457359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower                                  [](float a, float b) { return a + b; });
3467359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3477359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto sum_squared =
3487359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
3497359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower                                  [](float a, float b) { return a + b; });
3507359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3517359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::vector<float> mean(feature_bound);
3527359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3537359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
3547359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower    mean[i] = sum[i] / num_elements_per_feature;
3557359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  }
3567359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3577359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::vector<float> mean_square(feature_bound);
3587359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
3597359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower    mean_square[i] = mean[i] * mean[i];
3607359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  }
3617359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3627359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::vector<float> square_mean(feature_bound);
3637359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
3647359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower    square_mean[i] = sum_squared[i] / num_elements_per_feature;
3657359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  }
3667359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3677359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::vector<float> var(feature_bound);
3687359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
3697359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower    var[i] = square_mean[i] - mean_square[i];
3707359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  }
3717359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3727359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  Array4D<float> mean4D =
3737359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
3747359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
3757359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
3767359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto offset4D =
3777359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
3787359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3797359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
3807359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower                                                scale4D, offset4D, epsilon);
3817359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3827359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto offset_literal = Literal::CreateR1<float>(offset);
3837359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto scale_literal = Literal::CreateR1<float>(scale);
3847359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto mean_literal = Literal::CreateR1<float>(mean);
3857359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto var_literal = Literal::CreateR1<float>(var);
3867359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
3877359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3887359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto input_activations =
3897359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      builder.Parameter(0, input_literal->shape(), "input");
3907359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto scale_activations =
3917359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      builder.Parameter(1, scale_literal->shape(), "offset");
3927359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto offset_activations =
3937359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      builder.Parameter(2, offset_literal->shape(), "scale");
3947359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto mean_activations = builder.Parameter(3, mean_literal->shape(), "mean");
3957359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto variance_activations =
3967359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      builder.Parameter(4, var_literal->shape(), "variance");
3977359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3987359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  Array4D<float> expected = normalized;
3997359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
4007359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::unique_ptr<GlobalData> input_data =
4017359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      client_->TransferToServer(*input_literal).ConsumeValueOrDie();
4027359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::unique_ptr<GlobalData> scale_data =
4037359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
4047359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::unique_ptr<GlobalData> offset_data =
4057359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
4067359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::unique_ptr<GlobalData> mean_data =
4077359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
4087359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::unique_ptr<GlobalData> variance_data =
4097359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      client_->TransferToServer(*var_literal).ConsumeValueOrDie();
4107359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
4117359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  builder.BatchNormInference(input_activations, scale_activations,
4127359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower                             offset_activations, mean_activations,
4137359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower                             variance_activations, epsilon, feature_index);
4147359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
4157359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  ComputeAndCompareR4<float>(
4167359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      &builder, expected,
4177359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      {input_data.get(), scale_data.get(), offset_data.get(), mean_data.get(),
4187359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower       variance_data.get()},
4197359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      ErrorSpec(0.01, 1));
4207359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower}
4217359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
422f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_P(BatchNormTest, RandomizedGradTests) {
423ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  float epsilon = 0.001;
424ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  ComputationBuilder builder(client_, TestName());
425ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  const std::vector<int64>& bounds = GetParam().bounds;
426ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
427ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  input_array.FillRandom(GetParam().random_value_var,
428ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                         GetParam().random_value_mean);
429ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
430ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  Array4D<float> grad_output_array(bounds[0], bounds[1], bounds[2], bounds[3]);
431ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  grad_output_array.FillRandom(GetParam().random_value_var,
432ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                               GetParam().random_value_mean);
433ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
434ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  const int64 feature_index = GetParam().feature_index;
435ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  const int64 num_elements_per_feature =
436ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      Product(bounds) / bounds[feature_index];
437ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  const int64 feature_bound = bounds[feature_index];
438ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::vector<float> scale(feature_bound, 2);
439ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
440ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto input_squared =
441ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
442ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::vector<int64> reduce_dims;
443030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
444ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower    if (i != feature_index) {
445ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      reduce_dims.push_back(i);
446ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower    }
447ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  }
448ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
449ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto sum =
450ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
451ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                                  [](float a, float b) { return a + b; });
452ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
453ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto sum_squared =
454ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
455ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                                  [](float a, float b) { return a + b; });
456ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
457ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::vector<float> mean(feature_bound);
458ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
459ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
460ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower    mean[i] = sum[i] / num_elements_per_feature;
461ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  }
462ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
463ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::vector<float> mean_square(feature_bound);
464ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
465ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower    mean_square[i] = mean[i] * mean[i];
466ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  }
467ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
468ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::vector<float> square_mean(feature_bound);
469ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
470ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower    square_mean[i] = sum_squared[i] / num_elements_per_feature;
471ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  }
472ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
473ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::vector<float> var(feature_bound);
474ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
475ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower    var[i] = square_mean[i] - mean_square[i];
476ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  }
477ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
478f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  Array4D<float> mean4D =
479ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
480f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
481f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
482ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
483ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto var_add_epsilon = *ReferenceUtil::MapArray4D(
484030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      var4D, [epsilon](float a) { return a + epsilon; });
485030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
486030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D(
487030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      var_add_epsilon, [epsilon](float a) { return 1 / std::sqrt(a); });
488ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
489ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto grad_output_times_var =
490ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      *ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon,
491ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                                 [](float a, float b) { return a * b; });
492ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
493ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto activation_shifted = *ReferenceUtil::MapArray4D(
494f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower      input_array, mean4D, [](float a, float b) { return a - b; });
495ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
496030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto activation_shifted_times_grad_output =
497030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      *ReferenceUtil::MapArray4D(grad_output_array, activation_shifted,
498ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                                 [](float a, float b) { return a * b; });
499ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
500030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto grad_scale_before_reduction = *ReferenceUtil::MapArray4D(
501030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      activation_shifted_times_grad_output, rsqrt_var_add_epsilon,
502030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      [](float a, float b) { return a * b; });
503030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
504ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto grad_scale = ReferenceUtil::Reduce4DTo1D(
505ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      grad_scale_before_reduction, /*init=*/0.0f, reduce_dims,
506ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      [](float a, float b) { return a + b; });
507ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
508ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto grad_offset =
509ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(grad_output_array, /*init=*/0.0f, reduce_dims,
510ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                                  [](float a, float b) { return a + b; });
511ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
512030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto scale_times_rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D(
513030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      scale4D, rsqrt_var_add_epsilon, [](float a, float b) { return a * b; });
514030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
515030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto I1 = *ReferenceUtil::MapArray4D(
516030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      grad_output_array, [&](float a) { return num_elements_per_feature * a; });
517030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
518030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto I2 = *ReferenceUtil::Broadcast1DTo4D(grad_offset, bounds, feature_index);
519030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
520030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  // I3 = sum(output_grad * (activation - mean(activation)))
521030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto I3 = *ReferenceUtil::Broadcast1DTo4D(
522030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(activation_shifted_times_grad_output,
523030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower                                  /*init=*/0.0f, reduce_dims,
524030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower                                  [](float a, float b) { return a + b; }),
525030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      bounds, feature_index);
526030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
527030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  // I4 = (activation - mean(activation)) *
528030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  //   sum(output_grad * (activation - mean(activation)))
529030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto I4 = *ReferenceUtil::MapArray4D(I3, activation_shifted,
530030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower                                       [](float a, float b) { return a * b; });
531030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
532030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  // I5 = (activation - mean(activation)) *
533030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  //   sum(output_grad * (activation - mean(activation))) / (variance +
534030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  //   epsilon))
535030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto I5 = *ReferenceUtil::MapArray4D(I4, var_add_epsilon,
536030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower                                       [](float a, float b) { return a / b; });
537030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
538030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto grad_activation = *ReferenceUtil::MapArray4D(
539030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      I1, I2, [](float a, float b) { return a - b; });
540030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
541030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  grad_activation = *ReferenceUtil::MapArray4D(
542030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      grad_activation, I5, [](float a, float b) { return a - b; });
543030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
544030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  grad_activation = *ReferenceUtil::MapArray4D(
545030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      grad_activation, scale4D, [](float a, float b) { return a * b; });
546030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
547030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  grad_activation = *ReferenceUtil::MapArray4D(
548030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      grad_activation, rsqrt_var_add_epsilon,
549030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      [=](float a, float b) { return a * b / num_elements_per_feature; });
550030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
551ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto expected_grad_activation =
552ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      Literal::CreateR4FromArray4D<float>(grad_activation);
553ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
554ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
555ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto scale_literal = Literal::CreateR1<float>(scale);
556ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto mean_literal = Literal::CreateR1<float>(mean);
557ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto var_literal = Literal::CreateR1<float>(var);
558ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto grad_output_literal =
559ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      Literal::CreateR4FromArray4D<float>(grad_output_array);
560ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
561ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto input_parameter = builder.Parameter(0, input_literal->shape(), "input");
562ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto scale_parameter = builder.Parameter(1, scale_literal->shape(), "scale");
563ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto mean_parameter = builder.Parameter(2, mean_literal->shape(), "mean");
564ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto var_parameter = builder.Parameter(3, var_literal->shape(), "variance");
565ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto grad_output_parameter =
566ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      builder.Parameter(4, grad_output_literal->shape(), "grad_output");
567ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
568ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::unique_ptr<GlobalData> input_data =
569ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      client_->TransferToServer(*input_literal).ConsumeValueOrDie();
570ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::unique_ptr<GlobalData> scale_data =
571ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
572ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::unique_ptr<GlobalData> mean_data =
573ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
574ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::unique_ptr<GlobalData> var_data =
575ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      client_->TransferToServer(*var_literal).ConsumeValueOrDie();
576ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::unique_ptr<GlobalData> grad_output_data =
577ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie();
578ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
579ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto t = builder.BatchNormGrad(input_parameter, scale_parameter,
580ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                                 mean_parameter, var_parameter,
581ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                                 grad_output_parameter, epsilon, feature_index);
582ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
583ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto expected =
584ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      *Literal::MakeTuple({expected_grad_activation.get(),
585ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                           Literal::CreateR1<float>(grad_scale).get(),
586ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                           Literal::CreateR1<float>(grad_offset).get()});
587ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
588ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  ComputeAndCompareTuple(&builder, expected,
589ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                         {input_data.get(), scale_data.get(), mean_data.get(),
590ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                          var_data.get(), grad_output_data.get()},
591ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                         ErrorSpec(0.01, 1));
592ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower}
593ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
5941464b9930de871fd11870941963253670f737c23A. Unique TensorFlowerINSTANTIATE_TEST_CASE_P(
5951464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    BatchNormTest_Instantiation, BatchNormTest,
5961464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    ::testing::Values(BatchNormTestParam{{2, 2, 2, 2}, 0, 100.2f, 200.0f},
5971464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      BatchNormTestParam{{2, 2, 2, 2}, 3, 300.f, 400.0f},
5981464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
5991464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      BatchNormTestParam{{1, 10, 1, 1}, 0, 10.1f, 20.1f},
6001464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      BatchNormTestParam{{10, 10, 10, 10}, 1, 3.14f, 314.15f},
6011464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      BatchNormTestParam{{10, 10, 10, 10}, 2, 666.6f, 777.7f},
6021464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      BatchNormTestParam{{10, 10, 10, 10}, 1, -666.6f, 777.7f},
6031464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      BatchNormTestParam{{10, 10, 10, 10}, 2, 0.f, 777.7f},
6044ab3342eb5fdae864e0a32a0e460a27527d79997A. Unique TensorFlower                      BatchNormTestParam{{1, 1, 10, 130}, 2, 0.f, 777.7f},
605ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                      BatchNormTestParam{{1, 1, 130, 11}, 2, 0.f, 777.7f},
6061464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      BatchNormTestParam{{1, 1, 10, 1}, 3, 888.8f, 9.9f},
6071464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6081464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      BatchNormTestParam{{24, 129, 1, 2}, 2, 10000, 10000},
6091464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      BatchNormTestParam{{24, 129, 1, 2}, 3, 10000, 10000},
6101464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6111464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      // Feature on low dimension to trigger relayout, test
6121464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      // internal logical to physical dimension calculation
6131464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      // is correct after relayout.
6141464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      BatchNormTestParam{{1, 2, 3, 4}, 0, 100, 100}));
6151464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
616f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_F(BatchNormTest, BasicTraining) {
6171464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  const int kFeatureIndex = 3;
6181464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputationBuilder builder(client_, TestName());
6191464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6201464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto operand = builder.ConstantR4FromArray4D<float>(
6211464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}});
6221464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6231464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto scale = builder.ConstantR1<float>({2.0f, 3.0f});
6241464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6251464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto offset = builder.ConstantR1<float>({1.0f, 2.0f});
6261464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6271464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto tuple = builder.BatchNormTraining(operand, scale, offset,
6281464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                         /*epsilon=*/0.001, kFeatureIndex);
6291464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6301464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto expected = *Literal::MakeTuple(
6311464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      {Literal::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
6321464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                 {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}})
6331464b9930de871fd11870941963253670f737c23A. Unique TensorFlower           .get(),
6341464b9930de871fd11870941963253670f737c23A. Unique TensorFlower       Literal::CreateR1<float>({4, 5}).get(),
6351464b9930de871fd11870941963253670f737c23A. Unique TensorFlower       Literal::CreateR1<float>({5, 5}).get()});
6361464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6371464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
6381464b9930de871fd11870941963253670f737c23A. Unique TensorFlower}
6391464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
640f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_F(BatchNormTest, BasicTrainingOnSublane) {
6411464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  const int kFeatureIndex = 2;
6421464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputationBuilder builder(client_, TestName());
6431464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6441464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto operand = builder.ConstantR4FromArray4D<float>(
6451464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}});
6461464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6471464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto scale = builder.ConstantR1<float>({2.0f, 3.0f});
6481464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6491464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto offset = builder.ConstantR1<float>({1.0f, 2.0f});
6501464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6511464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto tuple = builder.BatchNormTraining(operand, scale, offset,
6521464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                         /*epsilon=*/0.001, kFeatureIndex);
6531464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6541464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto expected = *Literal::MakeTuple(
6551464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      {Literal::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
6561464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                 {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}})
6571464b9930de871fd11870941963253670f737c23A. Unique TensorFlower           .get(),
6581464b9930de871fd11870941963253670f737c23A. Unique TensorFlower       Literal::CreateR1<float>({4, 5}).get(),
6591464b9930de871fd11870941963253670f737c23A. Unique TensorFlower       Literal::CreateR1<float>({5, 5}).get()});
6601464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6611464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
6621464b9930de871fd11870941963253670f737c23A. Unique TensorFlower}
6631464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6644c15e948e0e4a17329b2363a5f8f3d4b4178ef7bA. Unique TensorFlowerXLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(TrainingWithFeatureOnLowDimension)) {
6651464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  // Use 0 dimension as feature, tests layout analyzer.
6661464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  const int kFeatureIndex = 0;
6671464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputationBuilder builder(client_, TestName());
6681464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6691464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputationDataHandle h0;
6701464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto operand = CreateR3Parameter<float>(Array3D<float>(260, 2, 2, 1.0f),
6711464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                          /*parameter_number=*/0, "operand",
6721464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                          &builder, &h0);
6731464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputationDataHandle h1;
6741464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto scale =
6751464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      CreateR1Parameter<float>(std::vector<float>(260, 1.0f),
6761464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                               /*parameter_number=*/1, "scale", &builder, &h1);
6771464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputationDataHandle h2;
6781464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto offset =
6791464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      CreateR1Parameter<float>(std::vector<float>(260, 1.0f),
6801464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                               /*parameter_number=*/2, "offset", &builder, &h2);
6811464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6821464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto tuple = builder.BatchNormTraining(h0, h1, h2,
6831464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                         /*epsilon=*/1, kFeatureIndex);
6841464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6851464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto expected = *Literal::MakeTuple(
6861464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      {Literal::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f))
6871464b9930de871fd11870941963253670f737c23A. Unique TensorFlower           .get(),
6881464b9930de871fd11870941963253670f737c23A. Unique TensorFlower       Literal::CreateR1<float>(std::vector<float>(260, 1.0f)).get(),
6891464b9930de871fd11870941963253670f737c23A. Unique TensorFlower       Literal::CreateR1<float>(std::vector<float>(260, 0.0f)).get()});
6901464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6911464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputeAndCompareTuple(&builder, expected,
6921464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                         {operand.get(), scale.get(), offset.get()},
6931464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                         ErrorSpec(0.1));
6941464b9930de871fd11870941963253670f737c23A. Unique TensorFlower}
6951464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
696f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_F(BatchNormTest, LargeEpsilonTest) {
6971464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  // Test the correctness of choosing a large epsilon value.
6981464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  const int kFeatureIndex = 2;
6991464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputationBuilder builder(client_, TestName());
7001464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
7011464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputationDataHandle h0;
7021464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto operand = CreateR3Parameter<float>({{{0.0f}, {10.0f}, {20.0f}, {30.0f}}},
7031464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                          /*parameter_number=*/0, "operand",
7041464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                          &builder, &h0);
7051464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputationDataHandle h1;
7061464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto scale =
7071464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      CreateR1Parameter<float>(std::vector<float>(1, 1.0f),
7081464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                               /*parameter_number=*/1, "scale", &builder, &h1);
7091464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputationDataHandle h2;
7101464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto offset =
7111464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      CreateR1Parameter<float>(std::vector<float>(1, 0.0f),
7121464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                               /*parameter_number=*/2, "offset", &builder, &h2);
7131464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
7141464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  // var = 125, mean = 15, epsilon = -100
7151464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto tuple = builder.BatchNormTraining(h0, h1, h2,
7161464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                         /*epsilon=*/-100, kFeatureIndex);
7171464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
7181464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto expected = *Literal::MakeTuple(
7191464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      {Literal::CreateR3FromArray3D<float>({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}})
7201464b9930de871fd11870941963253670f737c23A. Unique TensorFlower           .get(),
7211464b9930de871fd11870941963253670f737c23A. Unique TensorFlower       Literal::CreateR1<float>(std::vector<float>(1, 15.0f)).get(),
7221464b9930de871fd11870941963253670f737c23A. Unique TensorFlower       Literal::CreateR1<float>(std::vector<float>(1, 125.0f)).get()});
7231464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
7241464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputeAndCompareTuple(&builder, expected,
7251464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                         {operand.get(), scale.get(), offset.get()},
7261464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                         ErrorSpec(0.1));
7271464b9930de871fd11870941963253670f737c23A. Unique TensorFlower}
7281464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
729f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_F(BatchNormTest, BatchNormGradBasic) {
730ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  const int kFeatureIndex = 2;
731ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  ComputationBuilder builder(client_, TestName());
732ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
733ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto operand =
734ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      builder.ConstantR4FromArray4D<float>(Array4D<float>(2, 2, 2, 1, 0.0f));
735ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
736ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto scale = builder.ConstantR1<float>({1.0f, 1.0f});
737ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
738ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto mean = builder.ConstantR1<float>({0.0f, 0.0f});
739ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
740ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto var = builder.ConstantR1<float>({1.0f, 1.0f});
741ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
742ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto grad_output = builder.ConstantR4FromArray4D<float>(
743ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}});
744ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
745ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  builder.BatchNormGrad(operand, scale, mean, var, grad_output,
746ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                        /*epsilon=*/0.0, kFeatureIndex);
747ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
748ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto expected = *Literal::MakeTuple(
749030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      {Literal::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
750030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower                                 {{{1.f}, {1.f}}, {{3.f}, {3.f}}}})
751ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower           .get(),
752ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower       Literal::CreateR1<float>({0, 0}).get(),
753ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower       Literal::CreateR1<float>({16, 20}).get()});
754ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
755ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
756ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower}
757ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
7581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace
7591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace xla
760