batch_normalization_test.cc revision 3d05024db5fc0c39ee3f5626639bd611f44ac03c
11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License"); 41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License. 51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at 61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins http://www.apache.org/licenses/LICENSE-2.0 81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software 101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS, 111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and 131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License. 141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/ 151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <cmath> 171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <memory> 181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <vector> 191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/array2d.h" 211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/array4d.h" 221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation.h" 231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation_builder.h" 241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/lib/arithmetic.h" 251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/local_client.h" 261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/literal_util.h" 271464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/reference_util.h" 281464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_computation.h" 291464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_instruction.h" 301464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_module.h" 311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/shape_util.h" 321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/statusor.h" 331464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/test.h" 341464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/test_helpers.h" 351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/client_library_test_base.h" 361464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/tests/hlo_test_base.h" 371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/literal_test_util.h" 381464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/tests/test_macros.h" 391464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/tests/test_utils.h" 401464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/util.h" 411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/xla_data.pb.h" 423d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower#include "tensorflow/core/lib/math/math_util.h" 43913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar#include "tensorflow/core/lib/strings/str_util.h" 441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/logging.h" 451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/test.h" 461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/types.h" 471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla { 491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace { 501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass BatchNormalizationTest : public ClientLibraryTestBase { 521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins protected: 531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins BatchNormalizationTest() : input_array_(kSamples, kZ, kY, kX) { 541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array2D<float> pz({ 551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // z0 z1 561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {-1.0f, 4.1f}, // p0 571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {2.0f, 4.1f}, // p1 581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {5.0f, 4.4f}, // p2 591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins }); 601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins input_array_.FillWithPZ(pz); 6146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower input_literal_ = *Literal::CreateR4FromArray4D(input_array_); 621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(kSamples, input_array_.planes()); 631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(kZ, input_array_.depth()); 641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(kY, input_array_.height()); 651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(kY, input_array_.width()); 661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins static constexpr int64 kSamples = 3; 691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins static constexpr int64 kX = 1; 701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins static constexpr int64 kY = 1; 711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins static constexpr int64 kZ = 2; 721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array4D<float> input_array_; 741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Literal input_literal_; 751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const ErrorSpec error_spec_{0.001, 0.001}; 761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}; 771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, SubtractInZ) { 791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder builder(client_, "subtract_in_z_one_sample"); 801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto x = builder.ConstantLiteral(input_literal_); 811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto y = builder.ConstantR1<float>({3.14, 4.25}); 821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.Sub(x, y, /*broadcast_dimensions=*/{1}); 831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array4D<float> expected(kSamples, kZ, kY, kX); 851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array2D<float> pz({ 861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {-1.0f - 3.14f, 4.1f - 4.25f}, // p0 871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {2.0f - 3.14f, 4.1f - 4.25f}, // p1 881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {5.0f - 3.14f, 4.4f - 4.25f}, // p2 891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins }); 901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins expected.FillWithPZ(pz); 911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); 921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, SquareTesseractElementwise) { 951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder builder(client_, "square_tesseract_elementwise"); 961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto x = builder.ConstantLiteral(input_literal_); 971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.SquareF32(x); 981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 993d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower using tensorflow::MathUtil; 1003d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower 1011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array4D<float> expected(kSamples, kZ, kY, kX); 1021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array2D<float> expected_pz({ 1033d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower {MathUtil::IPow(-1.0f, 2), MathUtil::IPow(4.1f, 2)}, 1043d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower {MathUtil::IPow(2.0f, 2), MathUtil::IPow(4.1f, 2)}, 1053d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower {MathUtil::IPow(5.0f, 2), MathUtil::IPow(4.4f, 2)}, 1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins }); 1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins expected.FillWithPZ(expected_pz); 1081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); 1091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, SumToZ) { 1121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder builder(client_, "sum_to_z"); 1131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto input_activations = builder.ConstantLiteral(input_literal_); 1141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Computation add = CreateScalarAddComputation(F32, &builder); 1151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Reduce all but the Z dimension. 1161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add, 1171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {0, 2, 3}); 1181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<float> expected = {6, 12.6}; 1201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); 1211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, SquareAndReduce) { 1241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder builder(client_, "square_and_reduce"); 1251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto input_activations = builder.ConstantLiteral(input_literal_); 1261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto set_means = builder.ConstantR1<float>({2.f, 4.2f}); 1271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto activation_deviations = builder.Sub(input_activations, set_means, 1281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*broadcast_dimensions=*/{1}); 1291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Computation add = CreateScalarAddComputation(F32, &builder); 1301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto dev_squares = builder.SquareF32(activation_deviations); 1311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto sum_of_squares = builder.Reduce( 1321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins dev_squares, builder.ConstantR0<float>(0.0f), add, {0, 2, 3}); 1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<float> expected = {18, 0.06}; 1351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); 1361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, VarianceToStddev) { 1391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder builder(client_, "variance_to_stddev"); 1401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto variance = builder.ConstantR1<float>({6.f, .02f}); 1411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto sqrt = builder.SqrtF32(variance); 1421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<float> expected = {2.44948974f, 0.14142136f}; 1441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); 1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// Compare against a forward batch normalization example in the NN spec 1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// reference. 1491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTEST_F(BatchNormalizationTest, SpecComparisonForward) { 1501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder builder(client_, "batch_normalize_per_spec"); 1511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto input_activations = 1521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.CheckShape(builder.ConstantLiteral(input_literal_), 1531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(F32, {3, 2, 1, 1})); 1541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto gamma = builder.ConstantR1<float>({1.0, 1.0}); 1551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto beta = builder.ConstantR1<float>({0.0, 0.0}); 1561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Computation add = CreateScalarAddComputation(F32, &builder); 1571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Reduce all dimensions except dimension 1. 1581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2}); 1591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto sum = builder.CheckShape( 1601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add, 1611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*dimensions_to_reduce=*/{0, 2, 3}), 1621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TwoElementVectorF32); 1631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto input_shape = builder.GetShape(input_activations).ConsumeValueOrDie(); 1641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto sum_shape = builder.GetShape(sum).ConsumeValueOrDie(); 1651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto count = builder.ConstantR0<float>(ShapeUtil::ElementsIn(*input_shape) / 1661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::ElementsIn(*sum_shape)); 1671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto set_means = builder.Div(sum, count); 1681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const float kEpsilon = 1e-9f; 1701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto epsilon = builder.ConstantR0<float>(kEpsilon); 1711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto epsilon2 = builder.ConstantR1<float>({kEpsilon, kEpsilon}); 1721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto activation_deviations = builder.Sub(input_activations, set_means, 1731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*broadcast_dimensions=*/{1}); 1741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto dev_squares = builder.SquareF32(activation_deviations); 1751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto sum_of_squares = builder.CheckShape( 1761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.Reduce(dev_squares, builder.ConstantR0<float>(0.0f), add, 1771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*dimensions_to_reduce=*/{0, 2, 3}), 1781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TwoElementVectorF32); 1791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto variance = builder.Div(sum_of_squares, count); 1801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto standard_deviation = builder.SqrtF32(variance); 1811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto standard_deviation_above_epsilon = builder.CheckShape( 1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2})); 1831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto gt_eps = builder.Select(standard_deviation_above_epsilon, 1841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins standard_deviation, epsilon2); 1851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto normalization_factors = builder.ReciprocalF32(gt_eps); 1861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto normalized_input_activations = 1871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.Mul(activation_deviations, normalization_factors, 1881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*broadcast_dimensions=*/{1}); 1891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /* auto output_activations = */ builder.Add( 1901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.Mul(normalized_input_activations, gamma, 1911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*broadcast_dimensions=*/{1}), 1921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins beta, /*broadcast_dimensions=*/{1}); 1931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array4D<float> expected(kSamples, kZ, kY, kX); 1951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array2D<float> pz({ 1961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {-3.f / std::sqrt(6.f), -.1f / std::sqrt(.02f)}, 1971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {0.f, -.1f / std::sqrt(.02f)}, 1981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {3.f / std::sqrt(6.f), .2f / std::sqrt(.02f)}, 1991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins }); 2001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins expected.FillWithPZ(pz); 2011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); 2031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 2041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2051464b9930de871fd11870941963253670f737c23A. Unique TensorFlowerstruct BatchNormTestParam { 2061464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<int64> bounds; 2071464b9930de871fd11870941963253670f737c23A. Unique TensorFlower int64 feature_index; 2081464b9930de871fd11870941963253670f737c23A. Unique TensorFlower float random_value_mean; 2091464b9930de871fd11870941963253670f737c23A. Unique TensorFlower float random_value_var; 210913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar 211913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar friend ::std::ostream& operator<<(::std::ostream& os, 212913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar const BatchNormTestParam& p) { 213913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar os << "bounds={" << tensorflow::str_util::Join(p.bounds, ", ") << "}, "; 214913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar os << "feature_index=" << p.feature_index << ", "; 215913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar os << "random_value_mean=" << p.random_value_mean << ", "; 216913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar os << "random_value_var=" << p.random_value_var; 217913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar return os; 218913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar } 2191464b9930de871fd11870941963253670f737c23A. Unique TensorFlower}; 2201464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2211464b9930de871fd11870941963253670f737c23A. Unique TensorFlower// Tests to test the fused operation of BatchNorm. 2221464b9930de871fd11870941963253670f737c23A. Unique TensorFlowerclass BatchNormTest : public ClientLibraryTestBase, 2231464b9930de871fd11870941963253670f737c23A. Unique TensorFlower public ::testing::WithParamInterface<BatchNormTestParam> { 2241464b9930de871fd11870941963253670f737c23A. Unique TensorFlower}; 2251464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 226f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_P(BatchNormTest, RandomizedTests) { 2271464b9930de871fd11870941963253670f737c23A. Unique TensorFlower float epsilon = 0.001; 2281464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputationBuilder builder(client_, TestName()); 2291464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const std::vector<int64>& bounds = GetParam().bounds; 2301464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]); 2311464b9930de871fd11870941963253670f737c23A. Unique TensorFlower input_array.FillRandom(GetParam().random_value_var, 2321464b9930de871fd11870941963253670f737c23A. Unique TensorFlower GetParam().random_value_mean); 2331464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2341464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const int64 feature_index = GetParam().feature_index; 2351464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const int64 num_elements_per_feature = 2361464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Product(bounds) / bounds[feature_index]; 2371464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const int64 feature_bound = bounds[feature_index]; 2381464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<float> offset(feature_bound, 1); 2391464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<float> scale(feature_bound, 2); 2401464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2411464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto input_squared = 2421464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; }); 2431464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<int64> reduce_dims; 244030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) { 2451464b9930de871fd11870941963253670f737c23A. Unique TensorFlower if (i != feature_index) { 2461464b9930de871fd11870941963253670f737c23A. Unique TensorFlower reduce_dims.push_back(i); 2471464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 2481464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 2491464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2501464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto sum = 2511464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims, 2521464b9930de871fd11870941963253670f737c23A. Unique TensorFlower [](float a, float b) { return a + b; }); 2531464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2541464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto sum_squared = 2551464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims, 2561464b9930de871fd11870941963253670f737c23A. Unique TensorFlower [](float a, float b) { return a + b; }); 2571464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2581464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<float> mean(feature_bound); 2591464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2601464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 2611464b9930de871fd11870941963253670f737c23A. Unique TensorFlower mean[i] = sum[i] / num_elements_per_feature; 2621464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 2631464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2641464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<float> mean_square(feature_bound); 2651464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 2661464b9930de871fd11870941963253670f737c23A. Unique TensorFlower mean_square[i] = mean[i] * mean[i]; 2671464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 2681464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2691464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<float> square_mean(feature_bound); 2701464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 2711464b9930de871fd11870941963253670f737c23A. Unique TensorFlower square_mean[i] = sum_squared[i] / num_elements_per_feature; 2721464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 2731464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2741464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<float> var(feature_bound); 2751464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 2761464b9930de871fd11870941963253670f737c23A. Unique TensorFlower var[i] = square_mean[i] - mean_square[i]; 2771464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 2781464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 279f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower Array4D<float> mean4D = 2801464b9930de871fd11870941963253670f737c23A. Unique TensorFlower *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index); 281f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index); 282f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index); 283f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower auto offset4D = 2841464b9930de871fd11870941963253670f737c23A. Unique TensorFlower *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index); 2851464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 286f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D, 287f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower scale4D, offset4D, epsilon); 2881464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2891464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto expected_normalized = Literal::CreateR4FromArray4D<float>(normalized); 2901464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2911464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto offset_literal = Literal::CreateR1<float>(offset); 2921464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto scale_literal = Literal::CreateR1<float>(scale); 2931464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto input_literal = Literal::CreateR4FromArray4D<float>(input_array); 2941464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 2951464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto input_activations = 2961464b9930de871fd11870941963253670f737c23A. Unique TensorFlower builder.Parameter(0, input_literal->shape(), "input"); 2971464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto scale_activations = 2981464b9930de871fd11870941963253670f737c23A. Unique TensorFlower builder.Parameter(1, scale_literal->shape(), "offset"); 2991464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto offset_activations = 3001464b9930de871fd11870941963253670f737c23A. Unique TensorFlower builder.Parameter(2, offset_literal->shape(), "scale"); 3011464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 3021464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto expected = *Literal::MakeTuple({expected_normalized.get(), 3031464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Literal::CreateR1<float>(mean).get(), 3041464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Literal::CreateR1<float>(var).get()}); 3051464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 3061464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::unique_ptr<GlobalData> input_data = 3071464b9930de871fd11870941963253670f737c23A. Unique TensorFlower client_->TransferToServer(*input_literal).ConsumeValueOrDie(); 3081464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::unique_ptr<GlobalData> scale_data = 3091464b9930de871fd11870941963253670f737c23A. Unique TensorFlower client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); 3101464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::unique_ptr<GlobalData> offset_data = 3111464b9930de871fd11870941963253670f737c23A. Unique TensorFlower client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); 3121464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 3131464b9930de871fd11870941963253670f737c23A. Unique TensorFlower builder.BatchNormTraining(input_activations, scale_activations, 3141464b9930de871fd11870941963253670f737c23A. Unique TensorFlower offset_activations, epsilon, feature_index); 3151464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 3161464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputeAndCompareTuple( 3171464b9930de871fd11870941963253670f737c23A. Unique TensorFlower &builder, expected, 3181464b9930de871fd11870941963253670f737c23A. Unique TensorFlower {input_data.get(), scale_data.get(), offset_data.get()}, 3191464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ErrorSpec(0.01, 1)); 3201464b9930de871fd11870941963253670f737c23A. Unique TensorFlower} 3211464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 3227359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlowerXLA_TEST_P(BatchNormTest, RandomizedInferencingTests) { 3237359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower float epsilon = 0.001; 3247359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower ComputationBuilder builder(client_, TestName()); 3257359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower const std::vector<int64>& bounds = GetParam().bounds; 3267359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]); 3277359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower input_array.FillRandom(GetParam().random_value_var, 3287359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower GetParam().random_value_mean); 3297359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3307359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower const int64 feature_index = GetParam().feature_index; 3317359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower const int64 num_elements_per_feature = 3327359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower Product(bounds) / bounds[feature_index]; 3337359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower const int64 feature_bound = bounds[feature_index]; 3347359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::vector<float> offset(feature_bound, 1); 3357359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::vector<float> scale(feature_bound, 2); 3367359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3377359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto input_squared = 3387359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; }); 3397359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::vector<int64> reduce_dims; 3407359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) { 3417359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower if (i != feature_index) { 3427359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower reduce_dims.push_back(i); 3437359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower } 3447359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower } 3457359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3467359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto sum = 3477359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims, 3487359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower [](float a, float b) { return a + b; }); 3497359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3507359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto sum_squared = 3517359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims, 3527359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower [](float a, float b) { return a + b; }); 3537359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3547359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::vector<float> mean(feature_bound); 3557359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3567359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 3577359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower mean[i] = sum[i] / num_elements_per_feature; 3587359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower } 3597359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3607359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::vector<float> mean_square(feature_bound); 3617359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 3627359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower mean_square[i] = mean[i] * mean[i]; 3637359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower } 3647359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3657359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::vector<float> square_mean(feature_bound); 3667359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 3677359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower square_mean[i] = sum_squared[i] / num_elements_per_feature; 3687359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower } 3697359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3707359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::vector<float> var(feature_bound); 3717359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 3727359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower var[i] = square_mean[i] - mean_square[i]; 3737359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower } 3747359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3757359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower Array4D<float> mean4D = 3767359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index); 3777359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index); 3787359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index); 3797359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto offset4D = 3807359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index); 3817359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3827359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D, 3837359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower scale4D, offset4D, epsilon); 3847359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3857359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto offset_literal = Literal::CreateR1<float>(offset); 3867359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto scale_literal = Literal::CreateR1<float>(scale); 3877359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto mean_literal = Literal::CreateR1<float>(mean); 3887359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto var_literal = Literal::CreateR1<float>(var); 3897359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto input_literal = Literal::CreateR4FromArray4D<float>(input_array); 3907359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 3917359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto input_activations = 3927359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower builder.Parameter(0, input_literal->shape(), "input"); 3937359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto scale_activations = 3947359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower builder.Parameter(1, scale_literal->shape(), "offset"); 3957359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto offset_activations = 3967359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower builder.Parameter(2, offset_literal->shape(), "scale"); 3977359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto mean_activations = builder.Parameter(3, mean_literal->shape(), "mean"); 3987359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto variance_activations = 3997359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower builder.Parameter(4, var_literal->shape(), "variance"); 4007359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 4017359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower Array4D<float> expected = normalized; 4027359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 4037359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::unique_ptr<GlobalData> input_data = 4047359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower client_->TransferToServer(*input_literal).ConsumeValueOrDie(); 4057359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::unique_ptr<GlobalData> scale_data = 4067359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); 4077359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::unique_ptr<GlobalData> offset_data = 4087359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); 4097359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::unique_ptr<GlobalData> mean_data = 4107359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower client_->TransferToServer(*mean_literal).ConsumeValueOrDie(); 4117359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::unique_ptr<GlobalData> variance_data = 4127359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower client_->TransferToServer(*var_literal).ConsumeValueOrDie(); 4137359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 4147359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower builder.BatchNormInference(input_activations, scale_activations, 4157359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower offset_activations, mean_activations, 4167359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower variance_activations, epsilon, feature_index); 4177359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 4187359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower ComputeAndCompareR4<float>( 4197359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower &builder, expected, 4207359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower {input_data.get(), scale_data.get(), offset_data.get(), mean_data.get(), 4217359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower variance_data.get()}, 4227359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower ErrorSpec(0.01, 1)); 4237359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower} 4247359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 425f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_P(BatchNormTest, RandomizedGradTests) { 426ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower float epsilon = 0.001; 427ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ComputationBuilder builder(client_, TestName()); 428ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower const std::vector<int64>& bounds = GetParam().bounds; 429ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]); 430ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower input_array.FillRandom(GetParam().random_value_var, 431ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower GetParam().random_value_mean); 432ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 433ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower Array4D<float> grad_output_array(bounds[0], bounds[1], bounds[2], bounds[3]); 434ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower grad_output_array.FillRandom(GetParam().random_value_var, 435ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower GetParam().random_value_mean); 436ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 437ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower const int64 feature_index = GetParam().feature_index; 438ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower const int64 num_elements_per_feature = 439ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower Product(bounds) / bounds[feature_index]; 440ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower const int64 feature_bound = bounds[feature_index]; 441ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::vector<float> scale(feature_bound, 2); 442ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 443ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto input_squared = 444ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; }); 445ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::vector<int64> reduce_dims; 446030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) { 447ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower if (i != feature_index) { 448ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower reduce_dims.push_back(i); 449ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower } 450ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower } 451ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 452ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto sum = 453ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims, 454ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower [](float a, float b) { return a + b; }); 455ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 456ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto sum_squared = 457ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims, 458ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower [](float a, float b) { return a + b; }); 459ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 460ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::vector<float> mean(feature_bound); 461ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 462ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 463ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower mean[i] = sum[i] / num_elements_per_feature; 464ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower } 465ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 466ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::vector<float> mean_square(feature_bound); 467ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 468ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower mean_square[i] = mean[i] * mean[i]; 469ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower } 470ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 471ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::vector<float> square_mean(feature_bound); 472ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 473ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower square_mean[i] = sum_squared[i] / num_elements_per_feature; 474ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower } 475ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 476ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::vector<float> var(feature_bound); 477ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 478ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower var[i] = square_mean[i] - mean_square[i]; 479ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower } 480ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 481f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower Array4D<float> mean4D = 482ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index); 483f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index); 484f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index); 485ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 486ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto var_add_epsilon = *ReferenceUtil::MapArray4D( 487030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower var4D, [epsilon](float a) { return a + epsilon; }); 488030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 489030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D( 490030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower var_add_epsilon, [epsilon](float a) { return 1 / std::sqrt(a); }); 491ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 492ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto grad_output_times_var = 493ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower *ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon, 494ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower [](float a, float b) { return a * b; }); 495ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 496ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto activation_shifted = *ReferenceUtil::MapArray4D( 497f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower input_array, mean4D, [](float a, float b) { return a - b; }); 498ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 499030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto activation_shifted_times_grad_output = 500030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower *ReferenceUtil::MapArray4D(grad_output_array, activation_shifted, 501ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower [](float a, float b) { return a * b; }); 502ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 503030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto grad_scale_before_reduction = *ReferenceUtil::MapArray4D( 504030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower activation_shifted_times_grad_output, rsqrt_var_add_epsilon, 505030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower [](float a, float b) { return a * b; }); 506030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 507ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto grad_scale = ReferenceUtil::Reduce4DTo1D( 508ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower grad_scale_before_reduction, /*init=*/0.0f, reduce_dims, 509ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower [](float a, float b) { return a + b; }); 510ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 511ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto grad_offset = 512ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(grad_output_array, /*init=*/0.0f, reduce_dims, 513ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower [](float a, float b) { return a + b; }); 514ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 515030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto scale_times_rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D( 516030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower scale4D, rsqrt_var_add_epsilon, [](float a, float b) { return a * b; }); 517030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 518030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto I1 = *ReferenceUtil::MapArray4D( 519030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower grad_output_array, [&](float a) { return num_elements_per_feature * a; }); 520030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 521030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto I2 = *ReferenceUtil::Broadcast1DTo4D(grad_offset, bounds, feature_index); 522030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 523030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower // I3 = sum(output_grad * (activation - mean(activation))) 524030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto I3 = *ReferenceUtil::Broadcast1DTo4D( 525030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(activation_shifted_times_grad_output, 526030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower /*init=*/0.0f, reduce_dims, 527030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower [](float a, float b) { return a + b; }), 528030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower bounds, feature_index); 529030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 530030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower // I4 = (activation - mean(activation)) * 531030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower // sum(output_grad * (activation - mean(activation))) 532030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto I4 = *ReferenceUtil::MapArray4D(I3, activation_shifted, 533030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower [](float a, float b) { return a * b; }); 534030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 535030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower // I5 = (activation - mean(activation)) * 536030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower // sum(output_grad * (activation - mean(activation))) / (variance + 537030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower // epsilon)) 538030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto I5 = *ReferenceUtil::MapArray4D(I4, var_add_epsilon, 539030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower [](float a, float b) { return a / b; }); 540030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 541030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto grad_activation = *ReferenceUtil::MapArray4D( 542030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower I1, I2, [](float a, float b) { return a - b; }); 543030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 544030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower grad_activation = *ReferenceUtil::MapArray4D( 545030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower grad_activation, I5, [](float a, float b) { return a - b; }); 546030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 547030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower grad_activation = *ReferenceUtil::MapArray4D( 548030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower grad_activation, scale4D, [](float a, float b) { return a * b; }); 549030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 550030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower grad_activation = *ReferenceUtil::MapArray4D( 551030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower grad_activation, rsqrt_var_add_epsilon, 552030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower [=](float a, float b) { return a * b / num_elements_per_feature; }); 553030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 554ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto expected_grad_activation = 555ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower Literal::CreateR4FromArray4D<float>(grad_activation); 556ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 557ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto input_literal = Literal::CreateR4FromArray4D<float>(input_array); 558ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto scale_literal = Literal::CreateR1<float>(scale); 559ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto mean_literal = Literal::CreateR1<float>(mean); 560ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto var_literal = Literal::CreateR1<float>(var); 561ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto grad_output_literal = 562ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower Literal::CreateR4FromArray4D<float>(grad_output_array); 563ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 564ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto input_parameter = builder.Parameter(0, input_literal->shape(), "input"); 565ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto scale_parameter = builder.Parameter(1, scale_literal->shape(), "scale"); 566ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto mean_parameter = builder.Parameter(2, mean_literal->shape(), "mean"); 567ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto var_parameter = builder.Parameter(3, var_literal->shape(), "variance"); 568ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto grad_output_parameter = 569ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower builder.Parameter(4, grad_output_literal->shape(), "grad_output"); 570ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 571ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::unique_ptr<GlobalData> input_data = 572ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower client_->TransferToServer(*input_literal).ConsumeValueOrDie(); 573ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::unique_ptr<GlobalData> scale_data = 574ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); 575ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::unique_ptr<GlobalData> mean_data = 576ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower client_->TransferToServer(*mean_literal).ConsumeValueOrDie(); 577ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::unique_ptr<GlobalData> var_data = 578ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower client_->TransferToServer(*var_literal).ConsumeValueOrDie(); 579ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::unique_ptr<GlobalData> grad_output_data = 580ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie(); 581ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 582ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto t = builder.BatchNormGrad(input_parameter, scale_parameter, 583ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower mean_parameter, var_parameter, 584ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower grad_output_parameter, epsilon, feature_index); 585ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 586ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto expected = 587ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower *Literal::MakeTuple({expected_grad_activation.get(), 588ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower Literal::CreateR1<float>(grad_scale).get(), 589ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower Literal::CreateR1<float>(grad_offset).get()}); 590ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 591ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ComputeAndCompareTuple(&builder, expected, 592ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower {input_data.get(), scale_data.get(), mean_data.get(), 593ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower var_data.get(), grad_output_data.get()}, 594ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ErrorSpec(0.01, 1)); 595ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower} 596ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 5971464b9930de871fd11870941963253670f737c23A. Unique TensorFlowerINSTANTIATE_TEST_CASE_P( 5981464b9930de871fd11870941963253670f737c23A. Unique TensorFlower BatchNormTest_Instantiation, BatchNormTest, 5991464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ::testing::Values(BatchNormTestParam{{2, 2, 2, 2}, 0, 100.2f, 200.0f}, 6001464b9930de871fd11870941963253670f737c23A. Unique TensorFlower BatchNormTestParam{{2, 2, 2, 2}, 3, 300.f, 400.0f}, 6011464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6021464b9930de871fd11870941963253670f737c23A. Unique TensorFlower BatchNormTestParam{{1, 10, 1, 1}, 0, 10.1f, 20.1f}, 6031464b9930de871fd11870941963253670f737c23A. Unique TensorFlower BatchNormTestParam{{10, 10, 10, 10}, 1, 3.14f, 314.15f}, 6041464b9930de871fd11870941963253670f737c23A. Unique TensorFlower BatchNormTestParam{{10, 10, 10, 10}, 2, 666.6f, 777.7f}, 6051464b9930de871fd11870941963253670f737c23A. Unique TensorFlower BatchNormTestParam{{10, 10, 10, 10}, 1, -666.6f, 777.7f}, 6061464b9930de871fd11870941963253670f737c23A. Unique TensorFlower BatchNormTestParam{{10, 10, 10, 10}, 2, 0.f, 777.7f}, 6074ab3342eb5fdae864e0a32a0e460a27527d79997A. Unique TensorFlower BatchNormTestParam{{1, 1, 10, 130}, 2, 0.f, 777.7f}, 608ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower BatchNormTestParam{{1, 1, 130, 11}, 2, 0.f, 777.7f}, 6091464b9930de871fd11870941963253670f737c23A. Unique TensorFlower BatchNormTestParam{{1, 1, 10, 1}, 3, 888.8f, 9.9f}, 6101464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6111464b9930de871fd11870941963253670f737c23A. Unique TensorFlower BatchNormTestParam{{24, 129, 1, 2}, 2, 10000, 10000}, 6121464b9930de871fd11870941963253670f737c23A. Unique TensorFlower BatchNormTestParam{{24, 129, 1, 2}, 3, 10000, 10000}, 6131464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6141464b9930de871fd11870941963253670f737c23A. Unique TensorFlower // Feature on low dimension to trigger relayout, test 6151464b9930de871fd11870941963253670f737c23A. Unique TensorFlower // internal logical to physical dimension calculation 6161464b9930de871fd11870941963253670f737c23A. Unique TensorFlower // is correct after relayout. 6171464b9930de871fd11870941963253670f737c23A. Unique TensorFlower BatchNormTestParam{{1, 2, 3, 4}, 0, 100, 100})); 6181464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 619f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_F(BatchNormTest, BasicTraining) { 6201464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const int kFeatureIndex = 3; 6211464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputationBuilder builder(client_, TestName()); 6221464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6231464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto operand = builder.ConstantR4FromArray4D<float>( 6241464b9930de871fd11870941963253670f737c23A. Unique TensorFlower {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}}); 6251464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6261464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto scale = builder.ConstantR1<float>({2.0f, 3.0f}); 6271464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6281464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto offset = builder.ConstantR1<float>({1.0f, 2.0f}); 6291464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6301464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto tuple = builder.BatchNormTraining(operand, scale, offset, 6311464b9930de871fd11870941963253670f737c23A. Unique TensorFlower /*epsilon=*/0.001, kFeatureIndex); 6321464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6331464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto expected = *Literal::MakeTuple( 6341464b9930de871fd11870941963253670f737c23A. Unique TensorFlower {Literal::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, 6351464b9930de871fd11870941963253670f737c23A. Unique TensorFlower {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}) 6361464b9930de871fd11870941963253670f737c23A. Unique TensorFlower .get(), 6371464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Literal::CreateR1<float>({4, 5}).get(), 6381464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Literal::CreateR1<float>({5, 5}).get()}); 6391464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6401464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); 6411464b9930de871fd11870941963253670f737c23A. Unique TensorFlower} 6421464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 643f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_F(BatchNormTest, BasicTrainingOnSublane) { 6441464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const int kFeatureIndex = 2; 6451464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputationBuilder builder(client_, TestName()); 6461464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6471464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto operand = builder.ConstantR4FromArray4D<float>( 6481464b9930de871fd11870941963253670f737c23A. Unique TensorFlower {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); 6491464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6501464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto scale = builder.ConstantR1<float>({2.0f, 3.0f}); 6511464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6521464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto offset = builder.ConstantR1<float>({1.0f, 2.0f}); 6531464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6541464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto tuple = builder.BatchNormTraining(operand, scale, offset, 6551464b9930de871fd11870941963253670f737c23A. Unique TensorFlower /*epsilon=*/0.001, kFeatureIndex); 6561464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6571464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto expected = *Literal::MakeTuple( 6581464b9930de871fd11870941963253670f737c23A. Unique TensorFlower {Literal::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, 6591464b9930de871fd11870941963253670f737c23A. Unique TensorFlower {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}) 6601464b9930de871fd11870941963253670f737c23A. Unique TensorFlower .get(), 6611464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Literal::CreateR1<float>({4, 5}).get(), 6621464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Literal::CreateR1<float>({5, 5}).get()}); 6631464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6641464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); 6651464b9930de871fd11870941963253670f737c23A. Unique TensorFlower} 6661464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6674c15e948e0e4a17329b2363a5f8f3d4b4178ef7bA. Unique TensorFlowerXLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(TrainingWithFeatureOnLowDimension)) { 6681464b9930de871fd11870941963253670f737c23A. Unique TensorFlower // Use 0 dimension as feature, tests layout analyzer. 6691464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const int kFeatureIndex = 0; 6701464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputationBuilder builder(client_, TestName()); 6711464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6721464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputationDataHandle h0; 6731464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto operand = CreateR3Parameter<float>(Array3D<float>(260, 2, 2, 1.0f), 6741464b9930de871fd11870941963253670f737c23A. Unique TensorFlower /*parameter_number=*/0, "operand", 6751464b9930de871fd11870941963253670f737c23A. Unique TensorFlower &builder, &h0); 6761464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputationDataHandle h1; 6771464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto scale = 6781464b9930de871fd11870941963253670f737c23A. Unique TensorFlower CreateR1Parameter<float>(std::vector<float>(260, 1.0f), 6791464b9930de871fd11870941963253670f737c23A. Unique TensorFlower /*parameter_number=*/1, "scale", &builder, &h1); 6801464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputationDataHandle h2; 6811464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto offset = 6821464b9930de871fd11870941963253670f737c23A. Unique TensorFlower CreateR1Parameter<float>(std::vector<float>(260, 1.0f), 6831464b9930de871fd11870941963253670f737c23A. Unique TensorFlower /*parameter_number=*/2, "offset", &builder, &h2); 6841464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6851464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto tuple = builder.BatchNormTraining(h0, h1, h2, 6861464b9930de871fd11870941963253670f737c23A. Unique TensorFlower /*epsilon=*/1, kFeatureIndex); 6871464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6881464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto expected = *Literal::MakeTuple( 6891464b9930de871fd11870941963253670f737c23A. Unique TensorFlower {Literal::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f)) 6901464b9930de871fd11870941963253670f737c23A. Unique TensorFlower .get(), 6911464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Literal::CreateR1<float>(std::vector<float>(260, 1.0f)).get(), 6921464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Literal::CreateR1<float>(std::vector<float>(260, 0.0f)).get()}); 6931464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 6941464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputeAndCompareTuple(&builder, expected, 6951464b9930de871fd11870941963253670f737c23A. Unique TensorFlower {operand.get(), scale.get(), offset.get()}, 6961464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ErrorSpec(0.1)); 6971464b9930de871fd11870941963253670f737c23A. Unique TensorFlower} 6981464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 699f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_F(BatchNormTest, LargeEpsilonTest) { 7001464b9930de871fd11870941963253670f737c23A. Unique TensorFlower // Test the correctness of choosing a large epsilon value. 7011464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const int kFeatureIndex = 2; 7021464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputationBuilder builder(client_, TestName()); 7031464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 7041464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputationDataHandle h0; 7051464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto operand = CreateR3Parameter<float>({{{0.0f}, {10.0f}, {20.0f}, {30.0f}}}, 7061464b9930de871fd11870941963253670f737c23A. Unique TensorFlower /*parameter_number=*/0, "operand", 7071464b9930de871fd11870941963253670f737c23A. Unique TensorFlower &builder, &h0); 7081464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputationDataHandle h1; 7091464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto scale = 7101464b9930de871fd11870941963253670f737c23A. Unique TensorFlower CreateR1Parameter<float>(std::vector<float>(1, 1.0f), 7111464b9930de871fd11870941963253670f737c23A. Unique TensorFlower /*parameter_number=*/1, "scale", &builder, &h1); 7121464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputationDataHandle h2; 7131464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto offset = 7141464b9930de871fd11870941963253670f737c23A. Unique TensorFlower CreateR1Parameter<float>(std::vector<float>(1, 0.0f), 7151464b9930de871fd11870941963253670f737c23A. Unique TensorFlower /*parameter_number=*/2, "offset", &builder, &h2); 7161464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 7171464b9930de871fd11870941963253670f737c23A. Unique TensorFlower // var = 125, mean = 15, epsilon = -100 7181464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto tuple = builder.BatchNormTraining(h0, h1, h2, 7191464b9930de871fd11870941963253670f737c23A. Unique TensorFlower /*epsilon=*/-100, kFeatureIndex); 7201464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 7211464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto expected = *Literal::MakeTuple( 7221464b9930de871fd11870941963253670f737c23A. Unique TensorFlower {Literal::CreateR3FromArray3D<float>({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) 7231464b9930de871fd11870941963253670f737c23A. Unique TensorFlower .get(), 7241464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Literal::CreateR1<float>(std::vector<float>(1, 15.0f)).get(), 7251464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Literal::CreateR1<float>(std::vector<float>(1, 125.0f)).get()}); 7261464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 7271464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputeAndCompareTuple(&builder, expected, 7281464b9930de871fd11870941963253670f737c23A. Unique TensorFlower {operand.get(), scale.get(), offset.get()}, 7291464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ErrorSpec(0.1)); 7301464b9930de871fd11870941963253670f737c23A. Unique TensorFlower} 7311464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 732f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlowerXLA_TEST_F(BatchNormTest, BatchNormGradBasic) { 733ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower const int kFeatureIndex = 2; 734ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ComputationBuilder builder(client_, TestName()); 735ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 736ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto operand = 737ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower builder.ConstantR4FromArray4D<float>(Array4D<float>(2, 2, 2, 1, 0.0f)); 738ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 739ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto scale = builder.ConstantR1<float>({1.0f, 1.0f}); 740ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 741ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto mean = builder.ConstantR1<float>({0.0f, 0.0f}); 742ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 743ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto var = builder.ConstantR1<float>({1.0f, 1.0f}); 744ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 745ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto grad_output = builder.ConstantR4FromArray4D<float>( 746ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); 747ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 748ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower builder.BatchNormGrad(operand, scale, mean, var, grad_output, 749ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower /*epsilon=*/0.0, kFeatureIndex); 750ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 751ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto expected = *Literal::MakeTuple( 752030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower {Literal::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}}, 753030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}) 754ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower .get(), 755ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower Literal::CreateR1<float>({0, 0}).get(), 756ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower Literal::CreateR1<float>({16, 20}).get()}); 757ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 758ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); 759ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower} 760ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 7611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace 7621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace xla 763