1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <cmath>
17#include <memory>
18#include <vector>
19
20#include "tensorflow/compiler/xla/array2d.h"
21#include "tensorflow/compiler/xla/array4d.h"
22#include "tensorflow/compiler/xla/client/computation.h"
23#include "tensorflow/compiler/xla/client/computation_builder.h"
24#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
25#include "tensorflow/compiler/xla/client/local_client.h"
26#include "tensorflow/compiler/xla/literal_util.h"
27#include "tensorflow/compiler/xla/reference_util.h"
28#include "tensorflow/compiler/xla/service/hlo_computation.h"
29#include "tensorflow/compiler/xla/service/hlo_instruction.h"
30#include "tensorflow/compiler/xla/service/hlo_module.h"
31#include "tensorflow/compiler/xla/shape_util.h"
32#include "tensorflow/compiler/xla/statusor.h"
33#include "tensorflow/compiler/xla/test.h"
34#include "tensorflow/compiler/xla/test_helpers.h"
35#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
36#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
37#include "tensorflow/compiler/xla/tests/literal_test_util.h"
38#include "tensorflow/compiler/xla/tests/test_macros.h"
39#include "tensorflow/compiler/xla/tests/test_utils.h"
40#include "tensorflow/compiler/xla/util.h"
41#include "tensorflow/compiler/xla/xla_data.pb.h"
42#include "tensorflow/core/platform/logging.h"
43#include "tensorflow/core/platform/test.h"
44#include "tensorflow/core/platform/types.h"
45
46namespace xla {
47namespace {
48
49class Bfloat16Test : public ClientLibraryTestBase {
50 protected:
51  const ErrorSpec error_spec_{0.001, 0.001};
52};
53
54XLA_TEST_F(Bfloat16Test, ScalarOperation) {
55  ComputationBuilder builder(client_, TestName());
56  auto x = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.0f));
57  auto y = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(1.0f));
58  builder.Add(x, y);
59
60  ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(3.0f), {},
61                                error_spec_);
62}
63
64XLA_TEST_F(Bfloat16Test, LogOperation) {
65  ComputationBuilder builder(client_, TestName());
66  auto x = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(4.0f));
67  builder.Log(x);
68
69  ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(1.387f), {},
70                                error_spec_);
71}
72
73XLA_TEST_F(Bfloat16Test, NegateScalarF16) {
74  ComputationBuilder builder(client_, TestName());
75  builder.Neg(builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.1f)));
76
77  ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(-2.1f), {},
78                                error_spec_);
79}
80
81XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
82  const int kFeatureIndex = 2;
83  ComputationBuilder builder(client_, TestName());
84
85  auto operand = builder.ConstantR4FromArray4D<bfloat16>(
86      {{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(2.f)}},
87        {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(4.f)}}},
88       {{{static_cast<bfloat16>(5.f)}, {static_cast<bfloat16>(6.f)}},
89        {{static_cast<bfloat16>(7.f)}, {static_cast<bfloat16>(8.f)}}}});
90
91  auto scale = builder.ConstantR1<bfloat16>(
92      {static_cast<bfloat16>(2.0f), static_cast<bfloat16>(3.0f)});
93
94  auto offset = builder.ConstantR1<bfloat16>(
95      {static_cast<bfloat16>(1.0f), static_cast<bfloat16>(2.0f)});
96
97  auto tuple = builder.BatchNormTraining(operand, scale, offset,
98                                         /*epsilon=*/0.001, kFeatureIndex);
99
100  auto expected = Literal::MakeTuple(
101      {Literal::CreateR4<bfloat16>(
102           {{{{static_cast<bfloat16>(-1.6875f)},
103              {static_cast<bfloat16>(-2.04f)}},
104             {{static_cast<bfloat16>(0.105f)}, {static_cast<bfloat16>(0.66f)}}},
105            {{{static_cast<bfloat16>(1.89f)}, {static_cast<bfloat16>(3.35f)}},
106             {{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}})
107           .get(),
108       Literal::CreateR1<bfloat16>(
109           {static_cast<bfloat16>(4), static_cast<bfloat16>(5)})
110           .get(),
111       Literal::CreateR1<bfloat16>(
112           {static_cast<bfloat16>(5), static_cast<bfloat16>(5)})
113           .get()});
114
115  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01));
116}
117
118XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
119  const int kFeatureIndex = 2;
120  ComputationBuilder builder(client_, TestName());
121
122  auto operand = builder.ConstantR4FromArray4D<bfloat16>(
123      Array4D<bfloat16>(2, 2, 2, 1, static_cast<bfloat16>(0.0f)));
124
125  auto scale = builder.ConstantR1<bfloat16>(
126      {static_cast<bfloat16>(1.0f), static_cast<bfloat16>(1.0f)});
127
128  auto mean = builder.ConstantR1<bfloat16>(
129      {static_cast<bfloat16>(0.0f), static_cast<bfloat16>(0.0f)});
130
131  auto var = builder.ConstantR1<bfloat16>(
132      {static_cast<bfloat16>(1.0f), static_cast<bfloat16>(1.0f)});
133
134  auto grad_output = builder.ConstantR4FromArray4D<bfloat16>(
135      {{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(2.f)}},
136        {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(4.f)}}},
137       {{{static_cast<bfloat16>(5.f)}, {static_cast<bfloat16>(6.f)}},
138        {{static_cast<bfloat16>(7.f)}, {static_cast<bfloat16>(8.f)}}}});
139
140  builder.BatchNormGrad(operand, scale, mean, var, grad_output,
141                        /*epsilon=*/0.0, kFeatureIndex);
142
143  auto expected = Literal::MakeTuple(
144      {Literal::CreateR4<bfloat16>(
145           {{{{static_cast<bfloat16>(-3.f)}, {static_cast<bfloat16>(-3.f)}},
146             {{static_cast<bfloat16>(-1.f)}, {static_cast<bfloat16>(-1.f)}}},
147            {{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(1.f)}},
148             {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}})
149           .get(),
150       Literal::CreateR1<bfloat16>(
151           {static_cast<bfloat16>(0), static_cast<bfloat16>(0)})
152           .get(),
153       Literal::CreateR1<bfloat16>(
154           {static_cast<bfloat16>(16), static_cast<bfloat16>(20)})
155           .get()});
156
157  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01));
158}
159
160}  // namespace
161}  // namespace xla
162