batch_normalization_test.cc revision 913175c2bd38f6e97de399b29cfe1195bffbaa25
11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License"); 41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License. 51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at 61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins http://www.apache.org/licenses/LICENSE-2.0 81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software 101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS, 111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and 131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License. 141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/ 151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <cmath> 171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <memory> 181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <vector> 191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/array2d.h" 211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/array4d.h" 221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation.h" 231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation_builder.h" 241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/lib/arithmetic.h" 251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/local_client.h" 261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/literal_util.h" 271464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/reference_util.h" 281464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_computation.h" 291464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_instruction.h" 301464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_module.h" 311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/shape_util.h" 321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/statusor.h" 331464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/test.h" 341464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/test_helpers.h" 351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/client_library_test_base.h" 361464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/tests/hlo_test_base.h" 371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/literal_test_util.h" 381464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/tests/test_macros.h" 391464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/tests/test_utils.h" 401464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/util.h" 411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/xla_data.pb.h" 42913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar#include "tensorflow/core/lib/strings/str_util.h" 431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/logging.h" 441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/test.h" 451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/types.h" 461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla { 481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace { 491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass BatchNormalizationTest : public ClientLibraryTestBase { 511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins protected: 521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins BatchNormalizationTest() : input_array_(kSamples, kZ, kY, kX) { 531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array2D<float> pz({ 541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // z0 z1 551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {-1.0f, 4.1f}, // p0 561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {2.0f, 4.1f}, // p1 571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {5.0f, 4.4f}, // p2 581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins }); 591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins input_array_.FillWithPZ(pz); 6046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower input_literal_ = *Literal::CreateR4FromArray4D(input_array_); 611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(kSamples, input_array_.planes()); 621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(kZ, input_array_.depth()); 631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(kY, input_array_.height()); 641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(kY, input_array_.width()); 651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins static constexpr int64 kSamples = 3; 681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins static constexpr int64 kX = 1; 691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins static constexpr int64 kY = 1; 701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins static constexpr int64 kZ = 2; 711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array4D<float> input_array_; 731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Literal input_literal_; 741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const ErrorSpec error_spec_{0.001, 0.001}; 751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}; 761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, SubtractInZ) { 781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder builder(client_, "subtract_in_z_one_sample"); 791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto x = builder.ConstantLiteral(input_literal_); 801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto y = builder.ConstantR1<float>({3.14, 4.25}); 811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.Sub(x, y, /*broadcast_dimensions=*/{1}); 821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array4D<float> expected(kSamples, kZ, kY, kX); 841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array2D<float> pz({ 851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {-1.0f - 3.14f, 4.1f - 4.25f}, // p0 861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {2.0f - 3.14f, 4.1f - 4.25f}, // p1 871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {5.0f - 3.14f, 4.4f - 4.25f}, // p2 881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins }); 891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins expected.FillWithPZ(pz); 901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); 911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, SquareTesseractElementwise) { 941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder builder(client_, "square_tesseract_elementwise"); 951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto x = builder.ConstantLiteral(input_literal_); 961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.SquareF32(x); 971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array4D<float> expected(kSamples, kZ, kY, kX); 991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array2D<float> expected_pz({ 1001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {std::pow(-1.0f, 2.0f), std::pow(4.1f, 2.0f)}, 1011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {std::pow(2.0f, 2.0f), std::pow(4.1f, 2.0f)}, 1021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {std::pow(5.0f, 2.0f), std::pow(4.4f, 2.0f)}, 1031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins }); 1041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins expected.FillWithPZ(expected_pz); 1051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); 1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, SumToZ) { 1091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder builder(client_, "sum_to_z"); 1101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto input_activations = builder.ConstantLiteral(input_literal_); 1111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Computation add = CreateScalarAddComputation(F32, &builder); 1121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Reduce all but the Z dimension. 1131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add, 1141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {0, 2, 3}); 1151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<float> expected = {6, 12.6}; 1171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); 1181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, SquareAndReduce) { 1211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder builder(client_, "square_and_reduce"); 1221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto input_activations = builder.ConstantLiteral(input_literal_); 1231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto set_means = builder.ConstantR1<float>({2.f, 4.2f}); 1241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto activation_deviations = builder.Sub(input_activations, set_means, 1251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*broadcast_dimensions=*/{1}); 1261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Computation add = CreateScalarAddComputation(F32, &builder); 1271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto dev_squares = builder.SquareF32(activation_deviations); 1281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto sum_of_squares = builder.Reduce( 1291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins dev_squares, builder.ConstantR0<float>(0.0f), add, {0, 2, 3}); 1301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<float> expected = {18, 0.06}; 1321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); 1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, VarianceToStddev) { 1361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder builder(client_, "variance_to_stddev"); 1371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto variance = builder.ConstantR1<float>({6.f, .02f}); 1381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto sqrt = builder.SqrtF32(variance); 1391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<float> expected = {2.44948974f, 0.14142136f}; 1411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); 1421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// Compare against a forward batch normalization example in the NN spec 1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// reference. 1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, SpecComparisonForward) { 1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder builder(client_, "batch_normalize_per_spec"); 1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto input_activations = 1491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.CheckShape(builder.ConstantLiteral(input_literal_), 1501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(F32, {3, 2, 1, 1})); 1511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto gamma = builder.ConstantR1<float>({1.0, 1.0}); 1521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto beta = builder.ConstantR1<float>({0.0, 0.0}); 1531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Computation add = CreateScalarAddComputation(F32, &builder); 1541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Reduce all dimensions except dimension 1. 1551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2}); 1561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto sum = builder.CheckShape( 1571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add, 1581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*dimensions_to_reduce=*/{0, 2, 3}), 1591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TwoElementVectorF32); 1601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto input_shape = builder.GetShape(input_activations).ConsumeValueOrDie(); 1611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto sum_shape = builder.GetShape(sum).ConsumeValueOrDie(); 1621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto count = builder.ConstantR0<float>(ShapeUtil::ElementsIn(*input_shape) / 1631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::ElementsIn(*sum_shape)); 1641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto set_means = builder.Div(sum, count); 1651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const float kEpsilon = 1e-9f; 1671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto epsilon = builder.ConstantR0<float>(kEpsilon); 1681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto epsilon2 = builder.ConstantR1<float>({kEpsilon, kEpsilon}); 1691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto activation_deviations = builder.Sub(input_activations, set_means, 1701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*broadcast_dimensions=*/{1}); 1711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto dev_squares = builder.SquareF32(activation_deviations); 1721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto sum_of_squares = builder.CheckShape( 1731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.Reduce(dev_squares, builder.ConstantR0<float>(0.0f), add, 1741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*dimensions_to_reduce=*/{0, 2, 3}), 1751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TwoElementVectorF32); 1761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto variance = builder.Div(sum_of_squares, count); 1771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto standard_deviation = builder.SqrtF32(variance); 1781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto standard_deviation_above_epsilon = builder.CheckShape( 1791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2})); 1801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto gt_eps = builder.Select(standard_deviation_above_epsilon, 1811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins standard_deviation, epsilon2); 1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto normalization_factors = builder.ReciprocalF32(gt_eps); 1831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto normalized_input_activations = 1841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.Mul(activation_deviations, normalization_factors, 1851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*broadcast_dimensions=*/{1}); 1861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /* auto output_activations = */ builder.Add( 1871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.Mul(normalized_input_activations, gamma, 1881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*broadcast_dimensions=*/{1}), 1891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins beta, /*broadcast_dimensions=*/{1}); 1901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array4D<float> expected(kSamples, kZ, kY, kX); 1921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array2D<float> pz({ 1931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {-3.f / std::sqrt(6.f), -.1f / std::sqrt(.02f)}, 1941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {0.f, -.1f / std::sqrt(.02f)}, 1951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {3.f / std::sqrt(6.f), .2f / std::sqrt(.02f)}, 1961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins }); 1971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins expected.FillWithPZ(pz); 1981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); 2001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 2011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2021464b9930de871fd11870941963253670f737c23A. Unique TensorFlowerstruct BatchNormTestParam { 2031464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<int64> bounds; 2041464b9930de871fd11870941963253670f737c23A. Unique TensorFlower int64 feature_index; 2051464b9930de871fd11870941963253670f737c23A. Unique TensorFlower float random_value_mean; 2061464b9930de871fd11870941963253670f737c23A. Unique TensorFlower float random_value_var; 207913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar 208913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar friend ::std::ostream& operator<<(::std::ostream& os, 209913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar const BatchNormTestParam& p) { 210913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar os << "bounds={" << tensorflow::str_util::Join(p.bounds, ", ") << "}, "; 211913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar os << "feature_index=" << p.feature_index << ", "; 212913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar os << "random_value_mean=" << p.random_value_mean << ", "; 213913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar os << "random_value_var=" << p.random_value_var; 214913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar return os; 215913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar } 2161464b9930de871fd11870941963253670f737c23A. Unique TensorFlower}; 2171464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2181464b9930de871fd11870941963253670f737c23A. Unique TensorFlower// Tests to test the fused operation of BatchNorm. 2191464b9930de871fd11870941963253670f737c23A. Unique TensorFlowerclass BatchNormTest : public ClientLibraryTestBase, 2201464b9930de871fd11870941963253670f737c23A. Unique TensorFlower public ::testing::WithParamInterface<BatchNormTestParam> { 2211464b9930de871fd11870941963253670f737c23A. Unique TensorFlower}; 2221464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 223f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_P(BatchNormTest, RandomizedTests) { 2241464b9930de871fd11870941963253670f737c23A. Unique TensorFlower float epsilon = 0.001; 2251464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputationBuilder builder(client_, TestName()); 2261464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const std::vector<int64>& bounds = GetParam().bounds; 2271464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]); 2281464b9930de871fd11870941963253670f737c23A. Unique TensorFlower input_array.FillRandom(GetParam().random_value_var, 2291464b9930de871fd11870941963253670f737c23A. Unique TensorFlower GetParam().random_value_mean); 2301464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2311464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const int64 feature_index = GetParam().feature_index; 2321464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const int64 num_elements_per_feature = 2331464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Product(bounds) / bounds[feature_index]; 2341464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const int64 feature_bound = bounds[feature_index]; 2351464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<float> offset(feature_bound, 1); 2361464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<float> scale(feature_bound, 2); 2371464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2381464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto input_squared = 2391464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; }); 2401464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<int64> reduce_dims; 241030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) { 2421464b9930de871fd11870941963253670f737c23A. Unique TensorFlower if (i != feature_index) { 2431464b9930de871fd11870941963253670f737c23A. Unique TensorFlower reduce_dims.push_back(i); 2441464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 2451464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 2461464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2471464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto sum = 2481464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims, 2491464b9930de871fd11870941963253670f737c23A. Unique TensorFlower [](float a, float b) { return a + b; }); 2501464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2511464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto sum_squared = 2521464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims, 2531464b9930de871fd11870941963253670f737c23A. Unique TensorFlower [](float a, float b) { return a + b; }); 2541464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2551464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<float> mean(feature_bound); 2561464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2571464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 2581464b9930de871fd11870941963253670f737c23A. Unique TensorFlower mean[i] = sum[i] / num_elements_per_feature; 2591464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 2601464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2611464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<float> mean_square(feature_bound); 2621464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 2631464b9930de871fd11870941963253670f737c23A. Unique TensorFlower mean_square[i] = mean[i] * mean[i]; 2641464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 2651464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2661464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<float> square_mean(feature_bound); 2671464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 2681464b9930de871fd11870941963253670f737c23A. Unique TensorFlower square_mean[i] = sum_squared[i] / num_elements_per_feature; 2691464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 2701464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2711464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<float> var(feature_bound); 2721464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 2731464b9930de871fd11870941963253670f737c23A. Unique TensorFlower var[i] = square_mean[i] - mean_square[i]; 2741464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 2751464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 276f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower Array4D<float> mean4D = 2771464b9930de871fd11870941963253670f737c23A. Unique TensorFlower *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index); 278f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index); 279f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index); 280f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower auto offset4D = 2811464b9930de871fd11870941963253670f737c23A. Unique TensorFlower *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index); 2821464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 283f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D, 284f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower scale4D, offset4D, epsilon); 2851464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2861464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto expected_normalized = Literal::CreateR4FromArray4D<float>(normalized); 2871464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2881464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto offset_literal = Literal::CreateR1<float>(offset); 2891464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto scale_literal = Literal::CreateR1<float>(scale); 2901464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto input_literal = Literal::CreateR4FromArray4D<float>(input_array); 2911464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2921464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto input_activations = 2931464b9930de871fd11870941963253670f737c23A. Unique TensorFlower builder.Parameter(0, input_literal->shape(), "input"); 2941464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto scale_activations = 2951464b9930de871fd11870941963253670f737c23A. Unique TensorFlower builder.Parameter(1, scale_literal->shape(), "offset"); 2961464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto offset_activations = 2971464b9930de871fd11870941963253670f737c23A. Unique TensorFlower builder.Parameter(2, offset_literal->shape(), "scale"); 2981464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2991464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto expected = *Literal::MakeTuple({expected_normalized.get(), 3001464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Literal::CreateR1<float>(mean).get(), 3011464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Literal::CreateR1<float>(var).get()}); 3021464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 3031464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::unique_ptr<GlobalData> input_data = 3041464b9930de871fd11870941963253670f737c23A. Unique TensorFlower client_->TransferToServer(*input_literal).ConsumeValueOrDie(); 3051464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::unique_ptr<GlobalData> scale_data = 3061464b9930de871fd11870941963253670f737c23A. Unique TensorFlower client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); 3071464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::unique_ptr<GlobalData> offset_data = 3081464b9930de871fd11870941963253670f737c23A. Unique TensorFlower client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); 3091464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 3101464b9930de871fd11870941963253670f737c23A. Unique TensorFlower builder.BatchNormTraining(input_activations, scale_activations, 3111464b9930de871fd11870941963253670f737c23A. Unique TensorFlower offset_activations, epsilon, feature_index); 3121464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 3131464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputeAndCompareTuple( 3141464b9930de871fd11870941963253670f737c23A. Unique TensorFlower &builder, expected, 3151464b9930de871fd11870941963253670f737c23A. Unique TensorFlower {input_data.get(), scale_data.get(), offset_data.get()}, 3161464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ErrorSpec(0.01, 1)); 3171464b9930de871fd11870941963253670f737c23A. Unique TensorFlower} 3181464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 3197359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlowerXLA_TEST_P(BatchNormTest, RandomizedInferencingTests) { 3207359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower float epsilon = 0.001; 3217359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower ComputationBuilder builder(client_, TestName()); 3227359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower const std::vector<int64>& bounds = GetParam().bounds; 3237359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]); 3247359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower input_array.FillRandom(GetParam().random_value_var, 3257359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower GetParam().random_value_mean); 3267359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3277359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower const int64 feature_index = GetParam().feature_index; 3287359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower const int64 num_elements_per_feature = 3297359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower Product(bounds) / bounds[feature_index]; 3307359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower const int64 feature_bound = bounds[feature_index]; 3317359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::vector<float> offset(feature_bound, 1); 3327359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::vector<float> scale(feature_bound, 2); 3337359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3347359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto input_squared = 3357359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; }); 3367359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::vector<int64> reduce_dims; 3377359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) { 3387359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower if (i != feature_index) { 3397359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower reduce_dims.push_back(i); 3407359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower } 3417359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower } 3427359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3437359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto sum = 3447359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims, 3457359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower [](float a, float b) { return a + b; }); 3467359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3477359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto sum_squared = 3487359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims, 3497359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower [](float a, float b) { return a + b; }); 3507359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3517359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::vector<float> mean(feature_bound); 3527359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3537359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 3547359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower mean[i] = sum[i] / num_elements_per_feature; 3557359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower } 3567359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3577359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::vector<float> mean_square(feature_bound); 3587359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 3597359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower mean_square[i] = mean[i] * mean[i]; 3607359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower } 3617359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3627359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::vector<float> square_mean(feature_bound); 3637359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 3647359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower square_mean[i] = sum_squared[i] / num_elements_per_feature; 3657359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower } 3667359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3677359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::vector<float> var(feature_bound); 3687359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 3697359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower var[i] = square_mean[i] - mean_square[i]; 3707359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower } 3717359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3727359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower Array4D<float> mean4D = 3737359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index); 3747359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index); 3757359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index); 3767359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto offset4D = 3777359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index); 3787359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3797359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D, 3807359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower scale4D, offset4D, epsilon); 3817359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3827359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto offset_literal = Literal::CreateR1<float>(offset); 3837359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto scale_literal = Literal::CreateR1<float>(scale); 3847359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto mean_literal = Literal::CreateR1<float>(mean); 3857359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto var_literal = Literal::CreateR1<float>(var); 3867359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto input_literal = Literal::CreateR4FromArray4D<float>(input_array); 3877359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3887359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto input_activations = 3897359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower builder.Parameter(0, input_literal->shape(), "input"); 3907359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto scale_activations = 3917359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower builder.Parameter(1, scale_literal->shape(), "offset"); 3927359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto offset_activations = 3937359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower builder.Parameter(2, offset_literal->shape(), "scale"); 3947359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto mean_activations = builder.Parameter(3, mean_literal->shape(), "mean"); 3957359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto variance_activations = 3967359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower builder.Parameter(4, var_literal->shape(), "variance"); 3977359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3987359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower Array4D<float> expected = normalized; 3997359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 4007359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::unique_ptr<GlobalData> input_data = 4017359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower client_->TransferToServer(*input_literal).ConsumeValueOrDie(); 4027359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::unique_ptr<GlobalData> scale_data = 4037359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); 4047359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::unique_ptr<GlobalData> offset_data = 4057359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); 4067359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::unique_ptr<GlobalData> mean_data = 4077359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower client_->TransferToServer(*mean_literal).ConsumeValueOrDie(); 4087359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::unique_ptr<GlobalData> variance_data = 4097359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower client_->TransferToServer(*var_literal).ConsumeValueOrDie(); 4107359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 4117359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower builder.BatchNormInference(input_activations, scale_activations, 4127359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower offset_activations, mean_activations, 4137359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower variance_activations, epsilon, feature_index); 4147359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 4157359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower ComputeAndCompareR4<float>( 4167359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower &builder, expected, 4177359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower {input_data.get(), scale_data.get(), offset_data.get(), mean_data.get(), 4187359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower variance_data.get()}, 4197359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower ErrorSpec(0.01, 1)); 4207359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower} 4217359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 422f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_P(BatchNormTest, RandomizedGradTests) { 423ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower float epsilon = 0.001; 424ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ComputationBuilder builder(client_, TestName()); 425ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower const std::vector<int64>& bounds = GetParam().bounds; 426ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]); 427ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower input_array.FillRandom(GetParam().random_value_var, 428ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower GetParam().random_value_mean); 429ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 430ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower Array4D<float> grad_output_array(bounds[0], bounds[1], bounds[2], bounds[3]); 431ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower grad_output_array.FillRandom(GetParam().random_value_var, 432ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower GetParam().random_value_mean); 433ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 434ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower const int64 feature_index = GetParam().feature_index; 435ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower const int64 num_elements_per_feature = 436ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower Product(bounds) / bounds[feature_index]; 437ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower const int64 feature_bound = bounds[feature_index]; 438ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::vector<float> scale(feature_bound, 2); 439ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 440ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto input_squared = 441ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; }); 442ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::vector<int64> reduce_dims; 443030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) { 444ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower if (i != feature_index) { 445ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower reduce_dims.push_back(i); 446ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower } 447ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower } 448ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 449ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto sum = 450ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims, 451ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower [](float a, float b) { return a + b; }); 452ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 453ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto sum_squared = 454ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims, 455ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower [](float a, float b) { return a + b; }); 456ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 457ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::vector<float> mean(feature_bound); 458ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 459ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 460ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower mean[i] = sum[i] / num_elements_per_feature; 461ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower } 462ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 463ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::vector<float> mean_square(feature_bound); 464ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 465ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower mean_square[i] = mean[i] * mean[i]; 466ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower } 467ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 468ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::vector<float> square_mean(feature_bound); 469ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 470ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower square_mean[i] = sum_squared[i] / num_elements_per_feature; 471ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower } 472ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 473ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::vector<float> var(feature_bound); 474ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 475ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower var[i] = square_mean[i] - mean_square[i]; 476ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower } 477ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 478f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower Array4D<float> mean4D = 479ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index); 480f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index); 481f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index); 482ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 483ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto var_add_epsilon = *ReferenceUtil::MapArray4D( 484030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower var4D, [epsilon](float a) { return a + epsilon; }); 485030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 486030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D( 487030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower var_add_epsilon, [epsilon](float a) { return 1 / std::sqrt(a); }); 488ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 489ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto grad_output_times_var = 490ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower *ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon, 491ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower [](float a, float b) { return a * b; }); 492ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 493ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto activation_shifted = *ReferenceUtil::MapArray4D( 494f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower input_array, mean4D, [](float a, float b) { return a - b; }); 495ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 496030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto activation_shifted_times_grad_output = 497030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower *ReferenceUtil::MapArray4D(grad_output_array, activation_shifted, 498ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower [](float a, float b) { return a * b; }); 499ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 500030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto grad_scale_before_reduction = *ReferenceUtil::MapArray4D( 501030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower activation_shifted_times_grad_output, rsqrt_var_add_epsilon, 502030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower [](float a, float b) { return a * b; }); 503030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 504ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto grad_scale = ReferenceUtil::Reduce4DTo1D( 505ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower grad_scale_before_reduction, /*init=*/0.0f, reduce_dims, 506ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower [](float a, float b) { return a + b; }); 507ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 508ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto grad_offset = 509ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(grad_output_array, /*init=*/0.0f, reduce_dims, 510ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower [](float a, float b) { return a + b; }); 511ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 512030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto scale_times_rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D( 513030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower scale4D, rsqrt_var_add_epsilon, [](float a, float b) { return a * b; }); 514030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 515030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto I1 = *ReferenceUtil::MapArray4D( 516030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower grad_output_array, [&](float a) { return num_elements_per_feature * a; }); 517030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 518030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto I2 = *ReferenceUtil::Broadcast1DTo4D(grad_offset, bounds, feature_index); 519030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 520030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower // I3 = sum(output_grad * (activation - mean(activation))) 521030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto I3 = *ReferenceUtil::Broadcast1DTo4D( 522030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(activation_shifted_times_grad_output, 523030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower /*init=*/0.0f, reduce_dims, 524030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower [](float a, float b) { return a + b; }), 525030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower bounds, feature_index); 526030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 527030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower // I4 = (activation - mean(activation)) * 528030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower // sum(output_grad * (activation - mean(activation))) 529030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto I4 = *ReferenceUtil::MapArray4D(I3, activation_shifted, 530030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower [](float a, float b) { return a * b; }); 531030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 532030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower // I5 = (activation - mean(activation)) * 533030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower // sum(output_grad * (activation - mean(activation))) / (variance + 534030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower // epsilon)) 535030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto I5 = *ReferenceUtil::MapArray4D(I4, var_add_epsilon, 536030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower [](float a, float b) { return a / b; }); 537030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 538030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto grad_activation = *ReferenceUtil::MapArray4D( 539030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower I1, I2, [](float a, float b) { return a - b; }); 540030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 541030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower grad_activation = *ReferenceUtil::MapArray4D( 542030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower grad_activation, I5, [](float a, float b) { return a - b; }); 543030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 544030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower grad_activation = *ReferenceUtil::MapArray4D( 545030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower grad_activation, scale4D, [](float a, float b) { return a * b; }); 546030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 547030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower grad_activation = *ReferenceUtil::MapArray4D( 548030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower grad_activation, rsqrt_var_add_epsilon, 549030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower [=](float a, float b) { return a * b / num_elements_per_feature; }); 550030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 551ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto expected_grad_activation = 552ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower Literal::CreateR4FromArray4D<float>(grad_activation); 553ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 554ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto input_literal = Literal::CreateR4FromArray4D<float>(input_array); 555ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto scale_literal = Literal::CreateR1<float>(scale); 556ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto mean_literal = Literal::CreateR1<float>(mean); 557ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto var_literal = Literal::CreateR1<float>(var); 558ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto grad_output_literal = 559ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower Literal::CreateR4FromArray4D<float>(grad_output_array); 560ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 561ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto input_parameter = builder.Parameter(0, input_literal->shape(), "input"); 562ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto scale_parameter = builder.Parameter(1, scale_literal->shape(), "scale"); 563ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto mean_parameter = builder.Parameter(2, mean_literal->shape(), "mean"); 564ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto var_parameter = builder.Parameter(3, var_literal->shape(), "variance"); 565ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto grad_output_parameter = 566ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower builder.Parameter(4, grad_output_literal->shape(), "grad_output"); 567ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 568ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::unique_ptr<GlobalData> input_data = 569ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower client_->TransferToServer(*input_literal).ConsumeValueOrDie(); 570ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::unique_ptr<GlobalData> scale_data = 571ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); 572ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::unique_ptr<GlobalData> mean_data = 573ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower client_->TransferToServer(*mean_literal).ConsumeValueOrDie(); 574ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::unique_ptr<GlobalData> var_data = 575ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower client_->TransferToServer(*var_literal).ConsumeValueOrDie(); 576ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::unique_ptr<GlobalData> grad_output_data = 577ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie(); 578ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 579ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto t = builder.BatchNormGrad(input_parameter, scale_parameter, 580ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower mean_parameter, var_parameter, 581ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower grad_output_parameter, epsilon, feature_index); 582ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 583ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto expected = 584ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower *Literal::MakeTuple({expected_grad_activation.get(), 585ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower Literal::CreateR1<float>(grad_scale).get(), 586ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower Literal::CreateR1<float>(grad_offset).get()}); 587ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 588ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ComputeAndCompareTuple(&builder, expected, 589ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower {input_data.get(), scale_data.get(), mean_data.get(), 590ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower var_data.get(), grad_output_data.get()}, 591ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ErrorSpec(0.01, 1)); 592ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower} 593ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 5941464b9930de871fd11870941963253670f737c23A. Unique TensorFlowerINSTANTIATE_TEST_CASE_P( 5951464b9930de871fd11870941963253670f737c23A. Unique TensorFlower BatchNormTest_Instantiation, BatchNormTest, 5961464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ::testing::Values(BatchNormTestParam{{2, 2, 2, 2}, 0, 100.2f, 200.0f}, 5971464b9930de871fd11870941963253670f737c23A. Unique TensorFlower BatchNormTestParam{{2, 2, 2, 2}, 3, 300.f, 400.0f}, 5981464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 5991464b9930de871fd11870941963253670f737c23A. Unique TensorFlower BatchNormTestParam{{1, 10, 1, 1}, 0, 10.1f, 20.1f}, 6001464b9930de871fd11870941963253670f737c23A. Unique TensorFlower BatchNormTestParam{{10, 10, 10, 10}, 1, 3.14f, 314.15f}, 6011464b9930de871fd11870941963253670f737c23A. Unique TensorFlower BatchNormTestParam{{10, 10, 10, 10}, 2, 666.6f, 777.7f}, 6021464b9930de871fd11870941963253670f737c23A. Unique TensorFlower BatchNormTestParam{{10, 10, 10, 10}, 1, -666.6f, 777.7f}, 6031464b9930de871fd11870941963253670f737c23A. Unique TensorFlower BatchNormTestParam{{10, 10, 10, 10}, 2, 0.f, 777.7f}, 6044ab3342eb5fdae864e0a32a0e460a27527d79997A. Unique TensorFlower BatchNormTestParam{{1, 1, 10, 130}, 2, 0.f, 777.7f}, 605ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower BatchNormTestParam{{1, 1, 130, 11}, 2, 0.f, 777.7f}, 6061464b9930de871fd11870941963253670f737c23A. Unique TensorFlower BatchNormTestParam{{1, 1, 10, 1}, 3, 888.8f, 9.9f}, 6071464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6081464b9930de871fd11870941963253670f737c23A. Unique TensorFlower BatchNormTestParam{{24, 129, 1, 2}, 2, 10000, 10000}, 6091464b9930de871fd11870941963253670f737c23A. Unique TensorFlower BatchNormTestParam{{24, 129, 1, 2}, 3, 10000, 10000}, 6101464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6111464b9930de871fd11870941963253670f737c23A. Unique TensorFlower // Feature on low dimension to trigger relayout, test 6121464b9930de871fd11870941963253670f737c23A. Unique TensorFlower // internal logical to physical dimension calculation 6131464b9930de871fd11870941963253670f737c23A. Unique TensorFlower // is correct after relayout. 6141464b9930de871fd11870941963253670f737c23A. Unique TensorFlower BatchNormTestParam{{1, 2, 3, 4}, 0, 100, 100})); 6151464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 616f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_F(BatchNormTest, BasicTraining) { 6171464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const int kFeatureIndex = 3; 6181464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputationBuilder builder(client_, TestName()); 6191464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6201464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto operand = builder.ConstantR4FromArray4D<float>( 6211464b9930de871fd11870941963253670f737c23A. Unique TensorFlower {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}}); 6221464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6231464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto scale = builder.ConstantR1<float>({2.0f, 3.0f}); 6241464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6251464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto offset = builder.ConstantR1<float>({1.0f, 2.0f}); 6261464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6271464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto tuple = builder.BatchNormTraining(operand, scale, offset, 6281464b9930de871fd11870941963253670f737c23A. Unique TensorFlower /*epsilon=*/0.001, kFeatureIndex); 6291464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6301464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto expected = *Literal::MakeTuple( 6311464b9930de871fd11870941963253670f737c23A. Unique TensorFlower {Literal::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, 6321464b9930de871fd11870941963253670f737c23A. Unique TensorFlower {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}) 6331464b9930de871fd11870941963253670f737c23A. Unique TensorFlower .get(), 6341464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Literal::CreateR1<float>({4, 5}).get(), 6351464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Literal::CreateR1<float>({5, 5}).get()}); 6361464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6371464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); 6381464b9930de871fd11870941963253670f737c23A. Unique TensorFlower} 6391464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 640f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_F(BatchNormTest, BasicTrainingOnSublane) { 6411464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const int kFeatureIndex = 2; 6421464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputationBuilder builder(client_, TestName()); 6431464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6441464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto operand = builder.ConstantR4FromArray4D<float>( 6451464b9930de871fd11870941963253670f737c23A. Unique TensorFlower {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); 6461464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6471464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto scale = builder.ConstantR1<float>({2.0f, 3.0f}); 6481464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6491464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto offset = builder.ConstantR1<float>({1.0f, 2.0f}); 6501464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6511464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto tuple = builder.BatchNormTraining(operand, scale, offset, 6521464b9930de871fd11870941963253670f737c23A. Unique TensorFlower /*epsilon=*/0.001, kFeatureIndex); 6531464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6541464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto expected = *Literal::MakeTuple( 6551464b9930de871fd11870941963253670f737c23A. Unique TensorFlower {Literal::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, 6561464b9930de871fd11870941963253670f737c23A. Unique TensorFlower {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}) 6571464b9930de871fd11870941963253670f737c23A. Unique TensorFlower .get(), 6581464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Literal::CreateR1<float>({4, 5}).get(), 6591464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Literal::CreateR1<float>({5, 5}).get()}); 6601464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6611464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); 6621464b9930de871fd11870941963253670f737c23A. Unique TensorFlower} 6631464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6644c15e948e0e4a17329b2363a5f8f3d4b4178ef7bA. Unique TensorFlowerXLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(TrainingWithFeatureOnLowDimension)) { 6651464b9930de871fd11870941963253670f737c23A. Unique TensorFlower // Use 0 dimension as feature, tests layout analyzer. 6661464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const int kFeatureIndex = 0; 6671464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputationBuilder builder(client_, TestName()); 6681464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6691464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputationDataHandle h0; 6701464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto operand = CreateR3Parameter<float>(Array3D<float>(260, 2, 2, 1.0f), 6711464b9930de871fd11870941963253670f737c23A. Unique TensorFlower /*parameter_number=*/0, "operand", 6721464b9930de871fd11870941963253670f737c23A. Unique TensorFlower &builder, &h0); 6731464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputationDataHandle h1; 6741464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto scale = 6751464b9930de871fd11870941963253670f737c23A. Unique TensorFlower CreateR1Parameter<float>(std::vector<float>(260, 1.0f), 6761464b9930de871fd11870941963253670f737c23A. Unique TensorFlower /*parameter_number=*/1, "scale", &builder, &h1); 6771464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputationDataHandle h2; 6781464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto offset = 6791464b9930de871fd11870941963253670f737c23A. Unique TensorFlower CreateR1Parameter<float>(std::vector<float>(260, 1.0f), 6801464b9930de871fd11870941963253670f737c23A. Unique TensorFlower /*parameter_number=*/2, "offset", &builder, &h2); 6811464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6821464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto tuple = builder.BatchNormTraining(h0, h1, h2, 6831464b9930de871fd11870941963253670f737c23A. Unique TensorFlower /*epsilon=*/1, kFeatureIndex); 6841464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6851464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto expected = *Literal::MakeTuple( 6861464b9930de871fd11870941963253670f737c23A. Unique TensorFlower {Literal::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f)) 6871464b9930de871fd11870941963253670f737c23A. Unique TensorFlower .get(), 6881464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Literal::CreateR1<float>(std::vector<float>(260, 1.0f)).get(), 6891464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Literal::CreateR1<float>(std::vector<float>(260, 0.0f)).get()}); 6901464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6911464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputeAndCompareTuple(&builder, expected, 6921464b9930de871fd11870941963253670f737c23A. Unique TensorFlower {operand.get(), scale.get(), offset.get()}, 6931464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ErrorSpec(0.1)); 6941464b9930de871fd11870941963253670f737c23A. Unique TensorFlower} 6951464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 696f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_F(BatchNormTest, LargeEpsilonTest) { 6971464b9930de871fd11870941963253670f737c23A. Unique TensorFlower // Test the correctness of choosing a large epsilon value. 6981464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const int kFeatureIndex = 2; 6991464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputationBuilder builder(client_, TestName()); 7001464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 7011464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputationDataHandle h0; 7021464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto operand = CreateR3Parameter<float>({{{0.0f}, {10.0f}, {20.0f}, {30.0f}}}, 7031464b9930de871fd11870941963253670f737c23A. Unique TensorFlower /*parameter_number=*/0, "operand", 7041464b9930de871fd11870941963253670f737c23A. Unique TensorFlower &builder, &h0); 7051464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputationDataHandle h1; 7061464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto scale = 7071464b9930de871fd11870941963253670f737c23A. Unique TensorFlower CreateR1Parameter<float>(std::vector<float>(1, 1.0f), 7081464b9930de871fd11870941963253670f737c23A. Unique TensorFlower /*parameter_number=*/1, "scale", &builder, &h1); 7091464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputationDataHandle h2; 7101464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto offset = 7111464b9930de871fd11870941963253670f737c23A. Unique TensorFlower CreateR1Parameter<float>(std::vector<float>(1, 0.0f), 7121464b9930de871fd11870941963253670f737c23A. Unique TensorFlower /*parameter_number=*/2, "offset", &builder, &h2); 7131464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 7141464b9930de871fd11870941963253670f737c23A. Unique TensorFlower // var = 125, mean = 15, epsilon = -100 7151464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto tuple = builder.BatchNormTraining(h0, h1, h2, 7161464b9930de871fd11870941963253670f737c23A. Unique TensorFlower /*epsilon=*/-100, kFeatureIndex); 7171464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 7181464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto expected = *Literal::MakeTuple( 7191464b9930de871fd11870941963253670f737c23A. Unique TensorFlower {Literal::CreateR3FromArray3D<float>({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) 7201464b9930de871fd11870941963253670f737c23A. Unique TensorFlower .get(), 7211464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Literal::CreateR1<float>(std::vector<float>(1, 15.0f)).get(), 7221464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Literal::CreateR1<float>(std::vector<float>(1, 125.0f)).get()}); 7231464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 7241464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputeAndCompareTuple(&builder, expected, 7251464b9930de871fd11870941963253670f737c23A. Unique TensorFlower {operand.get(), scale.get(), offset.get()}, 7261464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ErrorSpec(0.1)); 7271464b9930de871fd11870941963253670f737c23A. Unique TensorFlower} 7281464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 729f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_F(BatchNormTest, BatchNormGradBasic) { 730ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower const int kFeatureIndex = 2; 731ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ComputationBuilder builder(client_, TestName()); 732ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 733ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto operand = 734ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower builder.ConstantR4FromArray4D<float>(Array4D<float>(2, 2, 2, 1, 0.0f)); 735ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 736ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto scale = builder.ConstantR1<float>({1.0f, 1.0f}); 737ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 738ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto mean = builder.ConstantR1<float>({0.0f, 0.0f}); 739ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 740ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto var = builder.ConstantR1<float>({1.0f, 1.0f}); 741ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 742ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto grad_output = builder.ConstantR4FromArray4D<float>( 743ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); 744ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 745ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower builder.BatchNormGrad(operand, scale, mean, var, grad_output, 746ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower /*epsilon=*/0.0, kFeatureIndex); 747ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 748ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto expected = *Literal::MakeTuple( 749030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower {Literal::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}}, 750030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}) 751ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower .get(), 752ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower Literal::CreateR1<float>({0, 0}).get(), 753ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower Literal::CreateR1<float>({16, 20}).get()}); 754ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 755ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); 756ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower} 757ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 7581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace 7591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace xla 760