1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/service/user_computation.h" 17 18#include "tensorflow/compiler/xla/literal_util.h" 19#include "tensorflow/compiler/xla/service/hlo_computation.h" 20#include "tensorflow/compiler/xla/service/hlo_matchers.h" 21#include "tensorflow/compiler/xla/shape_util.h" 22#include "tensorflow/compiler/xla/status_macros.h" 23#include "tensorflow/compiler/xla/test.h" 24#include "tensorflow/compiler/xla/test_helpers.h" 25#include "tensorflow/compiler/xla/xla_data.pb.h" 26#include "tensorflow/core/lib/core/status_test_util.h" 27 28namespace op = xla::testing::opcode_matchers; 29 30namespace xla { 31namespace { 32 33using UserComputationTest = ::testing::Test; 34 35TEST_F(UserComputationTest, SimpleComputation) { 36 const Shape kScalarShape = ShapeUtil::MakeShape(F32, {}); 37 const Shape kVectorShape = ShapeUtil::MakeShape(F32, {2}); 38 39 // Build a simple three operation computatation: 40 // 41 // %constant = Constant({123, 42}) 42 // %param = Param(0) 43 // %outfeed = Outfeed(%constant) 44 // 45 // Build the computation at two different versions and check invariants. 46 ComputationHandle handle; 47 handle.set_handle(123); 48 UserComputation computation("TheComputation", handle); 49 50 ConstantRequest constant_request; 51 *constant_request.mutable_literal() = 52 Literal::CreateR1<float>({123.0f, 42.0f})->ToProto(); 53 TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle constant_handle, 54 computation.AddConstantInstruction(constant_request)); 55 56 ParameterRequest param_request; 57 *param_request.mutable_shape() = kScalarShape; 58 param_request.set_parameter(0); 59 param_request.set_name("param0"); 60 TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle param_handle, 61 computation.AddParameterInstruction(param_request)); 62 OpMetadata metadata; 63 metadata.set_op_name("meta"); 64 TF_ASSERT_OK(computation.SetOpMetadata(param_handle, metadata)); 65 66 OutfeedRequest outfeed_request; 67 *outfeed_request.mutable_operand() = constant_handle; 68 *outfeed_request.mutable_shape() = kVectorShape; 69 outfeed_request.set_outfeed_config("abc"); 70 TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle outfeed_handle, 71 computation.AddOutfeedInstruction(outfeed_request)); 72 73 auto hlo_resolver = [](const VersionedComputationHandle& handle) { 74 return nullptr; 75 }; 76 { 77 // Test the computation at the latest version. In this case, the most 78 // recently added operation is an outfeed. However, the outfeed is not the 79 // root because outfeeds cannot be the root of a computation. 80 VersionedComputationHandle latest_version = 81 computation.GetVersionedHandle(); 82 83 // Program shape should have a single scalar parameter and scalar 84 // result. The outfeed instruction should not affect the program shape. 85 TF_ASSERT_OK_AND_ASSIGN( 86 std::shared_ptr<const ProgramShape> program_shape, 87 computation.ComputeProgramShape(latest_version.version)); 88 ASSERT_EQ(1, program_shape->parameters_size()); 89 EXPECT_TRUE( 90 ShapeUtil::Compatible(kScalarShape, program_shape->parameters(0))); 91 EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result())); 92 93 // Build the HLO computation. 94 TF_ASSERT_OK_AND_ASSIGN( 95 std::unique_ptr<HloComputation> hlo_computation, 96 computation.BuildHloComputation(latest_version.version, hlo_resolver, 97 DebugOptions())); 98 // There should be one HloInstruction per UserComputation operation. 99 EXPECT_EQ(3, hlo_computation->instruction_count()); 100 // The root of the instruction should be the parameter instruction (not the 101 // outfeed). 102 EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); 103 } 104 105 { 106 // Test the computation at the version right after the parameter instruction 107 // is added. 108 VersionedComputationHandle version_at_param = 109 computation.GetVersionedHandleAtOperation(param_handle); 110 111 // Program shape should have a single scalar parameter, and scalar result. 112 TF_ASSERT_OK_AND_ASSIGN( 113 std::shared_ptr<const ProgramShape> program_shape, 114 computation.ComputeProgramShape(version_at_param.version)); 115 ASSERT_EQ(1, program_shape->parameters_size()); 116 EXPECT_TRUE( 117 ShapeUtil::Compatible(kScalarShape, program_shape->parameters(0))); 118 EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result())); 119 120 // There should be two instructions, one for the constant and one for the 121 // parameter. The outfeed instruction should not be included. 122 TF_ASSERT_OK_AND_ASSIGN( 123 std::unique_ptr<HloComputation> hlo_computation, 124 computation.BuildHloComputation(version_at_param.version, hlo_resolver, 125 DebugOptions())); 126 EXPECT_EQ(2, hlo_computation->instruction_count()); 127 EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); 128 } 129 { 130 // Test the computation at the latest version, but lowered with 131 // include_unreachable_instructions set to false. 132 VersionedComputationHandle latest_version = 133 computation.GetVersionedHandle(); 134 135 // Build the HLO computation. 136 TF_ASSERT_OK_AND_ASSIGN( 137 std::unique_ptr<HloComputation> hlo_computation, 138 computation.BuildHloComputation( 139 latest_version.version, hlo_resolver, DebugOptions(), 140 /*include_unreachable_instructions=*/false)); 141 // There is only one reachable instruction, the parameter. 142 EXPECT_EQ(1, hlo_computation->instruction_count()); 143 // The root of the instruction should be the parameter instruction (not the 144 // outfeed). 145 EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); 146 EXPECT_EQ(hlo_computation->root_instruction()->metadata().op_name(), 147 "meta"); 148 } 149} 150 151TEST_F(UserComputationTest, EliminateScalarBroadcast) { 152 auto debug_options = DebugOptions(); 153 debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); 154 155 // Build a binary computation with scalar broadcast. 156 // 157 // %a = Constant({123, 42}) 158 // %b = Constant(1) 159 // %add = Add(%a, %b) 160 ComputationHandle handle; 161 handle.set_handle(123); 162 UserComputation computation("TheComputation", handle); 163 164 ConstantRequest a_request; 165 *a_request.mutable_literal() = 166 Literal::CreateR1<float>({123.0f, 42.0f})->ToProto(); 167 TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, 168 computation.AddConstantInstruction(a_request)); 169 170 ConstantRequest b_request; 171 *b_request.mutable_literal() = Literal::CreateR0<float>(1.0f)->ToProto(); 172 TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, 173 computation.AddConstantInstruction(b_request)); 174 175 BinaryOpRequest add; 176 add.set_binop(BINOP_ADD); 177 *add.mutable_lhs() = a_handle; 178 *add.mutable_rhs() = b_handle; 179 TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); 180 181 auto hlo_resolver = [](const VersionedComputationHandle& handle) { 182 return nullptr; 183 }; 184 VersionedComputationHandle latest_version = computation.GetVersionedHandle(); 185 186 // Build the HLO computation. 187 TF_ASSERT_OK_AND_ASSIGN( 188 std::unique_ptr<HloComputation> hlo_computation, 189 computation.BuildHloComputation(latest_version.version, hlo_resolver, 190 debug_options)); 191 // The binary operation has implicit scalar broadcast, should be converted 192 // to an explicit broadcast intruction and a binary instruction. 193 EXPECT_EQ(4, hlo_computation->instruction_count()); 194 EXPECT_THAT(hlo_computation->root_instruction(), op::Add()); 195 LOG(INFO) << hlo_computation->root_instruction()->ToString(); 196 const auto& operands = hlo_computation->root_instruction()->operands(); 197 ASSERT_EQ(2, operands.size()); 198 EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast || 199 operands[1]->opcode() == HloOpcode::kBroadcast); 200} 201 202TEST_F(UserComputationTest, CheckImplicitBroadcastToExplicitBroadcast) { 203 auto debug_options = DebugOptions(); 204 debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); 205 206 // Build a binary computation with degenerate broadcast. 207 // 208 // %a = Param({1, 2, 3}); 209 // %b = Param({1, 2, 1}); 210 // %add = Add(%a, %b, {}); 211 ComputationHandle handle; 212 handle.set_handle(123); 213 UserComputation computation("TheComputation", handle); 214 215 ParameterRequest a_request; 216 *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 3}); 217 a_request.set_name("a"); 218 a_request.set_parameter(0); 219 TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, 220 computation.AddParameterInstruction(a_request)); 221 222 ParameterRequest b_request; 223 *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 1}); 224 b_request.set_name("b"); 225 b_request.set_parameter(1); 226 TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, 227 computation.AddParameterInstruction(b_request)); 228 229 const int64 kDevice = 7; 230 OpSharding sharding; 231 sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); 232 sharding.add_tile_assignment_dimensions(1); 233 sharding.add_tile_assignment_devices(kDevice); 234 235 TF_EXPECT_OK(computation.SetOpSharding(b_handle, sharding)); 236 237 BinaryOpRequest add; 238 add.set_binop(BINOP_ADD); 239 *add.mutable_lhs() = a_handle; 240 *add.mutable_rhs() = b_handle; 241 TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); 242 243 auto hlo_resolver = [](const VersionedComputationHandle& handle) { 244 return nullptr; 245 }; 246 VersionedComputationHandle latest_version = computation.GetVersionedHandle(); 247 248 // Build the HLO computation. 249 TF_ASSERT_OK_AND_ASSIGN( 250 std::unique_ptr<HloComputation> hlo_computation, 251 computation.BuildHloComputation(latest_version.version, hlo_resolver, 252 debug_options)); 253 254 // b a 255 // | | 256 // reshape | 257 // | | 258 // broadcast | 259 // \ / 260 // add 261 EXPECT_EQ(5, hlo_computation->instruction_count()); 262 ASSERT_THAT( 263 hlo_computation->root_instruction(), 264 op::Add(op::Parameter(), op::Broadcast(op::Reshape(op::Parameter())))); 265 266 const HloInstruction* broadcast = 267 hlo_computation->root_instruction()->operand(1); 268 EXPECT_TRUE(broadcast->has_sharding()); 269 270 const HloInstruction* reshape = broadcast->operand(0); 271 EXPECT_TRUE(reshape->has_sharding()); 272} 273 274TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { 275 auto debug_options = DebugOptions(); 276 debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); 277 278 // Build a binary computation with in-dim broadcast and degenerate broadcast. 279 // 280 // %a = Param({2, 3}); 281 // %b = Param({2, 1, 4}); 282 // %add = Add(%a, %b, {0, 1}); 283 ComputationHandle handle; 284 handle.set_handle(123); 285 UserComputation computation("TheComputation", handle); 286 287 ParameterRequest a_request; 288 *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 3}); 289 a_request.set_name("a"); 290 a_request.set_parameter(0); 291 TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, 292 computation.AddParameterInstruction(a_request)); 293 294 ParameterRequest b_request; 295 *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 1, 4}); 296 b_request.set_name("b"); 297 b_request.set_parameter(1); 298 TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, 299 computation.AddParameterInstruction(b_request)); 300 301 BinaryOpRequest add; 302 add.set_binop(BINOP_ADD); 303 *add.mutable_lhs() = a_handle; 304 *add.mutable_rhs() = b_handle; 305 add.add_broadcast_dimensions(0); 306 add.add_broadcast_dimensions(1); 307 TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); 308 309 auto hlo_resolver = [](const VersionedComputationHandle& handle) { 310 return nullptr; 311 }; 312 VersionedComputationHandle latest_version = computation.GetVersionedHandle(); 313 314 // Build the HLO computation. 315 TF_ASSERT_OK_AND_ASSIGN( 316 std::unique_ptr<HloComputation> hlo_computation, 317 computation.BuildHloComputation(latest_version.version, hlo_resolver, 318 debug_options)); 319 320 // The binary operation has in-dim broadcast and degenerate broadcast, should 321 // first do the in-dim broadcast then convert the degnerate broadcast into a 322 // reshape and a broadcast. 323 // 324 // b a 325 // | | 326 // broadcast reshape 327 // | | 328 // | broadcast 329 // \ / 330 // add 331 EXPECT_EQ(6, hlo_computation->instruction_count()); 332 EXPECT_THAT(hlo_computation->root_instruction(), op::Add()); 333 const auto& operands = hlo_computation->root_instruction()->operands(); 334 ASSERT_EQ(2, operands.size()); 335 EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast && 336 operands[1]->opcode() == HloOpcode::kBroadcast); 337} 338 339} // namespace 340} // namespace xla 341