11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License"); 41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License. 51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at 61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins http://www.apache.org/licenses/LICENSE-2.0 81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software 101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS, 111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and 131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License. 141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/ 151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <cmath> 171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <memory> 181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <vector> 191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/array2d.h" 211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/array4d.h" 221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation.h" 231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation_builder.h" 241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/lib/arithmetic.h" 251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/local_client.h" 261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/literal_util.h" 271464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/reference_util.h" 281464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_computation.h" 291464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_instruction.h" 301464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_module.h" 311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/shape_util.h" 321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/statusor.h" 331464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/test.h" 341464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/test_helpers.h" 351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/client_library_test_base.h" 361464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/tests/hlo_test_base.h" 371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/tests/literal_test_util.h" 381464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/tests/test_macros.h" 391464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/tests/test_utils.h" 401464b9930de871fd11870941963253670f737c23A. Unique TensorFlower#include "tensorflow/compiler/xla/util.h" 411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/xla_data.pb.h" 423d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower#include "tensorflow/core/lib/math/math_util.h" 43913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar#include "tensorflow/core/lib/strings/str_util.h" 441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/logging.h" 451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/test.h" 461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/types.h" 471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla { 491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace { 501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 516a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebarclass BatchNormalizationTest 526a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar : public ClientLibraryTestBase, 536a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar public ::testing::WithParamInterface<bool /*use_cudnn_batchnorm*/> { 541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins protected: 551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins BatchNormalizationTest() : input_array_(kSamples, kZ, kY, kX) { 566a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar mutable_debug_options()->set_xla_gpu_use_cudnn_batchnorm(GetParam()); 576a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array2D<float> pz({ 591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // z0 z1 601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {-1.0f, 4.1f}, // p0 611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {2.0f, 4.1f}, // p1 621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {5.0f, 4.4f}, // p2 631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins }); 641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins input_array_.FillWithPZ(pz); 657d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan input_literal_ = std::move(*Literal::CreateR4FromArray4D(input_array_)); 661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(kSamples, input_array_.planes()); 671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(kZ, input_array_.depth()); 681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(kY, input_array_.height()); 691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK_EQ(kY, input_array_.width()); 701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins static constexpr int64 kSamples = 3; 731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins static constexpr int64 kX = 1; 741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins static constexpr int64 kY = 1; 751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins static constexpr int64 kZ = 2; 761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array4D<float> input_array_; 781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Literal input_literal_; 791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const ErrorSpec error_spec_{0.001, 0.001}; 801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}; 811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 826a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar// If testing the GPU backend, run the tests twice, with and without cudnn 836a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar// batchnorm. Otherwise, just run the tests once -- the value of this flag 846a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar// doesn't matter. 856a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar#ifdef XLA_TEST_BACKEND_GPU 866a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarINSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest, 876a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar ::testing::Bool()); 886a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar#else 896a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarINSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest, 906a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar ::testing::Values(false)); 916a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar#endif 926a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 936a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormalizationTest, SubtractInZ) { 941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder builder(client_, "subtract_in_z_one_sample"); 951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto x = builder.ConstantLiteral(input_literal_); 961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto y = builder.ConstantR1<float>({3.14, 4.25}); 971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.Sub(x, y, /*broadcast_dimensions=*/{1}); 981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array4D<float> expected(kSamples, kZ, kY, kX); 1001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array2D<float> pz({ 1011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {-1.0f - 3.14f, 4.1f - 4.25f}, // p0 1021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {2.0f - 3.14f, 4.1f - 4.25f}, // p1 1031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {5.0f - 3.14f, 4.4f - 4.25f}, // p2 1041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins }); 1051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins expected.FillWithPZ(pz); 1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); 1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1096a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) { 1101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder builder(client_, "square_tesseract_elementwise"); 1111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto x = builder.ConstantLiteral(input_literal_); 1121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.SquareF32(x); 1131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1143d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower using tensorflow::MathUtil; 1153d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower 1161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array4D<float> expected(kSamples, kZ, kY, kX); 1171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array2D<float> expected_pz({ 1183d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower {MathUtil::IPow(-1.0f, 2), MathUtil::IPow(4.1f, 2)}, 1193d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower {MathUtil::IPow(2.0f, 2), MathUtil::IPow(4.1f, 2)}, 1203d05024db5fc0c39ee3f5626639bd611f44ac03cA. Unique TensorFlower {MathUtil::IPow(5.0f, 2), MathUtil::IPow(4.4f, 2)}, 1211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins }); 1221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins expected.FillWithPZ(expected_pz); 1231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); 1241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1266a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormalizationTest, SumToZ) { 1271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder builder(client_, "sum_to_z"); 1281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto input_activations = builder.ConstantLiteral(input_literal_); 1291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Computation add = CreateScalarAddComputation(F32, &builder); 1301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Reduce all but the Z dimension. 1311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add, 1321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {0, 2, 3}); 1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<float> expected = {6, 12.6}; 1351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); 1361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1386a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormalizationTest, SquareAndReduce) { 1391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder builder(client_, "square_and_reduce"); 1401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto input_activations = builder.ConstantLiteral(input_literal_); 1411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto set_means = builder.ConstantR1<float>({2.f, 4.2f}); 1421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto activation_deviations = builder.Sub(input_activations, set_means, 1431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*broadcast_dimensions=*/{1}); 1441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Computation add = CreateScalarAddComputation(F32, &builder); 1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto dev_squares = builder.SquareF32(activation_deviations); 1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto sum_of_squares = builder.Reduce( 1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins dev_squares, builder.ConstantR0<float>(0.0f), add, {0, 2, 3}); 1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<float> expected = {18, 0.06}; 1501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); 1511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1536a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormalizationTest, VarianceToStddev) { 1541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder builder(client_, "variance_to_stddev"); 1551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto variance = builder.ConstantR1<float>({6.f, .02f}); 1561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto sqrt = builder.SqrtF32(variance); 1571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<float> expected = {2.44948974f, 0.14142136f}; 1591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); 1601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// Compare against a forward batch normalization example in the NN spec 1631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// reference. 1646a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) { 1651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputationBuilder builder(client_, "batch_normalize_per_spec"); 1661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto input_activations = 1671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.CheckShape(builder.ConstantLiteral(input_literal_), 1681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::MakeShape(F32, {3, 2, 1, 1})); 1691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto gamma = builder.ConstantR1<float>({1.0, 1.0}); 1701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto beta = builder.ConstantR1<float>({0.0, 0.0}); 1711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Computation add = CreateScalarAddComputation(F32, &builder); 1721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Reduce all dimensions except dimension 1. 1731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2}); 1741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto sum = builder.CheckShape( 1751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add, 1761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*dimensions_to_reduce=*/{0, 2, 3}), 1771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TwoElementVectorF32); 1781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto input_shape = builder.GetShape(input_activations).ConsumeValueOrDie(); 1791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto sum_shape = builder.GetShape(sum).ConsumeValueOrDie(); 1801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto count = builder.ConstantR0<float>(ShapeUtil::ElementsIn(*input_shape) / 1811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ShapeUtil::ElementsIn(*sum_shape)); 1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto set_means = builder.Div(sum, count); 1831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const float kEpsilon = 1e-9f; 1851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto epsilon = builder.ConstantR0<float>(kEpsilon); 1861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto epsilon2 = builder.ConstantR1<float>({kEpsilon, kEpsilon}); 1871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto activation_deviations = builder.Sub(input_activations, set_means, 1881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*broadcast_dimensions=*/{1}); 1891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto dev_squares = builder.SquareF32(activation_deviations); 1901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto sum_of_squares = builder.CheckShape( 1911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.Reduce(dev_squares, builder.ConstantR0<float>(0.0f), add, 1921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*dimensions_to_reduce=*/{0, 2, 3}), 1931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TwoElementVectorF32); 1941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto variance = builder.Div(sum_of_squares, count); 1951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto standard_deviation = builder.SqrtF32(variance); 1961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto standard_deviation_above_epsilon = builder.CheckShape( 1971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2})); 1981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto gt_eps = builder.Select(standard_deviation_above_epsilon, 1991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins standard_deviation, epsilon2); 2001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto normalization_factors = builder.ReciprocalF32(gt_eps); 2011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto normalized_input_activations = 2021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.Mul(activation_deviations, normalization_factors, 2031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*broadcast_dimensions=*/{1}); 2041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /* auto output_activations = */ builder.Add( 2051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins builder.Mul(normalized_input_activations, gamma, 2061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins /*broadcast_dimensions=*/{1}), 2071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins beta, /*broadcast_dimensions=*/{1}); 2081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array4D<float> expected(kSamples, kZ, kY, kX); 2101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Array2D<float> pz({ 2111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {-3.f / std::sqrt(6.f), -.1f / std::sqrt(.02f)}, 2121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {0.f, -.1f / std::sqrt(.02f)}, 2131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins {3.f / std::sqrt(6.f), .2f / std::sqrt(.02f)}, 2141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins }); 2151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins expected.FillWithPZ(pz); 2161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); 2181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 2191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2206a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormalizationTest, BasicTraining) { 2216a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar const int kFeatureIndex = 3; 2226a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar ComputationBuilder builder(client_, TestName()); 2236a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 2246a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar auto operand = builder.ConstantR4FromArray4D<float>( 2256a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}}); 2266a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 2276a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar auto scale = builder.ConstantR1<float>({2.0f, 3.0f}); 2286a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 2296a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar auto offset = builder.ConstantR1<float>({1.0f, 2.0f}); 2306a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 2316a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar auto tuple = builder.BatchNormTraining(operand, scale, offset, 2326a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar /*epsilon=*/0.001, kFeatureIndex); 2336a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 2347d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan auto expected = Literal::MakeTuple( 2356a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar {Literal::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, 2366a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}) 2376a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar .get(), 2386a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar Literal::CreateR1<float>({4, 5}).get(), 2396a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar Literal::CreateR1<float>({5, 5}).get()}); 2406a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 2417d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); 2426a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar} 2436a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 2446a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormalizationTest, BasicTrainingOnSublane) { 2456a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar const int kFeatureIndex = 2; 2466a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar ComputationBuilder builder(client_, TestName()); 2476a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 2486a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar auto operand = builder.ConstantR4FromArray4D<float>( 2496a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); 2506a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 2516a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar auto scale = builder.ConstantR1<float>({2.0f, 3.0f}); 2526a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 2536a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar auto offset = builder.ConstantR1<float>({1.0f, 2.0f}); 2546a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 2556a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar auto tuple = builder.BatchNormTraining(operand, scale, offset, 2566a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar /*epsilon=*/0.001, kFeatureIndex); 2576a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 2587d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan auto expected = Literal::MakeTuple( 2596a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar {Literal::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, 2606a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}) 2616a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar .get(), 2626a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar Literal::CreateR1<float>({4, 5}).get(), 2636a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar Literal::CreateR1<float>({5, 5}).get()}); 2646a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 2657d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); 2666a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar} 2676a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 2686a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { 2696a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar // Use 0 dimension as feature, tests layout analyzer. 2706a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar const int kFeatureIndex = 0; 2716a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar ComputationBuilder builder(client_, TestName()); 2726a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 2736a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar ComputationDataHandle h0; 2746a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar auto operand = CreateR3Parameter<float>(Array3D<float>(260, 2, 2, 1.0f), 2756a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar /*parameter_number=*/0, "operand", 2766a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar &builder, &h0); 2776a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar ComputationDataHandle h1; 2786a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar auto scale = 2796a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar CreateR1Parameter<float>(std::vector<float>(260, 1.0f), 2806a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar /*parameter_number=*/1, "scale", &builder, &h1); 2816a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar ComputationDataHandle h2; 2826a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar auto offset = 2836a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar CreateR1Parameter<float>(std::vector<float>(260, 1.0f), 2846a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar /*parameter_number=*/2, "offset", &builder, &h2); 2856a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 2866a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar auto tuple = builder.BatchNormTraining(h0, h1, h2, 2876a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar /*epsilon=*/1, kFeatureIndex); 2886a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 2897d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan auto expected = Literal::MakeTuple( 2906a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar {Literal::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f)) 2916a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar .get(), 2926a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar Literal::CreateR1<float>(std::vector<float>(260, 1.0f)).get(), 2936a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar Literal::CreateR1<float>(std::vector<float>(260, 0.0f)).get()}); 2946a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 2957d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan ComputeAndCompareTuple(&builder, *expected, 2966a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar {operand.get(), scale.get(), offset.get()}, 2976a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar ErrorSpec(0.1)); 2986a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar} 2996a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 3006a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) { 3016a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar // Test the correctness of choosing a large epsilon value. 3026a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar const int kFeatureIndex = 2; 3036a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar ComputationBuilder builder(client_, TestName()); 3046a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 3056a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar ComputationDataHandle h0; 3066a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar auto operand = CreateR3Parameter<float>({{{0.0f}, {10.0f}, {20.0f}, {30.0f}}}, 3076a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar /*parameter_number=*/0, "operand", 3086a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar &builder, &h0); 3096a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar ComputationDataHandle h1; 3106a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar auto scale = 3116a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar CreateR1Parameter<float>(std::vector<float>(1, 1.0f), 3126a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar /*parameter_number=*/1, "scale", &builder, &h1); 3136a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar ComputationDataHandle h2; 3146a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar auto offset = 3156a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar CreateR1Parameter<float>(std::vector<float>(1, 0.0f), 3166a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar /*parameter_number=*/2, "offset", &builder, &h2); 3176a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 3186a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar // var = 125, mean = 15, epsilon = -100 3196a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar auto tuple = builder.BatchNormTraining(h0, h1, h2, 3206a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar /*epsilon=*/-100, kFeatureIndex); 3216a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 3227d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan auto expected = Literal::MakeTuple( 3236a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar {Literal::CreateR3FromArray3D<float>({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) 3246a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar .get(), 3256a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar Literal::CreateR1<float>(std::vector<float>(1, 15.0f)).get(), 3266a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar Literal::CreateR1<float>(std::vector<float>(1, 125.0f)).get()}); 3276a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 3287d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan ComputeAndCompareTuple(&builder, *expected, 3296a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar {operand.get(), scale.get(), offset.get()}, 3306a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar ErrorSpec(0.1)); 3316a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar} 3326a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 3336a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) { 3346a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar const int kFeatureIndex = 2; 3356a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar ComputationBuilder builder(client_, TestName()); 3366a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 3376a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar auto operand = 3386a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar builder.ConstantR4FromArray4D<float>(Array4D<float>(2, 2, 2, 1, 0.0f)); 3396a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 3406a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar auto scale = builder.ConstantR1<float>({1.0f, 1.0f}); 3416a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 3426a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar auto mean = builder.ConstantR1<float>({0.0f, 0.0f}); 3436a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 3446a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar auto var = builder.ConstantR1<float>({1.0f, 1.0f}); 3456a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 3466a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar auto grad_output = builder.ConstantR4FromArray4D<float>( 3476a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); 3486a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 3496a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar builder.BatchNormGrad(operand, scale, mean, var, grad_output, 3506a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar /*epsilon=*/0.0, kFeatureIndex); 3516a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 3527d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan auto expected = Literal::MakeTuple( 3536a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar {Literal::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}}, 3546a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}) 3556a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar .get(), 3566a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar Literal::CreateR1<float>({0, 0}).get(), 3576a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar Literal::CreateR1<float>({16, 20}).get()}); 3586a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 3597d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); 3606a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar} 3616a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 3621464b9930de871fd11870941963253670f737c23A. Unique TensorFlowerstruct BatchNormTestParam { 3631464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<int64> bounds; 3641464b9930de871fd11870941963253670f737c23A. Unique TensorFlower int64 feature_index; 3651464b9930de871fd11870941963253670f737c23A. Unique TensorFlower float random_value_mean; 3661464b9930de871fd11870941963253670f737c23A. Unique TensorFlower float random_value_var; 3676a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar bool use_cudnn_batchnorm; 368913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar 369913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar friend ::std::ostream& operator<<(::std::ostream& os, 370913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar const BatchNormTestParam& p) { 371913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar os << "bounds={" << tensorflow::str_util::Join(p.bounds, ", ") << "}, "; 372913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar os << "feature_index=" << p.feature_index << ", "; 373913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar os << "random_value_mean=" << p.random_value_mean << ", "; 374913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar os << "random_value_var=" << p.random_value_var; 3756a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 3766a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar // Don't print use_cudnn_batchnorm when it's false, because most backends 3776a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar // never set it to true. 3786a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar if (p.use_cudnn_batchnorm) { 3796a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar os << ", use_cudnn_batchnorm=true"; 3806a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar } 381913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar return os; 382913175c2bd38f6e97de399b29cfe1195bffbaa25Justin Lebar } 3831464b9930de871fd11870941963253670f737c23A. Unique TensorFlower}; 3841464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 3851464b9930de871fd11870941963253670f737c23A. Unique TensorFlower// Tests to test the fused operation of BatchNorm. 3866a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebarclass BatchNormTestManySizes 3876a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar : public ClientLibraryTestBase, 3886a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar public ::testing::WithParamInterface<BatchNormTestParam> { 3896a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar public: 3906a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar BatchNormTestManySizes() { 3916a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar mutable_debug_options()->set_xla_gpu_use_cudnn_batchnorm( 3926a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar GetParam().use_cudnn_batchnorm); 3936a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar } 3941464b9930de871fd11870941963253670f737c23A. Unique TensorFlower}; 3951464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 3966a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebarstd::vector<BatchNormTestParam> BuildBatchNormTestParams() { 3976a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar std::vector<BatchNormTestParam> params; 3986a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 3996a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar auto add_testcase = [&](std::vector<int64> bounds, int64 feature_index, 4006a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar float random_value_mean, float random_value_var) { 4016a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar BatchNormTestParam p{bounds, feature_index, random_value_mean, 4026a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar random_value_var, /*use_cudnn_batchnorm=*/false}; 4036a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar params.push_back(p); 4046a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 4056a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar // If testing the GPU backend, also run with cudnn batchnorm enabled. 4066a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar#ifdef XLA_TEST_BACKEND_GPU 4076a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar p.use_cudnn_batchnorm = true; 4086a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar params.push_back(p); 4096a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar#endif 4106a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar }; 4116a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 4126a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar add_testcase({2, 2, 2, 2}, 0, 100.2f, 200.0f); 4136a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar add_testcase({2, 2, 2, 2}, 3, 300.f, 400.0f); 4146a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 4156a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar add_testcase({1, 10, 1, 1}, 0, 10.1f, 20.1f); 4166a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar add_testcase({10, 10, 10, 10}, 1, 3.14f, 314.15f); 4176a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar add_testcase({10, 10, 10, 10}, 2, 666.6f, 777.7f); 4186a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar add_testcase({10, 10, 10, 10}, 1, -666.6f, 777.7f); 4196a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar add_testcase({10, 10, 10, 10}, 2, 0.f, 777.7f); 4206a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar add_testcase({1, 1, 10, 130}, 2, 0.f, 777.7f); 4216a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar add_testcase({1, 1, 130, 11}, 2, 0.f, 777.7f); 4226a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar add_testcase({1, 1, 10, 1}, 3, 888.8f, 9.9f); 4236a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 4246a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar add_testcase({24, 129, 1, 2}, 2, 10000, 10000); 4256a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar add_testcase({24, 129, 1, 2}, 3, 10000, 10000); 4266a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 4276a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar // Feature on low dimension to trigger relayout, check that internal logical 4286a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar // to physical dimension calculation is correct after relayout. 4296a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar add_testcase({1, 2, 3, 4}, 0, 100, 100); 4306a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 4311654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar // Zero-sized tensor. 4321654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar add_testcase({1, 0, 100, 42}, 0, 100, 100); 4331654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar 4346a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar return params; 4356a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar} 4366a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 4376a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarINSTANTIATE_TEST_CASE_P(BatchNormTest_Instantiation, BatchNormTestManySizes, 4386a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar ::testing::ValuesIn(BuildBatchNormTestParams())); 4396a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin Lebar 4406a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { 4411464b9930de871fd11870941963253670f737c23A. Unique TensorFlower float epsilon = 0.001; 4421464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputationBuilder builder(client_, TestName()); 4431464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const std::vector<int64>& bounds = GetParam().bounds; 4441464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]); 4451464b9930de871fd11870941963253670f737c23A. Unique TensorFlower input_array.FillRandom(GetParam().random_value_var, 4461464b9930de871fd11870941963253670f737c23A. Unique TensorFlower GetParam().random_value_mean); 4471464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 4481464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const int64 feature_index = GetParam().feature_index; 4491464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const int64 num_elements_per_feature = 4501464b9930de871fd11870941963253670f737c23A. Unique TensorFlower Product(bounds) / bounds[feature_index]; 4511464b9930de871fd11870941963253670f737c23A. Unique TensorFlower const int64 feature_bound = bounds[feature_index]; 4521464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<float> offset(feature_bound, 1); 4531464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<float> scale(feature_bound, 2); 4541464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 4551464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto input_squared = 4561464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; }); 4571464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<int64> reduce_dims; 458030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) { 4591464b9930de871fd11870941963253670f737c23A. Unique TensorFlower if (i != feature_index) { 4601464b9930de871fd11870941963253670f737c23A. Unique TensorFlower reduce_dims.push_back(i); 4611464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 4621464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 4631464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 4641464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto sum = 4651464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims, 4661464b9930de871fd11870941963253670f737c23A. Unique TensorFlower [](float a, float b) { return a + b; }); 4671464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 4681464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto sum_squared = 4691464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims, 4701464b9930de871fd11870941963253670f737c23A. Unique TensorFlower [](float a, float b) { return a + b; }); 4711464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 4721464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<float> mean(feature_bound); 4731464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 4741464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 4751464b9930de871fd11870941963253670f737c23A. Unique TensorFlower mean[i] = sum[i] / num_elements_per_feature; 4761464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 4771464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 4781464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<float> mean_square(feature_bound); 4791464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 4801464b9930de871fd11870941963253670f737c23A. Unique TensorFlower mean_square[i] = mean[i] * mean[i]; 4811464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 4821464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 4831464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<float> square_mean(feature_bound); 4841464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 4851464b9930de871fd11870941963253670f737c23A. Unique TensorFlower square_mean[i] = sum_squared[i] / num_elements_per_feature; 4861464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 4871464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 4881464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::vector<float> var(feature_bound); 4891464b9930de871fd11870941963253670f737c23A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 4901464b9930de871fd11870941963253670f737c23A. Unique TensorFlower var[i] = square_mean[i] - mean_square[i]; 4911464b9930de871fd11870941963253670f737c23A. Unique TensorFlower } 4921464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 493f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower Array4D<float> mean4D = 4941464b9930de871fd11870941963253670f737c23A. Unique TensorFlower *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index); 495f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index); 496f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index); 497f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower auto offset4D = 4981464b9930de871fd11870941963253670f737c23A. Unique TensorFlower *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index); 4991464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 500f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D, 501f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower scale4D, offset4D, epsilon); 5021464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 5031464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto expected_normalized = Literal::CreateR4FromArray4D<float>(normalized); 5041464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 5051464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto offset_literal = Literal::CreateR1<float>(offset); 5061464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto scale_literal = Literal::CreateR1<float>(scale); 5071464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto input_literal = Literal::CreateR4FromArray4D<float>(input_array); 5081464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 5091464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto input_activations = 5101464b9930de871fd11870941963253670f737c23A. Unique TensorFlower builder.Parameter(0, input_literal->shape(), "input"); 5111464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto scale_activations = 5121464b9930de871fd11870941963253670f737c23A. Unique TensorFlower builder.Parameter(1, scale_literal->shape(), "offset"); 5131464b9930de871fd11870941963253670f737c23A. Unique TensorFlower auto offset_activations = 5141464b9930de871fd11870941963253670f737c23A. Unique TensorFlower builder.Parameter(2, offset_literal->shape(), "scale"); 5151464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 5167d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan auto expected = Literal::MakeTuple({expected_normalized.get(), 5177d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan Literal::CreateR1<float>(mean).get(), 5187d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan Literal::CreateR1<float>(var).get()}); 5191464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 5201464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::unique_ptr<GlobalData> input_data = 5211464b9930de871fd11870941963253670f737c23A. Unique TensorFlower client_->TransferToServer(*input_literal).ConsumeValueOrDie(); 5221464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::unique_ptr<GlobalData> scale_data = 5231464b9930de871fd11870941963253670f737c23A. Unique TensorFlower client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); 5241464b9930de871fd11870941963253670f737c23A. Unique TensorFlower std::unique_ptr<GlobalData> offset_data = 5251464b9930de871fd11870941963253670f737c23A. Unique TensorFlower client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); 5261464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 5271464b9930de871fd11870941963253670f737c23A. Unique TensorFlower builder.BatchNormTraining(input_activations, scale_activations, 5281464b9930de871fd11870941963253670f737c23A. Unique TensorFlower offset_activations, epsilon, feature_index); 5291464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 5301654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar // Run all HLO passes during this test. In particular, ClientLibraryTestBase 5311654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar // disables constant folding, but we want it enabled for our zero-sized tensor 5321654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar // testcase. 5331654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); 5341464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ComputeAndCompareTuple( 5357d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan &builder, *expected, 5361464b9930de871fd11870941963253670f737c23A. Unique TensorFlower {input_data.get(), scale_data.get(), offset_data.get()}, 5371464b9930de871fd11870941963253670f737c23A. Unique TensorFlower ErrorSpec(0.01, 1)); 5381464b9930de871fd11870941963253670f737c23A. Unique TensorFlower} 5391464b9930de871fd11870941963253670f737c23A. Unique TensorFlower 5406a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) { 5417359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower float epsilon = 0.001; 5427359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower ComputationBuilder builder(client_, TestName()); 5437359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower const std::vector<int64>& bounds = GetParam().bounds; 5447359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]); 5457359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower input_array.FillRandom(GetParam().random_value_var, 5467359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower GetParam().random_value_mean); 5477359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 5487359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower const int64 feature_index = GetParam().feature_index; 5497359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower const int64 num_elements_per_feature = 5507359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower Product(bounds) / bounds[feature_index]; 5517359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower const int64 feature_bound = bounds[feature_index]; 5527359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::vector<float> offset(feature_bound, 1); 5537359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::vector<float> scale(feature_bound, 2); 5547359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 5557359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto input_squared = 5567359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; }); 5577359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::vector<int64> reduce_dims; 5587359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) { 5597359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower if (i != feature_index) { 5607359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower reduce_dims.push_back(i); 5617359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower } 5627359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower } 5637359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 5647359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto sum = 5657359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims, 5667359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower [](float a, float b) { return a + b; }); 5677359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 5687359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto sum_squared = 5697359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims, 5707359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower [](float a, float b) { return a + b; }); 5717359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 5727359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::vector<float> mean(feature_bound); 5737359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 5747359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 5757359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower mean[i] = sum[i] / num_elements_per_feature; 5767359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower } 5777359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 5787359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::vector<float> mean_square(feature_bound); 5797359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 5807359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower mean_square[i] = mean[i] * mean[i]; 5817359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower } 5827359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 5837359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::vector<float> square_mean(feature_bound); 5847359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 5857359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower square_mean[i] = sum_squared[i] / num_elements_per_feature; 5867359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower } 5877359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 5887359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::vector<float> var(feature_bound); 5897359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 5907359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower var[i] = square_mean[i] - mean_square[i]; 5917359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower } 5927359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 5937359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower Array4D<float> mean4D = 5947359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index); 5957359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index); 5967359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index); 5977359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto offset4D = 5987359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index); 5997359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 6007359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D, 6017359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower scale4D, offset4D, epsilon); 6027359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 6037359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto offset_literal = Literal::CreateR1<float>(offset); 6047359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto scale_literal = Literal::CreateR1<float>(scale); 6057359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto mean_literal = Literal::CreateR1<float>(mean); 6067359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto var_literal = Literal::CreateR1<float>(var); 6077359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto input_literal = Literal::CreateR4FromArray4D<float>(input_array); 6087359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 6097359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto input_activations = 6107359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower builder.Parameter(0, input_literal->shape(), "input"); 6117359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto scale_activations = 6127359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower builder.Parameter(1, scale_literal->shape(), "offset"); 6137359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto offset_activations = 6147359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower builder.Parameter(2, offset_literal->shape(), "scale"); 6157359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto mean_activations = builder.Parameter(3, mean_literal->shape(), "mean"); 6167359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower auto variance_activations = 6177359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower builder.Parameter(4, var_literal->shape(), "variance"); 6187359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 6197359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower Array4D<float> expected = normalized; 6207359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 6217359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::unique_ptr<GlobalData> input_data = 6227359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower client_->TransferToServer(*input_literal).ConsumeValueOrDie(); 6237359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::unique_ptr<GlobalData> scale_data = 6247359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); 6257359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::unique_ptr<GlobalData> offset_data = 6267359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); 6277359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::unique_ptr<GlobalData> mean_data = 6287359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower client_->TransferToServer(*mean_literal).ConsumeValueOrDie(); 6297359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower std::unique_ptr<GlobalData> variance_data = 6307359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower client_->TransferToServer(*var_literal).ConsumeValueOrDie(); 6317359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 6327359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower builder.BatchNormInference(input_activations, scale_activations, 6337359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower offset_activations, mean_activations, 6347359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower variance_activations, epsilon, feature_index); 6357359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 6361654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar // Run all HLO passes during this test. In particular, ClientLibraryTestBase 6371654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar // disables constant folding, but we want it enabled for our zero-sized tensor 6381654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar // testcase. 6391654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); 6401654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar 6417359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower ComputeAndCompareR4<float>( 6427359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower &builder, expected, 6437359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower {input_data.get(), scale_data.get(), offset_data.get(), mean_data.get(), 6447359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower variance_data.get()}, 6457359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower ErrorSpec(0.01, 1)); 6467359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower} 6477359fec792e4efec1670a12332bb524a5608b215A. Unique TensorFlower 6486a9a9ed0e1f5eded19d793b2be125d2d845cf079Justin LebarXLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { 649ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower float epsilon = 0.001; 650ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ComputationBuilder builder(client_, TestName()); 651ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower const std::vector<int64>& bounds = GetParam().bounds; 652ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]); 653ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower input_array.FillRandom(GetParam().random_value_var, 654ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower GetParam().random_value_mean); 655ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 656ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower Array4D<float> grad_output_array(bounds[0], bounds[1], bounds[2], bounds[3]); 657ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower grad_output_array.FillRandom(GetParam().random_value_var, 658ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower GetParam().random_value_mean); 659ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 660ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower const int64 feature_index = GetParam().feature_index; 661ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower const int64 num_elements_per_feature = 662ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower Product(bounds) / bounds[feature_index]; 663ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower const int64 feature_bound = bounds[feature_index]; 664ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::vector<float> scale(feature_bound, 2); 665ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 666ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto input_squared = 667ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; }); 668ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::vector<int64> reduce_dims; 669030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) { 670ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower if (i != feature_index) { 671ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower reduce_dims.push_back(i); 672ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower } 673ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower } 674ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 675ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto sum = 676ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims, 677ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower [](float a, float b) { return a + b; }); 678ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 679ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto sum_squared = 680ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims, 681ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower [](float a, float b) { return a + b; }); 682ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 683ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::vector<float> mean(feature_bound); 684ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 685ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 6861654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar if (num_elements_per_feature > 0) { 6871654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar mean[i] = sum[i] / num_elements_per_feature; 6881654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar } else { 6891654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar mean[i] = 0; 6901654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar } 691ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower } 692ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 693ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::vector<float> mean_square(feature_bound); 694ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 695ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower mean_square[i] = mean[i] * mean[i]; 696ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower } 697ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 698ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::vector<float> square_mean(feature_bound); 699ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 7001654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar if (num_elements_per_feature > 0) { 7011654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar square_mean[i] = sum_squared[i] / num_elements_per_feature; 7021654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar } else { 7031654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar square_mean[i] = 0; 7041654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar } 705ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower } 706ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 707ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::vector<float> var(feature_bound); 708ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower for (int64 i = 0; i < feature_bound; ++i) { 709ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower var[i] = square_mean[i] - mean_square[i]; 710ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower } 711ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 712f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower Array4D<float> mean4D = 713ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index); 714f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index); 715f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index); 716ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 717ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto var_add_epsilon = *ReferenceUtil::MapArray4D( 718030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower var4D, [epsilon](float a) { return a + epsilon; }); 719030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 720030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D( 721030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower var_add_epsilon, [epsilon](float a) { return 1 / std::sqrt(a); }); 722ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 723ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto grad_output_times_var = 724ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower *ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon, 725ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower [](float a, float b) { return a * b; }); 726ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 727ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto activation_shifted = *ReferenceUtil::MapArray4D( 728f4230546f030e6a4f8b416ae69952bf15eca2ec9A. Unique TensorFlower input_array, mean4D, [](float a, float b) { return a - b; }); 729ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 730030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto activation_shifted_times_grad_output = 731030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower *ReferenceUtil::MapArray4D(grad_output_array, activation_shifted, 732ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower [](float a, float b) { return a * b; }); 733ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 734030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto grad_scale_before_reduction = *ReferenceUtil::MapArray4D( 735030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower activation_shifted_times_grad_output, rsqrt_var_add_epsilon, 736030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower [](float a, float b) { return a * b; }); 737030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 738ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto grad_scale = ReferenceUtil::Reduce4DTo1D( 739ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower grad_scale_before_reduction, /*init=*/0.0f, reduce_dims, 740ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower [](float a, float b) { return a + b; }); 741ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 742ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto grad_offset = 743ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(grad_output_array, /*init=*/0.0f, reduce_dims, 744ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower [](float a, float b) { return a + b; }); 745ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 746030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto scale_times_rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D( 747030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower scale4D, rsqrt_var_add_epsilon, [](float a, float b) { return a * b; }); 748030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 749030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto I1 = *ReferenceUtil::MapArray4D( 750030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower grad_output_array, [&](float a) { return num_elements_per_feature * a; }); 751030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 752030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto I2 = *ReferenceUtil::Broadcast1DTo4D(grad_offset, bounds, feature_index); 753030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 754030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower // I3 = sum(output_grad * (activation - mean(activation))) 755030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto I3 = *ReferenceUtil::Broadcast1DTo4D( 756030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower ReferenceUtil::Reduce4DTo1D(activation_shifted_times_grad_output, 757030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower /*init=*/0.0f, reduce_dims, 758030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower [](float a, float b) { return a + b; }), 759030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower bounds, feature_index); 760030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 761030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower // I4 = (activation - mean(activation)) * 762030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower // sum(output_grad * (activation - mean(activation))) 763030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto I4 = *ReferenceUtil::MapArray4D(I3, activation_shifted, 764030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower [](float a, float b) { return a * b; }); 765030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 766030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower // I5 = (activation - mean(activation)) * 767030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower // sum(output_grad * (activation - mean(activation))) / (variance + 768030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower // epsilon)) 769030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto I5 = *ReferenceUtil::MapArray4D(I4, var_add_epsilon, 770030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower [](float a, float b) { return a / b; }); 771030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 772030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower auto grad_activation = *ReferenceUtil::MapArray4D( 773030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower I1, I2, [](float a, float b) { return a - b; }); 774030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 775030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower grad_activation = *ReferenceUtil::MapArray4D( 776030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower grad_activation, I5, [](float a, float b) { return a - b; }); 777030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 778030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower grad_activation = *ReferenceUtil::MapArray4D( 779030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower grad_activation, scale4D, [](float a, float b) { return a * b; }); 780030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 781030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower grad_activation = *ReferenceUtil::MapArray4D( 7821654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar grad_activation, rsqrt_var_add_epsilon, [=](float a, float b) { 7831654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar if (num_elements_per_feature > 0) { 7841654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar return a * b / num_elements_per_feature; 7851654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar } 7861654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar return 0.f; 7871654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar }); 788030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060A. Unique TensorFlower 789ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto expected_grad_activation = 790ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower Literal::CreateR4FromArray4D<float>(grad_activation); 791ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 792ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto input_literal = Literal::CreateR4FromArray4D<float>(input_array); 793ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto scale_literal = Literal::CreateR1<float>(scale); 794ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto mean_literal = Literal::CreateR1<float>(mean); 795ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto var_literal = Literal::CreateR1<float>(var); 796ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto grad_output_literal = 797ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower Literal::CreateR4FromArray4D<float>(grad_output_array); 798ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 799ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto input_parameter = builder.Parameter(0, input_literal->shape(), "input"); 800ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto scale_parameter = builder.Parameter(1, scale_literal->shape(), "scale"); 801ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto mean_parameter = builder.Parameter(2, mean_literal->shape(), "mean"); 802ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto var_parameter = builder.Parameter(3, var_literal->shape(), "variance"); 803ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto grad_output_parameter = 804ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower builder.Parameter(4, grad_output_literal->shape(), "grad_output"); 805ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 806ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::unique_ptr<GlobalData> input_data = 807ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower client_->TransferToServer(*input_literal).ConsumeValueOrDie(); 808ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::unique_ptr<GlobalData> scale_data = 809ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); 810ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::unique_ptr<GlobalData> mean_data = 811ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower client_->TransferToServer(*mean_literal).ConsumeValueOrDie(); 812ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::unique_ptr<GlobalData> var_data = 813ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower client_->TransferToServer(*var_literal).ConsumeValueOrDie(); 814ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower std::unique_ptr<GlobalData> grad_output_data = 815ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie(); 816ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 817ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto t = builder.BatchNormGrad(input_parameter, scale_parameter, 818ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower mean_parameter, var_parameter, 819ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower grad_output_parameter, epsilon, feature_index); 820ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 821ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower auto expected = 8227d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan Literal::MakeTuple({expected_grad_activation.get(), 8237d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan Literal::CreateR1<float>(grad_scale).get(), 8247d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan Literal::CreateR1<float>(grad_offset).get()}); 825ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 8261654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar // Run all HLO passes during this test. In particular, ClientLibraryTestBase 8271654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar // disables constant folding, but we want it enabled for our zero-sized tensor 8281654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar // testcase. 8291654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); 8301654f7a5032776ad34aeed2c94c2bd77a72d8cafJustin Lebar 8317d64e124103c8334b7d8b127cd2eff786959d185Mark Heffernan ComputeAndCompareTuple(&builder, *expected, 832ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower {input_data.get(), scale_data.get(), mean_data.get(), 833ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower var_data.get(), grad_output_data.get()}, 834ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower ErrorSpec(0.01, 1)); 835ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower} 836ac209ebc8fc8780bb3121a33740e10a34352996fA. Unique TensorFlower 8371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace 8381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace xla 839