batch_normalization_test.cc revision 3d05024db5fc0c39ee3f5626639bd611f44ac03c
11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <cmath>
171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <memory>
181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <vector>
191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/array2d.h"
211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/array4d.h"
221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation.h"
231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation_builder.h"
241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/local_client.h"
261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/literal_util.h"
271464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/reference_util.h"
281464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_computation.h"
291464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_instruction.h"
301464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_module.h"
311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/shape_util.h"
321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/statusor.h"
331464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/test.h"
341464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/test_helpers.h"
351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
361464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/literal_test_util.h"
381464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/tests/test_macros.h"
391464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/tests/test_utils.h"
401464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/util.h"
411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/xla_data.pb.h"
423d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower#include "tensorflow/core/lib/math/math_util.h"
43913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar#include "tensorflow/core/lib/strings/str_util.h"
441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/logging.h"
451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/test.h"
461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/types.h"
471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla {
491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace {
501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass BatchNormalizationTest : public ClientLibraryTestBase {
521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins protected:
531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  BatchNormalizationTest() : input_array_(kSamples, kZ, kY, kX) {
541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    Array2D<float> pz({
551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        // z0 z1
561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        {-1.0f, 4.1f},  // p0
571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        {2.0f, 4.1f},   // p1
581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        {5.0f, 4.4f},   // p2
591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    });
601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    input_array_.FillWithPZ(pz);
6146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower    input_literal_ = *Literal::CreateR4FromArray4D(input_array_);
621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    CHECK_EQ(kSamples, input_array_.planes());
631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    CHECK_EQ(kZ, input_array_.depth());
641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    CHECK_EQ(kY, input_array_.height());
651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    CHECK_EQ(kY, input_array_.width());
661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  static constexpr int64 kSamples = 3;
691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  static constexpr int64 kX = 1;
701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  static constexpr int64 kY = 1;
711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  static constexpr int64 kZ = 2;
721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array4D<float> input_array_;
741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Literal input_literal_;
751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const ErrorSpec error_spec_{0.001, 0.001};
761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, SubtractInZ) {
791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder builder(client_, "subtract_in_z_one_sample");
801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto x = builder.ConstantLiteral(input_literal_);
811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto y = builder.ConstantR1<float>({3.14, 4.25});
821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  builder.Sub(x, y, /*broadcast_dimensions=*/{1});
831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array4D<float> expected(kSamples, kZ, kY, kX);
851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array2D<float> pz({
861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      {-1.0f - 3.14f, 4.1f - 4.25f},  // p0
871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      {2.0f - 3.14f, 4.1f - 4.25f},   // p1
881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      {5.0f - 3.14f, 4.4f - 4.25f},   // p2
891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  });
901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  expected.FillWithPZ(pz);
911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, SquareTesseractElementwise) {
951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder builder(client_, "square_tesseract_elementwise");
961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto x = builder.ConstantLiteral(input_literal_);
971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  builder.SquareF32(x);
981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
993d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower  using tensorflow::MathUtil;
1003d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower
1011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array4D<float> expected(kSamples, kZ, kY, kX);
1021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array2D<float> expected_pz({
1033d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower      {MathUtil::IPow(-1.0f, 2), MathUtil::IPow(4.1f, 2)},
1043d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower      {MathUtil::IPow(2.0f, 2), MathUtil::IPow(4.1f, 2)},
1053d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower      {MathUtil::IPow(5.0f, 2), MathUtil::IPow(4.4f, 2)},
1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  });
1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  expected.FillWithPZ(expected_pz);
1081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
1091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, SumToZ) {
1121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder builder(client_, "sum_to_z");
1131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto input_activations = builder.ConstantLiteral(input_literal_);
1141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Computation add = CreateScalarAddComputation(F32, &builder);
1151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Reduce all but the Z dimension.
1161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add,
1171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                 {0, 2, 3});
1181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<float> expected = {6, 12.6};
1201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
1211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, SquareAndReduce) {
1241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder builder(client_, "square_and_reduce");
1251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto input_activations = builder.ConstantLiteral(input_literal_);
1261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto set_means = builder.ConstantR1<float>({2.f, 4.2f});
1271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto activation_deviations = builder.Sub(input_activations, set_means,
1281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                           /*broadcast_dimensions=*/{1});
1291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Computation add = CreateScalarAddComputation(F32, &builder);
1301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto dev_squares = builder.SquareF32(activation_deviations);
1311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto sum_of_squares = builder.Reduce(
1321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      dev_squares, builder.ConstantR0<float>(0.0f), add, {0, 2, 3});
1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<float> expected = {18, 0.06};
1351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
1361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, VarianceToStddev) {
1391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder builder(client_, "variance_to_stddev");
1401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto variance = builder.ConstantR1<float>({6.f, .02f});
1411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto sqrt = builder.SqrtF32(variance);
1421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  std::vector<float> expected = {2.44948974f, 0.14142136f};
1441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// Compare against a forward batch normalization example in the NN spec
1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// reference.
1491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, SpecComparisonForward) {
1501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputationBuilder builder(client_, "batch_normalize_per_spec");
1511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto input_activations =
1521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.CheckShape(builder.ConstantLiteral(input_literal_),
1531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                         ShapeUtil::MakeShape(F32, {3, 2, 1, 1}));
1541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto gamma = builder.ConstantR1<float>({1.0, 1.0});
1551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto beta = builder.ConstantR1<float>({0.0, 0.0});
1561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Computation add = CreateScalarAddComputation(F32, &builder);
1571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Reduce all dimensions except dimension 1.
1581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2});
1591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto sum = builder.CheckShape(
1601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add,
1611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                     /*dimensions_to_reduce=*/{0, 2, 3}),
1621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      TwoElementVectorF32);
1631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto input_shape = builder.GetShape(input_activations).ConsumeValueOrDie();
1641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto sum_shape = builder.GetShape(sum).ConsumeValueOrDie();
1651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto count = builder.ConstantR0<float>(ShapeUtil::ElementsIn(*input_shape) /
1661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                         ShapeUtil::ElementsIn(*sum_shape));
1671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto set_means = builder.Div(sum, count);
1681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  const float kEpsilon = 1e-9f;
1701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto epsilon = builder.ConstantR0<float>(kEpsilon);
1711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto epsilon2 = builder.ConstantR1<float>({kEpsilon, kEpsilon});
1721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto activation_deviations = builder.Sub(input_activations, set_means,
1731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                           /*broadcast_dimensions=*/{1});
1741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto dev_squares = builder.SquareF32(activation_deviations);
1751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto sum_of_squares = builder.CheckShape(
1761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.Reduce(dev_squares, builder.ConstantR0<float>(0.0f), add,
1771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                     /*dimensions_to_reduce=*/{0, 2, 3}),
1781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      TwoElementVectorF32);
1791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto variance = builder.Div(sum_of_squares, count);
1801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto standard_deviation = builder.SqrtF32(variance);
1811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto standard_deviation_above_epsilon = builder.CheckShape(
1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2}));
1831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto gt_eps = builder.Select(standard_deviation_above_epsilon,
1841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                               standard_deviation, epsilon2);
1851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto normalization_factors = builder.ReciprocalF32(gt_eps);
1861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto normalized_input_activations =
1871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.Mul(activation_deviations, normalization_factors,
1881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  /*broadcast_dimensions=*/{1});
1891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  /* auto output_activations = */ builder.Add(
1901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      builder.Mul(normalized_input_activations, gamma,
1911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                  /*broadcast_dimensions=*/{1}),
1921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      beta, /*broadcast_dimensions=*/{1});
1931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array4D<float> expected(kSamples, kZ, kY, kX);
1951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Array2D<float> pz({
1961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      {-3.f / std::sqrt(6.f), -.1f / std::sqrt(.02f)},
1971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      {0.f, -.1f / std::sqrt(.02f)},
1981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      {3.f / std::sqrt(6.f), .2f / std::sqrt(.02f)},
1991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  });
2001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  expected.FillWithPZ(pz);
2011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
2031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
2041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2051464b9930de871fd11870941963253670f737c23A. Unique TensorFlowerstruct BatchNormTestParam {
2061464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<int64> bounds;
2071464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  int64 feature_index;
2081464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  float random_value_mean;
2091464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  float random_value_var;
210913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar
211913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar  friend ::std::ostream& operator<<(::std::ostream& os,
212913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar                                    const BatchNormTestParam& p) {
213913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar    os << "bounds={" << tensorflow::str_util::Join(p.bounds, ", ") << "}, ";
214913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar    os << "feature_index=" << p.feature_index << ", ";
215913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar    os << "random_value_mean=" << p.random_value_mean << ", ";
216913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar    os << "random_value_var=" << p.random_value_var;
217913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar    return os;
218913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar  }
2191464b9930de871fd11870941963253670f737c23A. Unique TensorFlower};
2201464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2211464b9930de871fd11870941963253670f737c23A. Unique TensorFlower// Tests to test the fused operation of BatchNorm.
2221464b9930de871fd11870941963253670f737c23A. Unique TensorFlowerclass BatchNormTest : public ClientLibraryTestBase,
2231464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      public ::testing::WithParamInterface<BatchNormTestParam> {
2241464b9930de871fd11870941963253670f737c23A. Unique TensorFlower};
2251464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
226f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_P(BatchNormTest, RandomizedTests) {
2271464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  float epsilon = 0.001;
2281464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputationBuilder builder(client_, TestName());
2291464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  const std::vector<int64>& bounds = GetParam().bounds;
2301464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
2311464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  input_array.FillRandom(GetParam().random_value_var,
2321464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                         GetParam().random_value_mean);
2331464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2341464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  const int64 feature_index = GetParam().feature_index;
2351464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  const int64 num_elements_per_feature =
2361464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      Product(bounds) / bounds[feature_index];
2371464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  const int64 feature_bound = bounds[feature_index];
2381464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<float> offset(feature_bound, 1);
2391464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<float> scale(feature_bound, 2);
2401464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2411464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto input_squared =
2421464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
2431464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<int64> reduce_dims;
244030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
2451464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    if (i != feature_index) {
2461464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      reduce_dims.push_back(i);
2471464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    }
2481464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  }
2491464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2501464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto sum =
2511464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
2521464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                  [](float a, float b) { return a + b; });
2531464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2541464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto sum_squared =
2551464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
2561464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                  [](float a, float b) { return a + b; });
2571464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2581464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<float> mean(feature_bound);
2591464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2601464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
2611464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    mean[i] = sum[i] / num_elements_per_feature;
2621464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  }
2631464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2641464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<float> mean_square(feature_bound);
2651464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
2661464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    mean_square[i] = mean[i] * mean[i];
2671464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  }
2681464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2691464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<float> square_mean(feature_bound);
2701464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
2711464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    square_mean[i] = sum_squared[i] / num_elements_per_feature;
2721464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  }
2731464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2741464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::vector<float> var(feature_bound);
2751464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
2761464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    var[i] = square_mean[i] - mean_square[i];
2771464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  }
2781464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
279f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  Array4D<float> mean4D =
2801464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
281f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
282f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
283f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  auto offset4D =
2841464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
2851464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
286f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
287f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower                                                scale4D, offset4D, epsilon);
2881464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2891464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto expected_normalized = Literal::CreateR4FromArray4D<float>(normalized);
2901464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2911464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto offset_literal = Literal::CreateR1<float>(offset);
2921464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto scale_literal = Literal::CreateR1<float>(scale);
2931464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
2941464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
2951464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto input_activations =
2961464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      builder.Parameter(0, input_literal->shape(), "input");
2971464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto scale_activations =
2981464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      builder.Parameter(1, scale_literal->shape(), "offset");
2991464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto offset_activations =
3001464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      builder.Parameter(2, offset_literal->shape(), "scale");
3011464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
3021464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto expected = *Literal::MakeTuple({expected_normalized.get(),
3031464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                       Literal::CreateR1<float>(mean).get(),
3041464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                       Literal::CreateR1<float>(var).get()});
3051464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
3061464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::unique_ptr<GlobalData> input_data =
3071464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      client_->TransferToServer(*input_literal).ConsumeValueOrDie();
3081464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::unique_ptr<GlobalData> scale_data =
3091464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
3101464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  std::unique_ptr<GlobalData> offset_data =
3111464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
3121464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
3131464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  builder.BatchNormTraining(input_activations, scale_activations,
3141464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                            offset_activations, epsilon, feature_index);
3151464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
3161464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputeAndCompareTuple(
3171464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      &builder, expected,
3181464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      {input_data.get(), scale_data.get(), offset_data.get()},
3191464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      ErrorSpec(0.01, 1));
3201464b9930de871fd11870941963253670f737c23A. Unique TensorFlower}
3211464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
3227359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlowerXLA_TEST_P(BatchNormTest, RandomizedInferencingTests) {
3237359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  float epsilon = 0.001;
3247359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  ComputationBuilder builder(client_, TestName());
3257359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  const std::vector<int64>& bounds = GetParam().bounds;
3267359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
3277359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  input_array.FillRandom(GetParam().random_value_var,
3287359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower                         GetParam().random_value_mean);
3297359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3307359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  const int64 feature_index = GetParam().feature_index;
3317359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  const int64 num_elements_per_feature =
3327359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      Product(bounds) / bounds[feature_index];
3337359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  const int64 feature_bound = bounds[feature_index];
3347359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::vector<float> offset(feature_bound, 1);
3357359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::vector<float> scale(feature_bound, 2);
3367359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3377359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto input_squared =
3387359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
3397359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::vector<int64> reduce_dims;
3407359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
3417359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower    if (i != feature_index) {
3427359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      reduce_dims.push_back(i);
3437359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower    }
3447359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  }
3457359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3467359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto sum =
3477359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
3487359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower                                  [](float a, float b) { return a + b; });
3497359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3507359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto sum_squared =
3517359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
3527359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower                                  [](float a, float b) { return a + b; });
3537359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3547359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::vector<float> mean(feature_bound);
3557359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3567359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
3577359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower    mean[i] = sum[i] / num_elements_per_feature;
3587359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  }
3597359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3607359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::vector<float> mean_square(feature_bound);
3617359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
3627359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower    mean_square[i] = mean[i] * mean[i];
3637359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  }
3647359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3657359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::vector<float> square_mean(feature_bound);
3667359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
3677359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower    square_mean[i] = sum_squared[i] / num_elements_per_feature;
3687359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  }
3697359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3707359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::vector<float> var(feature_bound);
3717359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
3727359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower    var[i] = square_mean[i] - mean_square[i];
3737359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  }
3747359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3757359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  Array4D<float> mean4D =
3767359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
3777359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
3787359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
3797359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto offset4D =
3807359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
3817359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3827359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
3837359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower                                                scale4D, offset4D, epsilon);
3847359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3857359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto offset_literal = Literal::CreateR1<float>(offset);
3867359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto scale_literal = Literal::CreateR1<float>(scale);
3877359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto mean_literal = Literal::CreateR1<float>(mean);
3887359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto var_literal = Literal::CreateR1<float>(var);
3897359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
3907359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
3917359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto input_activations =
3927359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      builder.Parameter(0, input_literal->shape(), "input");
3937359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto scale_activations =
3947359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      builder.Parameter(1, scale_literal->shape(), "offset");
3957359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto offset_activations =
3967359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      builder.Parameter(2, offset_literal->shape(), "scale");
3977359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto mean_activations = builder.Parameter(3, mean_literal->shape(), "mean");
3987359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  auto variance_activations =
3997359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      builder.Parameter(4, var_literal->shape(), "variance");
4007359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
4017359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  Array4D<float> expected = normalized;
4027359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
4037359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::unique_ptr<GlobalData> input_data =
4047359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      client_->TransferToServer(*input_literal).ConsumeValueOrDie();
4057359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::unique_ptr<GlobalData> scale_data =
4067359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
4077359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::unique_ptr<GlobalData> offset_data =
4087359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
4097359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::unique_ptr<GlobalData> mean_data =
4107359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
4117359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  std::unique_ptr<GlobalData> variance_data =
4127359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      client_->TransferToServer(*var_literal).ConsumeValueOrDie();
4137359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
4147359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  builder.BatchNormInference(input_activations, scale_activations,
4157359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower                             offset_activations, mean_activations,
4167359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower                             variance_activations, epsilon, feature_index);
4177359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
4187359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower  ComputeAndCompareR4<float>(
4197359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      &builder, expected,
4207359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      {input_data.get(), scale_data.get(), offset_data.get(), mean_data.get(),
4217359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower       variance_data.get()},
4227359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower      ErrorSpec(0.01, 1));
4237359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower}
4247359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower
425f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_P(BatchNormTest, RandomizedGradTests) {
426ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  float epsilon = 0.001;
427ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  ComputationBuilder builder(client_, TestName());
428ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  const std::vector<int64>& bounds = GetParam().bounds;
429ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
430ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  input_array.FillRandom(GetParam().random_value_var,
431ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                         GetParam().random_value_mean);
432ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
433ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  Array4D<float> grad_output_array(bounds[0], bounds[1], bounds[2], bounds[3]);
434ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  grad_output_array.FillRandom(GetParam().random_value_var,
435ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                               GetParam().random_value_mean);
436ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
437ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  const int64 feature_index = GetParam().feature_index;
438ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  const int64 num_elements_per_feature =
439ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      Product(bounds) / bounds[feature_index];
440ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  const int64 feature_bound = bounds[feature_index];
441ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::vector<float> scale(feature_bound, 2);
442ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
443ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto input_squared =
444ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
445ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::vector<int64> reduce_dims;
446030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
447ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower    if (i != feature_index) {
448ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      reduce_dims.push_back(i);
449ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower    }
450ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  }
451ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
452ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto sum =
453ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
454ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                                  [](float a, float b) { return a + b; });
455ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
456ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto sum_squared =
457ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
458ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                                  [](float a, float b) { return a + b; });
459ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
460ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::vector<float> mean(feature_bound);
461ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
462ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
463ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower    mean[i] = sum[i] / num_elements_per_feature;
464ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  }
465ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
466ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::vector<float> mean_square(feature_bound);
467ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
468ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower    mean_square[i] = mean[i] * mean[i];
469ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  }
470ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
471ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::vector<float> square_mean(feature_bound);
472ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
473ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower    square_mean[i] = sum_squared[i] / num_elements_per_feature;
474ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  }
475ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
476ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::vector<float> var(feature_bound);
477ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  for (int64 i = 0; i < feature_bound; ++i) {
478ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower    var[i] = square_mean[i] - mean_square[i];
479ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  }
480ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
481f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  Array4D<float> mean4D =
482ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
483f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
484f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower  auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
485ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
486ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto var_add_epsilon = *ReferenceUtil::MapArray4D(
487030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      var4D, [epsilon](float a) { return a + epsilon; });
488030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
489030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D(
490030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      var_add_epsilon, [epsilon](float a) { return 1 / std::sqrt(a); });
491ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
492ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto grad_output_times_var =
493ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      *ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon,
494ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                                 [](float a, float b) { return a * b; });
495ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
496ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto activation_shifted = *ReferenceUtil::MapArray4D(
497f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower      input_array, mean4D, [](float a, float b) { return a - b; });
498ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
499030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto activation_shifted_times_grad_output =
500030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      *ReferenceUtil::MapArray4D(grad_output_array, activation_shifted,
501ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                                 [](float a, float b) { return a * b; });
502ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
503030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto grad_scale_before_reduction = *ReferenceUtil::MapArray4D(
504030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      activation_shifted_times_grad_output, rsqrt_var_add_epsilon,
505030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      [](float a, float b) { return a * b; });
506030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
507ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto grad_scale = ReferenceUtil::Reduce4DTo1D(
508ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      grad_scale_before_reduction, /*init=*/0.0f, reduce_dims,
509ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      [](float a, float b) { return a + b; });
510ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
511ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto grad_offset =
512ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(grad_output_array, /*init=*/0.0f, reduce_dims,
513ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                                  [](float a, float b) { return a + b; });
514ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
515030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto scale_times_rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D(
516030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      scale4D, rsqrt_var_add_epsilon, [](float a, float b) { return a * b; });
517030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
518030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto I1 = *ReferenceUtil::MapArray4D(
519030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      grad_output_array, [&](float a) { return num_elements_per_feature * a; });
520030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
521030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto I2 = *ReferenceUtil::Broadcast1DTo4D(grad_offset, bounds, feature_index);
522030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
523030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  // I3 = sum(output_grad * (activation - mean(activation)))
524030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto I3 = *ReferenceUtil::Broadcast1DTo4D(
525030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      ReferenceUtil::Reduce4DTo1D(activation_shifted_times_grad_output,
526030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower                                  /*init=*/0.0f, reduce_dims,
527030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower                                  [](float a, float b) { return a + b; }),
528030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      bounds, feature_index);
529030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
530030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  // I4 = (activation - mean(activation)) *
531030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  //   sum(output_grad * (activation - mean(activation)))
532030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto I4 = *ReferenceUtil::MapArray4D(I3, activation_shifted,
533030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower                                       [](float a, float b) { return a * b; });
534030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
535030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  // I5 = (activation - mean(activation)) *
536030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  //   sum(output_grad * (activation - mean(activation))) / (variance +
537030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  //   epsilon))
538030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto I5 = *ReferenceUtil::MapArray4D(I4, var_add_epsilon,
539030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower                                       [](float a, float b) { return a / b; });
540030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
541030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  auto grad_activation = *ReferenceUtil::MapArray4D(
542030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      I1, I2, [](float a, float b) { return a - b; });
543030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
544030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  grad_activation = *ReferenceUtil::MapArray4D(
545030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      grad_activation, I5, [](float a, float b) { return a - b; });
546030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
547030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  grad_activation = *ReferenceUtil::MapArray4D(
548030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      grad_activation, scale4D, [](float a, float b) { return a * b; });
549030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
550030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower  grad_activation = *ReferenceUtil::MapArray4D(
551030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      grad_activation, rsqrt_var_add_epsilon,
552030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      [=](float a, float b) { return a * b / num_elements_per_feature; });
553030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower
554ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto expected_grad_activation =
555ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      Literal::CreateR4FromArray4D<float>(grad_activation);
556ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
557ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
558ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto scale_literal = Literal::CreateR1<float>(scale);
559ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto mean_literal = Literal::CreateR1<float>(mean);
560ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto var_literal = Literal::CreateR1<float>(var);
561ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto grad_output_literal =
562ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      Literal::CreateR4FromArray4D<float>(grad_output_array);
563ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
564ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto input_parameter = builder.Parameter(0, input_literal->shape(), "input");
565ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto scale_parameter = builder.Parameter(1, scale_literal->shape(), "scale");
566ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto mean_parameter = builder.Parameter(2, mean_literal->shape(), "mean");
567ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto var_parameter = builder.Parameter(3, var_literal->shape(), "variance");
568ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto grad_output_parameter =
569ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      builder.Parameter(4, grad_output_literal->shape(), "grad_output");
570ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
571ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::unique_ptr<GlobalData> input_data =
572ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      client_->TransferToServer(*input_literal).ConsumeValueOrDie();
573ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::unique_ptr<GlobalData> scale_data =
574ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
575ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::unique_ptr<GlobalData> mean_data =
576ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
577ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::unique_ptr<GlobalData> var_data =
578ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      client_->TransferToServer(*var_literal).ConsumeValueOrDie();
579ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  std::unique_ptr<GlobalData> grad_output_data =
580ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie();
581ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
582ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto t = builder.BatchNormGrad(input_parameter, scale_parameter,
583ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                                 mean_parameter, var_parameter,
584ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                                 grad_output_parameter, epsilon, feature_index);
585ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
586ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto expected =
587ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      *Literal::MakeTuple({expected_grad_activation.get(),
588ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                           Literal::CreateR1<float>(grad_scale).get(),
589ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                           Literal::CreateR1<float>(grad_offset).get()});
590ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
591ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  ComputeAndCompareTuple(&builder, expected,
592ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                         {input_data.get(), scale_data.get(), mean_data.get(),
593ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                          var_data.get(), grad_output_data.get()},
594ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                         ErrorSpec(0.01, 1));
595ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower}
596ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
5971464b9930de871fd11870941963253670f737c23A. Unique TensorFlowerINSTANTIATE_TEST_CASE_P(
5981464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    BatchNormTest_Instantiation, BatchNormTest,
5991464b9930de871fd11870941963253670f737c23A. Unique TensorFlower    ::testing::Values(BatchNormTestParam{{2, 2, 2, 2}, 0, 100.2f, 200.0f},
6001464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      BatchNormTestParam{{2, 2, 2, 2}, 3, 300.f, 400.0f},
6011464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6021464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      BatchNormTestParam{{1, 10, 1, 1}, 0, 10.1f, 20.1f},
6031464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      BatchNormTestParam{{10, 10, 10, 10}, 1, 3.14f, 314.15f},
6041464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      BatchNormTestParam{{10, 10, 10, 10}, 2, 666.6f, 777.7f},
6051464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      BatchNormTestParam{{10, 10, 10, 10}, 1, -666.6f, 777.7f},
6061464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      BatchNormTestParam{{10, 10, 10, 10}, 2, 0.f, 777.7f},
6074ab3342eb5fdae864e0a32a0e460a27527d79997A. Unique TensorFlower                      BatchNormTestParam{{1, 1, 10, 130}, 2, 0.f, 777.7f},
608ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                      BatchNormTestParam{{1, 1, 130, 11}, 2, 0.f, 777.7f},
6091464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      BatchNormTestParam{{1, 1, 10, 1}, 3, 888.8f, 9.9f},
6101464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6111464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      BatchNormTestParam{{24, 129, 1, 2}, 2, 10000, 10000},
6121464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      BatchNormTestParam{{24, 129, 1, 2}, 3, 10000, 10000},
6131464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6141464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      // Feature on low dimension to trigger relayout, test
6151464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      // internal logical to physical dimension calculation
6161464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      // is correct after relayout.
6171464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                      BatchNormTestParam{{1, 2, 3, 4}, 0, 100, 100}));
6181464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
619f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_F(BatchNormTest, BasicTraining) {
6201464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  const int kFeatureIndex = 3;
6211464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputationBuilder builder(client_, TestName());
6221464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6231464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto operand = builder.ConstantR4FromArray4D<float>(
6241464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}});
6251464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6261464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto scale = builder.ConstantR1<float>({2.0f, 3.0f});
6271464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6281464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto offset = builder.ConstantR1<float>({1.0f, 2.0f});
6291464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6301464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto tuple = builder.BatchNormTraining(operand, scale, offset,
6311464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                         /*epsilon=*/0.001, kFeatureIndex);
6321464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6331464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto expected = *Literal::MakeTuple(
6341464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      {Literal::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
6351464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                 {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}})
6361464b9930de871fd11870941963253670f737c23A. Unique TensorFlower           .get(),
6371464b9930de871fd11870941963253670f737c23A. Unique TensorFlower       Literal::CreateR1<float>({4, 5}).get(),
6381464b9930de871fd11870941963253670f737c23A. Unique TensorFlower       Literal::CreateR1<float>({5, 5}).get()});
6391464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6401464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
6411464b9930de871fd11870941963253670f737c23A. Unique TensorFlower}
6421464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
643f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_F(BatchNormTest, BasicTrainingOnSublane) {
6441464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  const int kFeatureIndex = 2;
6451464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputationBuilder builder(client_, TestName());
6461464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6471464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto operand = builder.ConstantR4FromArray4D<float>(
6481464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}});
6491464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6501464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto scale = builder.ConstantR1<float>({2.0f, 3.0f});
6511464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6521464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto offset = builder.ConstantR1<float>({1.0f, 2.0f});
6531464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6541464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto tuple = builder.BatchNormTraining(operand, scale, offset,
6551464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                         /*epsilon=*/0.001, kFeatureIndex);
6561464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6571464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto expected = *Literal::MakeTuple(
6581464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      {Literal::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
6591464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                 {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}})
6601464b9930de871fd11870941963253670f737c23A. Unique TensorFlower           .get(),
6611464b9930de871fd11870941963253670f737c23A. Unique TensorFlower       Literal::CreateR1<float>({4, 5}).get(),
6621464b9930de871fd11870941963253670f737c23A. Unique TensorFlower       Literal::CreateR1<float>({5, 5}).get()});
6631464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6641464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
6651464b9930de871fd11870941963253670f737c23A. Unique TensorFlower}
6661464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6674c15e948e0e4a17329b2363a5f8f3d4b4178ef7bA. Unique TensorFlowerXLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(TrainingWithFeatureOnLowDimension)) {
6681464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  // Use 0 dimension as feature, tests layout analyzer.
6691464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  const int kFeatureIndex = 0;
6701464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputationBuilder builder(client_, TestName());
6711464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6721464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputationDataHandle h0;
6731464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto operand = CreateR3Parameter<float>(Array3D<float>(260, 2, 2, 1.0f),
6741464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                          /*parameter_number=*/0, "operand",
6751464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                          &builder, &h0);
6761464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputationDataHandle h1;
6771464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto scale =
6781464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      CreateR1Parameter<float>(std::vector<float>(260, 1.0f),
6791464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                               /*parameter_number=*/1, "scale", &builder, &h1);
6801464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputationDataHandle h2;
6811464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto offset =
6821464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      CreateR1Parameter<float>(std::vector<float>(260, 1.0f),
6831464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                               /*parameter_number=*/2, "offset", &builder, &h2);
6841464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6851464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto tuple = builder.BatchNormTraining(h0, h1, h2,
6861464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                         /*epsilon=*/1, kFeatureIndex);
6871464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6881464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto expected = *Literal::MakeTuple(
6891464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      {Literal::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f))
6901464b9930de871fd11870941963253670f737c23A. Unique TensorFlower           .get(),
6911464b9930de871fd11870941963253670f737c23A. Unique TensorFlower       Literal::CreateR1<float>(std::vector<float>(260, 1.0f)).get(),
6921464b9930de871fd11870941963253670f737c23A. Unique TensorFlower       Literal::CreateR1<float>(std::vector<float>(260, 0.0f)).get()});
6931464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
6941464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputeAndCompareTuple(&builder, expected,
6951464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                         {operand.get(), scale.get(), offset.get()},
6961464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                         ErrorSpec(0.1));
6971464b9930de871fd11870941963253670f737c23A. Unique TensorFlower}
6981464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
699f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_F(BatchNormTest, LargeEpsilonTest) {
7001464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  // Test the correctness of choosing a large epsilon value.
7011464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  const int kFeatureIndex = 2;
7021464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputationBuilder builder(client_, TestName());
7031464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
7041464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputationDataHandle h0;
7051464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto operand = CreateR3Parameter<float>({{{0.0f}, {10.0f}, {20.0f}, {30.0f}}},
7061464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                          /*parameter_number=*/0, "operand",
7071464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                          &builder, &h0);
7081464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputationDataHandle h1;
7091464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto scale =
7101464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      CreateR1Parameter<float>(std::vector<float>(1, 1.0f),
7111464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                               /*parameter_number=*/1, "scale", &builder, &h1);
7121464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputationDataHandle h2;
7131464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto offset =
7141464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      CreateR1Parameter<float>(std::vector<float>(1, 0.0f),
7151464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                               /*parameter_number=*/2, "offset", &builder, &h2);
7161464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
7171464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  // var = 125, mean = 15, epsilon = -100
7181464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto tuple = builder.BatchNormTraining(h0, h1, h2,
7191464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                                         /*epsilon=*/-100, kFeatureIndex);
7201464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
7211464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  auto expected = *Literal::MakeTuple(
7221464b9930de871fd11870941963253670f737c23A. Unique TensorFlower      {Literal::CreateR3FromArray3D<float>({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}})
7231464b9930de871fd11870941963253670f737c23A. Unique TensorFlower           .get(),
7241464b9930de871fd11870941963253670f737c23A. Unique TensorFlower       Literal::CreateR1<float>(std::vector<float>(1, 15.0f)).get(),
7251464b9930de871fd11870941963253670f737c23A. Unique TensorFlower       Literal::CreateR1<float>(std::vector<float>(1, 125.0f)).get()});
7261464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
7271464b9930de871fd11870941963253670f737c23A. Unique TensorFlower  ComputeAndCompareTuple(&builder, expected,
7281464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                         {operand.get(), scale.get(), offset.get()},
7291464b9930de871fd11870941963253670f737c23A. Unique TensorFlower                         ErrorSpec(0.1));
7301464b9930de871fd11870941963253670f737c23A. Unique TensorFlower}
7311464b9930de871fd11870941963253670f737c23A. Unique TensorFlower
732f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_F(BatchNormTest, BatchNormGradBasic) {
733ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  const int kFeatureIndex = 2;
734ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  ComputationBuilder builder(client_, TestName());
735ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
736ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto operand =
737ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      builder.ConstantR4FromArray4D<float>(Array4D<float>(2, 2, 2, 1, 0.0f));
738ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
739ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto scale = builder.ConstantR1<float>({1.0f, 1.0f});
740ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
741ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto mean = builder.ConstantR1<float>({0.0f, 0.0f});
742ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
743ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto var = builder.ConstantR1<float>({1.0f, 1.0f});
744ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
745ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto grad_output = builder.ConstantR4FromArray4D<float>(
746ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower      {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}});
747ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
748ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  builder.BatchNormGrad(operand, scale, mean, var, grad_output,
749ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower                        /*epsilon=*/0.0, kFeatureIndex);
750ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
751ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  auto expected = *Literal::MakeTuple(
752030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower      {Literal::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
753030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower                                 {{{1.f}, {1.f}}, {{3.f}, {3.f}}}})
754ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower           .get(),
755ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower       Literal::CreateR1<float>({0, 0}).get(),
756ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower       Literal::CreateR1<float>({16, 20}).get()});
757ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
758ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
759ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower}
760ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower
7611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace
7621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace xla
763