1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include <memory> 17#include <utility> 18 19#include "tensorflow/compiler/xla/literal_util.h" 20#include "tensorflow/compiler/xla/ptr_util.h" 21#include "tensorflow/compiler/xla/service/hlo_computation.h" 22#include "tensorflow/compiler/xla/service/hlo_instruction.h" 23#include "tensorflow/compiler/xla/service/hlo_module.h" 24#include "tensorflow/compiler/xla/shape_util.h" 25#include "tensorflow/compiler/xla/tests/hlo_test_base.h" 26#include "tensorflow/compiler/xla/tests/literal_test_util.h" 27#include "tensorflow/compiler/xla/tests/test_macros.h" 28#include "tensorflow/compiler/xla/xla_data.pb.h" 29#include "tensorflow/core/platform/test.h" 30 31namespace xla { 32namespace { 33 34class BroadcastTest : public HloTestBase {}; 35 36XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { 37 // Test degenerate case of broadcasting a scalar into a scalar. 38 auto builder = HloComputation::Builder(TestName()); 39 auto input = builder.AddInstruction( 40 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0))); 41 builder.AddInstruction(HloInstruction::CreateBroadcast( 42 ShapeUtil::MakeShape(F32, {}), input, {})); 43 44 // Create HLO module, compile, and execute. 45 auto hlo_module = CreateNewModule(); 46 hlo_module->AddEntryComputation(builder.Build()); 47 auto result = ExecuteAndTransfer(std::move(hlo_module), {}); 48 49 LiteralTestUtil::ExpectNear(*Literal::CreateR0<float>(42.0), *result, 50 error_spec_); 51} 52 53XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { 54 auto builder = HloComputation::Builder(TestName()); 55 auto input = builder.AddInstruction( 56 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0))); 57 builder.AddInstruction(HloInstruction::CreateBroadcast( 58 ShapeUtil::MakeShape(F32, {2, 2}), input, {})); 59 60 // Create HLO module, compile, and execute. 61 auto hlo_module = CreateNewModule(); 62 hlo_module->AddEntryComputation(builder.Build()); 63 auto result = ExecuteAndTransfer(std::move(hlo_module), {}); 64 65 LiteralTestUtil::ExpectNear( 66 *Literal::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), *result, 67 error_spec_); 68} 69 70XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { 71 auto builder = HloComputation::Builder(TestName()); 72 auto input = builder.AddInstruction(HloInstruction::CreateConstant( 73 Literal::CreateR1<float>({1.0, 2.0, 3.0}))); 74 75 // Broadcast vector in both dimension 0 and dimension 1. Join them in a tuple 76 // to enable testing of the results. 77 auto element1 = builder.AddInstruction(HloInstruction::CreateBroadcast( 78 ShapeUtil::MakeShape(F32, {3, 2}), input, {0})); 79 auto element2 = builder.AddInstruction(HloInstruction::CreateBroadcast( 80 ShapeUtil::MakeShape(F32, {2, 3}), input, {1})); 81 builder.AddInstruction(HloInstruction::CreateTuple({element1, element2})); 82 83 // Create HLO module, compile, and execute. 84 auto hlo_module = CreateNewModule(); 85 hlo_module->AddEntryComputation(builder.Build()); 86 auto result = ExecuteAndTransfer(std::move(hlo_module), {}); 87 88 LiteralTestUtil::ExpectNear( 89 *Literal::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), 90 LiteralView::Create(*result, {0}), error_spec_); 91 92 LiteralTestUtil::ExpectNear( 93 *Literal::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), 94 LiteralView::Create(*result, {1}), error_spec_); 95} 96 97XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { 98 auto builder = HloComputation::Builder(TestName()); 99 auto input = builder.AddInstruction(HloInstruction::CreateConstant( 100 Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); 101 builder.AddInstruction(HloInstruction::CreateBroadcast( 102 ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1})); 103 104 // Create HLO module, compile, and execute. 105 auto hlo_module = CreateNewModule(); 106 hlo_module->AddEntryComputation(builder.Build()); 107 auto result = ExecuteAndTransfer(std::move(hlo_module), {}); 108 109 LiteralTestUtil::ExpectNear( 110 *Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), *result, 111 error_spec_); 112} 113 114XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { 115 // Degenerately broadcasting a shape into a shape of the same rank reorders 116 // the dimensions, ie transpose. 117 auto builder = HloComputation::Builder(TestName()); 118 auto input = builder.AddInstruction(HloInstruction::CreateConstant( 119 Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); 120 builder.AddInstruction(HloInstruction::CreateBroadcast( 121 ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0})); 122 123 // Create HLO module, compile, and execute. 124 auto hlo_module = CreateNewModule(); 125 hlo_module->AddEntryComputation(builder.Build()); 126 auto result = ExecuteAndTransfer(std::move(hlo_module), {}); 127 128 LiteralTestUtil::ExpectNear( 129 *Literal::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), *result, 130 error_spec_); 131} 132 133XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { 134 auto builder = HloComputation::Builder(TestName()); 135 auto input = builder.AddInstruction(HloInstruction::CreateConstant( 136 Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); 137 builder.AddInstruction(HloInstruction::CreateBroadcast( 138 ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2})); 139 140 // Create HLO module, compile, and execute. 141 auto hlo_module = CreateNewModule(); 142 hlo_module->AddEntryComputation(builder.Build()); 143 auto result = ExecuteAndTransfer(std::move(hlo_module), {}); 144 145 LiteralTestUtil::ExpectNear( 146 *Literal::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, 147 {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), 148 *result, error_spec_); 149} 150 151TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { 152 auto builder = HloComputation::Builder(TestName()); 153 auto input = builder.AddInstruction( 154 HloInstruction::CreateConstant(Literal::CreateR1<float>({1.0, 2.0}))); 155 156 // Broadcast vector in dimension 1. 157 builder.AddInstruction(HloInstruction::CreateBroadcast( 158 ShapeUtil::MakeShape(F32, {2, 2, 3, 3}), input, {1})); 159 160 // Create HLO module, compile, and execute. 161 auto hlo_module = CreateNewModule(); 162 hlo_module->AddEntryComputation(builder.Build()); 163 auto result = ExecuteAndTransfer(std::move(hlo_module), {}); 164 165 Array4D<float> expected(2, 2, 3, 3); 166 Array2D<float> pz({{1, 2}, {1, 2}}); 167 expected.FillWithPZ(pz); 168 169 LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected), 170 *result, error_spec_); 171} 172 173TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { 174 auto builder = HloComputation::Builder(TestName()); 175 std::vector<float> input_data(1025); 176 int64 r1_size = input_data.size(); 177 std::iota(input_data.begin(), input_data.end(), 0.0f); 178 auto input = builder.AddInstruction( 179 HloInstruction::CreateConstant(Literal::CreateR1<float>(input_data))); 180 181 // Broadcast vector in dimension 3. 182 builder.AddInstruction(HloInstruction::CreateBroadcast( 183 ShapeUtil::MakeShape(F32, {3, 3, 3, r1_size}), input, {3})); 184 185 // Create HLO module, compile, and execute. 186 auto hlo_module = CreateNewModule(); 187 hlo_module->AddEntryComputation(builder.Build()); 188 auto result = ExecuteAndTransfer(std::move(hlo_module), {}); 189 190 Array4D<float> expected(3, 3, 3, 1025); 191 Array2D<float> yx(3, r1_size); 192 for (int64 y = 0; y < 3; ++y) { 193 for (int64 x = 0; x < r1_size; ++x) { 194 yx(y, x) = input_data[x]; 195 } 196 } 197 expected.FillWithYX(yx); 198 199 LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected), 200 *result, error_spec_); 201} 202 203XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { 204 auto builder = HloComputation::Builder(TestName()); 205 Array4D<float> r4_array(32, 64, 7, 7); 206 r4_array.Fill(42.0); 207 std::vector<float> r1_array(64, 42.0); 208 209 auto input = builder.AddInstruction( 210 HloInstruction::CreateConstant(Literal::CreateR1<float>(r1_array))); 211 212 // Broadcast vector in dimension 1. 213 builder.AddInstruction(HloInstruction::CreateBroadcast( 214 ShapeUtil::MakeShape(F32, {32, 64, 7, 7}), input, {1})); 215 216 // Create HLO module, compile, and execute. 217 auto hlo_module = CreateNewModule(); 218 hlo_module->AddEntryComputation(builder.Build()); 219 auto result = ExecuteAndTransfer(std::move(hlo_module), {}); 220 221 LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(r4_array), *result, 222 error_spec_); 223} 224 225TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { 226 auto builder = HloComputation::Builder(TestName()); 227 auto input = builder.AddInstruction( 228 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f))); 229 builder.AddInstruction(HloInstruction::CreateBroadcast( 230 ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {})); 231 232 // Create HLO module, compile, and execute. 233 auto hlo_module = CreateNewModule(); 234 hlo_module->AddEntryComputation(builder.Build()); 235 LOG(INFO) << hlo_module->ToString(); 236 auto result = ExecuteAndTransfer(std::move(hlo_module), {}); 237 238 Array4D<float> expected(64, 64, 3, 3); 239 expected.Fill(1.0f); 240 241 LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected), 242 *result, error_spec_); 243} 244 245TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { 246 auto builder = HloComputation::Builder(TestName()); 247 Array2D<float> to_broadcast({{1.0f, 2.0f}, {3.0f, 4.0f}}); 248 auto input = builder.AddInstruction(HloInstruction::CreateConstant( 249 Literal::CreateR2FromArray2D<float>(to_broadcast))); 250 251 // Broadcast vector in dimensions 2 and 3. 252 builder.AddInstruction(HloInstruction::CreateBroadcast( 253 ShapeUtil::MakeShape(F32, {3, 3, 2, 2}), input, {2, 3})); 254 255 // Create HLO module, compile, and execute. 256 auto hlo_module = CreateNewModule(); 257 hlo_module->AddEntryComputation(builder.Build()); 258 auto result = ExecuteAndTransfer(std::move(hlo_module), {}); 259 260 Array4D<float> expected(3, 3, 2, 2); 261 expected.FillWithYX(to_broadcast); 262 263 LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected), 264 *result, error_spec_); 265} 266 267TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { 268 auto builder = HloComputation::Builder(TestName()); 269 Array3D<float> input_vals(2, 3, 4); 270 input_vals.FillRandom(1.0); 271 272 Array4D<float> expected(2, 3, 4, 5); 273 for (int i = 0; i < 2; ++i) { 274 for (int j = 0; j < 3; ++j) { 275 for (int k = 0; k < 4; ++k) { 276 for (int m = 0; m < 5; ++m) { 277 expected(i, j, k, m) = input_vals(i, j, k); 278 } 279 } 280 } 281 } 282 auto input = builder.AddInstruction(HloInstruction::CreateConstant( 283 Literal::CreateR3FromArray3D<float>(input_vals))); 284 285 // Broadcast vector in dimensions 2 and 3. 286 builder.AddInstruction(HloInstruction::CreateBroadcast( 287 ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), input, {0, 1, 2})); 288 289 // Create HLO module, compile, and execute. 290 auto hlo_module = CreateNewModule(); 291 hlo_module->AddEntryComputation(builder.Build()); 292 auto result = ExecuteAndTransfer(std::move(hlo_module), {}); 293 294 LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected), 295 *result, error_spec_); 296} 297 298} // namespace 299} // namespace xla 300