1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include <cmath> 17#include <memory> 18#include <vector> 19 20#include "tensorflow/compiler/xla/array2d.h" 21#include "tensorflow/compiler/xla/array4d.h" 22#include "tensorflow/compiler/xla/client/computation.h" 23#include "tensorflow/compiler/xla/client/computation_builder.h" 24#include "tensorflow/compiler/xla/client/lib/arithmetic.h" 25#include "tensorflow/compiler/xla/client/local_client.h" 26#include "tensorflow/compiler/xla/literal_util.h" 27#include "tensorflow/compiler/xla/reference_util.h" 28#include "tensorflow/compiler/xla/service/hlo_computation.h" 29#include "tensorflow/compiler/xla/service/hlo_instruction.h" 30#include "tensorflow/compiler/xla/service/hlo_module.h" 31#include "tensorflow/compiler/xla/shape_util.h" 32#include "tensorflow/compiler/xla/statusor.h" 33#include "tensorflow/compiler/xla/test.h" 34#include "tensorflow/compiler/xla/test_helpers.h" 35#include "tensorflow/compiler/xla/tests/client_library_test_base.h" 36#include "tensorflow/compiler/xla/tests/hlo_test_base.h" 37#include "tensorflow/compiler/xla/tests/literal_test_util.h" 38#include "tensorflow/compiler/xla/tests/test_macros.h" 39#include "tensorflow/compiler/xla/tests/test_utils.h" 40#include "tensorflow/compiler/xla/util.h" 41#include "tensorflow/compiler/xla/xla_data.pb.h" 42#include "tensorflow/core/platform/logging.h" 43#include "tensorflow/core/platform/test.h" 44#include "tensorflow/core/platform/types.h" 45 46namespace xla { 47namespace { 48 49class Bfloat16Test : public ClientLibraryTestBase { 50 protected: 51 const ErrorSpec error_spec_{0.001, 0.001}; 52}; 53 54XLA_TEST_F(Bfloat16Test, ScalarOperation) { 55 ComputationBuilder builder(client_, TestName()); 56 auto x = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.0f)); 57 auto y = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(1.0f)); 58 builder.Add(x, y); 59 60 ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(3.0f), {}, 61 error_spec_); 62} 63 64XLA_TEST_F(Bfloat16Test, LogOperation) { 65 ComputationBuilder builder(client_, TestName()); 66 auto x = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(4.0f)); 67 builder.Log(x); 68 69 ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(1.387f), {}, 70 error_spec_); 71} 72 73XLA_TEST_F(Bfloat16Test, NegateScalarF16) { 74 ComputationBuilder builder(client_, TestName()); 75 builder.Neg(builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.1f))); 76 77 ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(-2.1f), {}, 78 error_spec_); 79} 80 81XLA_TEST_F(Bfloat16Test, BatchNormTraining) { 82 const int kFeatureIndex = 2; 83 ComputationBuilder builder(client_, TestName()); 84 85 auto operand = builder.ConstantR4FromArray4D<bfloat16>( 86 {{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(2.f)}}, 87 {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(4.f)}}}, 88 {{{static_cast<bfloat16>(5.f)}, {static_cast<bfloat16>(6.f)}}, 89 {{static_cast<bfloat16>(7.f)}, {static_cast<bfloat16>(8.f)}}}}); 90 91 auto scale = builder.ConstantR1<bfloat16>( 92 {static_cast<bfloat16>(2.0f), static_cast<bfloat16>(3.0f)}); 93 94 auto offset = builder.ConstantR1<bfloat16>( 95 {static_cast<bfloat16>(1.0f), static_cast<bfloat16>(2.0f)}); 96 97 auto tuple = builder.BatchNormTraining(operand, scale, offset, 98 /*epsilon=*/0.001, kFeatureIndex); 99 100 auto expected = Literal::MakeTuple( 101 {Literal::CreateR4<bfloat16>( 102 {{{{static_cast<bfloat16>(-1.6875f)}, 103 {static_cast<bfloat16>(-2.04f)}}, 104 {{static_cast<bfloat16>(0.105f)}, {static_cast<bfloat16>(0.66f)}}}, 105 {{{static_cast<bfloat16>(1.89f)}, {static_cast<bfloat16>(3.35f)}}, 106 {{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}}) 107 .get(), 108 Literal::CreateR1<bfloat16>( 109 {static_cast<bfloat16>(4), static_cast<bfloat16>(5)}) 110 .get(), 111 Literal::CreateR1<bfloat16>( 112 {static_cast<bfloat16>(5), static_cast<bfloat16>(5)}) 113 .get()}); 114 115 ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01)); 116} 117 118XLA_TEST_F(Bfloat16Test, BatchNormGrad) { 119 const int kFeatureIndex = 2; 120 ComputationBuilder builder(client_, TestName()); 121 122 auto operand = builder.ConstantR4FromArray4D<bfloat16>( 123 Array4D<bfloat16>(2, 2, 2, 1, static_cast<bfloat16>(0.0f))); 124 125 auto scale = builder.ConstantR1<bfloat16>( 126 {static_cast<bfloat16>(1.0f), static_cast<bfloat16>(1.0f)}); 127 128 auto mean = builder.ConstantR1<bfloat16>( 129 {static_cast<bfloat16>(0.0f), static_cast<bfloat16>(0.0f)}); 130 131 auto var = builder.ConstantR1<bfloat16>( 132 {static_cast<bfloat16>(1.0f), static_cast<bfloat16>(1.0f)}); 133 134 auto grad_output = builder.ConstantR4FromArray4D<bfloat16>( 135 {{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(2.f)}}, 136 {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(4.f)}}}, 137 {{{static_cast<bfloat16>(5.f)}, {static_cast<bfloat16>(6.f)}}, 138 {{static_cast<bfloat16>(7.f)}, {static_cast<bfloat16>(8.f)}}}}); 139 140 builder.BatchNormGrad(operand, scale, mean, var, grad_output, 141 /*epsilon=*/0.0, kFeatureIndex); 142 143 auto expected = Literal::MakeTuple( 144 {Literal::CreateR4<bfloat16>( 145 {{{{static_cast<bfloat16>(-3.f)}, {static_cast<bfloat16>(-3.f)}}, 146 {{static_cast<bfloat16>(-1.f)}, {static_cast<bfloat16>(-1.f)}}}, 147 {{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(1.f)}}, 148 {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}}) 149 .get(), 150 Literal::CreateR1<bfloat16>( 151 {static_cast<bfloat16>(0), static_cast<bfloat16>(0)}) 152 .get(), 153 Literal::CreateR1<bfloat16>( 154 {static_cast<bfloat16>(16), static_cast<bfloat16>(20)}) 155 .get()}); 156 157 ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01)); 158} 159 160} // namespace 161} // namespace xla 162