1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/user_computation.h"
17
18#include "tensorflow/compiler/xla/literal_util.h"
19#include "tensorflow/compiler/xla/service/hlo_computation.h"
20#include "tensorflow/compiler/xla/service/hlo_matchers.h"
21#include "tensorflow/compiler/xla/shape_util.h"
22#include "tensorflow/compiler/xla/status_macros.h"
23#include "tensorflow/compiler/xla/test.h"
24#include "tensorflow/compiler/xla/test_helpers.h"
25#include "tensorflow/compiler/xla/xla_data.pb.h"
26#include "tensorflow/core/lib/core/status_test_util.h"
27
28namespace op = xla::testing::opcode_matchers;
29
30namespace xla {
31namespace {
32
33using UserComputationTest = ::testing::Test;
34
35TEST_F(UserComputationTest, SimpleComputation) {
36  const Shape kScalarShape = ShapeUtil::MakeShape(F32, {});
37  const Shape kVectorShape = ShapeUtil::MakeShape(F32, {2});
38
39  // Build a simple three operation computatation:
40  //
41  //   %constant = Constant({123, 42})
42  //   %param = Param(0)
43  //   %outfeed = Outfeed(%constant)
44  //
45  // Build the computation at two different versions and check invariants.
46  ComputationHandle handle;
47  handle.set_handle(123);
48  UserComputation computation("TheComputation", handle);
49
50  ConstantRequest constant_request;
51  *constant_request.mutable_literal() =
52      Literal::CreateR1<float>({123.0f, 42.0f})->ToProto();
53  TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle constant_handle,
54                          computation.AddConstantInstruction(constant_request));
55
56  ParameterRequest param_request;
57  *param_request.mutable_shape() = kScalarShape;
58  param_request.set_parameter(0);
59  param_request.set_name("param0");
60  TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle param_handle,
61                          computation.AddParameterInstruction(param_request));
62  OpMetadata metadata;
63  metadata.set_op_name("meta");
64  TF_ASSERT_OK(computation.SetOpMetadata(param_handle, metadata));
65
66  OutfeedRequest outfeed_request;
67  *outfeed_request.mutable_operand() = constant_handle;
68  *outfeed_request.mutable_shape() = kVectorShape;
69  outfeed_request.set_outfeed_config("abc");
70  TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle outfeed_handle,
71                          computation.AddOutfeedInstruction(outfeed_request));
72
73  auto hlo_resolver = [](const VersionedComputationHandle& handle) {
74    return nullptr;
75  };
76  {
77    // Test the computation at the latest version. In this case, the most
78    // recently added operation is an outfeed. However, the outfeed is not the
79    // root because outfeeds cannot be the root of a computation.
80    VersionedComputationHandle latest_version =
81        computation.GetVersionedHandle();
82
83    // Program shape should have a single scalar parameter and scalar
84    // result. The outfeed instruction should not affect the program shape.
85    TF_ASSERT_OK_AND_ASSIGN(
86        std::shared_ptr<const ProgramShape> program_shape,
87        computation.ComputeProgramShape(latest_version.version));
88    ASSERT_EQ(1, program_shape->parameters_size());
89    EXPECT_TRUE(
90        ShapeUtil::Compatible(kScalarShape, program_shape->parameters(0)));
91    EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result()));
92
93    // Build the HLO computation.
94    TF_ASSERT_OK_AND_ASSIGN(
95        std::unique_ptr<HloComputation> hlo_computation,
96        computation.BuildHloComputation(latest_version.version, hlo_resolver,
97                                        DebugOptions()));
98    // There should be one HloInstruction per UserComputation operation.
99    EXPECT_EQ(3, hlo_computation->instruction_count());
100    // The root of the instruction should be the parameter instruction (not the
101    // outfeed).
102    EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter());
103  }
104
105  {
106    // Test the computation at the version right after the parameter instruction
107    // is added.
108    VersionedComputationHandle version_at_param =
109        computation.GetVersionedHandleAtOperation(param_handle);
110
111    // Program shape should have a single scalar parameter, and scalar result.
112    TF_ASSERT_OK_AND_ASSIGN(
113        std::shared_ptr<const ProgramShape> program_shape,
114        computation.ComputeProgramShape(version_at_param.version));
115    ASSERT_EQ(1, program_shape->parameters_size());
116    EXPECT_TRUE(
117        ShapeUtil::Compatible(kScalarShape, program_shape->parameters(0)));
118    EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result()));
119
120    // There should be two instructions, one for the constant and one for the
121    // parameter. The outfeed instruction should not be included.
122    TF_ASSERT_OK_AND_ASSIGN(
123        std::unique_ptr<HloComputation> hlo_computation,
124        computation.BuildHloComputation(version_at_param.version, hlo_resolver,
125                                        DebugOptions()));
126    EXPECT_EQ(2, hlo_computation->instruction_count());
127    EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter());
128  }
129  {
130    // Test the computation at the latest version, but lowered with
131    // include_unreachable_instructions set to false.
132    VersionedComputationHandle latest_version =
133        computation.GetVersionedHandle();
134
135    // Build the HLO computation.
136    TF_ASSERT_OK_AND_ASSIGN(
137        std::unique_ptr<HloComputation> hlo_computation,
138        computation.BuildHloComputation(
139            latest_version.version, hlo_resolver, DebugOptions(),
140            /*include_unreachable_instructions=*/false));
141    // There is only one reachable instruction, the parameter.
142    EXPECT_EQ(1, hlo_computation->instruction_count());
143    // The root of the instruction should be the parameter instruction (not the
144    // outfeed).
145    EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter());
146    EXPECT_EQ(hlo_computation->root_instruction()->metadata().op_name(),
147              "meta");
148  }
149}
150
151TEST_F(UserComputationTest, EliminateScalarBroadcast) {
152  auto debug_options = DebugOptions();
153  debug_options.set_xla_eliminate_hlo_implicit_broadcast(true);
154
155  // Build a binary computation with scalar broadcast.
156  //
157  //  %a = Constant({123, 42})
158  //  %b = Constant(1)
159  //  %add = Add(%a, %b)
160  ComputationHandle handle;
161  handle.set_handle(123);
162  UserComputation computation("TheComputation", handle);
163
164  ConstantRequest a_request;
165  *a_request.mutable_literal() =
166      Literal::CreateR1<float>({123.0f, 42.0f})->ToProto();
167  TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle,
168                          computation.AddConstantInstruction(a_request));
169
170  ConstantRequest b_request;
171  *b_request.mutable_literal() = Literal::CreateR0<float>(1.0f)->ToProto();
172  TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle,
173                          computation.AddConstantInstruction(b_request));
174
175  BinaryOpRequest add;
176  add.set_binop(BINOP_ADD);
177  *add.mutable_lhs() = a_handle;
178  *add.mutable_rhs() = b_handle;
179  TF_ASSERT_OK(computation.AddBinaryInstruction(add).status());
180
181  auto hlo_resolver = [](const VersionedComputationHandle& handle) {
182    return nullptr;
183  };
184  VersionedComputationHandle latest_version = computation.GetVersionedHandle();
185
186  // Build the HLO computation.
187  TF_ASSERT_OK_AND_ASSIGN(
188      std::unique_ptr<HloComputation> hlo_computation,
189      computation.BuildHloComputation(latest_version.version, hlo_resolver,
190                                      debug_options));
191  // The binary operation has implicit scalar broadcast, should be converted
192  // to an explicit broadcast intruction and a binary instruction.
193  EXPECT_EQ(4, hlo_computation->instruction_count());
194  EXPECT_THAT(hlo_computation->root_instruction(), op::Add());
195  LOG(INFO) << hlo_computation->root_instruction()->ToString();
196  const auto& operands = hlo_computation->root_instruction()->operands();
197  ASSERT_EQ(2, operands.size());
198  EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast ||
199              operands[1]->opcode() == HloOpcode::kBroadcast);
200}
201
202TEST_F(UserComputationTest, CheckImplicitBroadcastToExplicitBroadcast) {
203  auto debug_options = DebugOptions();
204  debug_options.set_xla_eliminate_hlo_implicit_broadcast(true);
205
206  // Build a binary computation with degenerate broadcast.
207  //
208  //  %a = Param({1, 2, 3});
209  //  %b = Param({1, 2, 1});
210  //  %add = Add(%a, %b, {});
211  ComputationHandle handle;
212  handle.set_handle(123);
213  UserComputation computation("TheComputation", handle);
214
215  ParameterRequest a_request;
216  *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 3});
217  a_request.set_name("a");
218  a_request.set_parameter(0);
219  TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle,
220                          computation.AddParameterInstruction(a_request));
221
222  ParameterRequest b_request;
223  *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 1});
224  b_request.set_name("b");
225  b_request.set_parameter(1);
226  TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle,
227                          computation.AddParameterInstruction(b_request));
228
229  const int64 kDevice = 7;
230  OpSharding sharding;
231  sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
232  sharding.add_tile_assignment_dimensions(1);
233  sharding.add_tile_assignment_devices(kDevice);
234
235  TF_EXPECT_OK(computation.SetOpSharding(b_handle, sharding));
236
237  BinaryOpRequest add;
238  add.set_binop(BINOP_ADD);
239  *add.mutable_lhs() = a_handle;
240  *add.mutable_rhs() = b_handle;
241  TF_ASSERT_OK(computation.AddBinaryInstruction(add).status());
242
243  auto hlo_resolver = [](const VersionedComputationHandle& handle) {
244    return nullptr;
245  };
246  VersionedComputationHandle latest_version = computation.GetVersionedHandle();
247
248  // Build the HLO computation.
249  TF_ASSERT_OK_AND_ASSIGN(
250      std::unique_ptr<HloComputation> hlo_computation,
251      computation.BuildHloComputation(latest_version.version, hlo_resolver,
252                                      debug_options));
253
254  //    b         a
255  //    |         |
256  // reshape      |
257  //    |         |
258  // broadcast    |
259  //     \       /
260  //        add
261  EXPECT_EQ(5, hlo_computation->instruction_count());
262  ASSERT_THAT(
263      hlo_computation->root_instruction(),
264      op::Add(op::Parameter(), op::Broadcast(op::Reshape(op::Parameter()))));
265
266  const HloInstruction* broadcast =
267      hlo_computation->root_instruction()->operand(1);
268  EXPECT_TRUE(broadcast->has_sharding());
269
270  const HloInstruction* reshape = broadcast->operand(0);
271  EXPECT_TRUE(reshape->has_sharding());
272}
273
274TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) {
275  auto debug_options = DebugOptions();
276  debug_options.set_xla_eliminate_hlo_implicit_broadcast(true);
277
278  // Build a binary computation with in-dim broadcast and degenerate broadcast.
279  //
280  //  %a = Param({2, 3});
281  //  %b = Param({2, 1, 4});
282  //  %add = Add(%a, %b, {0, 1});
283  ComputationHandle handle;
284  handle.set_handle(123);
285  UserComputation computation("TheComputation", handle);
286
287  ParameterRequest a_request;
288  *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 3});
289  a_request.set_name("a");
290  a_request.set_parameter(0);
291  TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle,
292                          computation.AddParameterInstruction(a_request));
293
294  ParameterRequest b_request;
295  *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 1, 4});
296  b_request.set_name("b");
297  b_request.set_parameter(1);
298  TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle,
299                          computation.AddParameterInstruction(b_request));
300
301  BinaryOpRequest add;
302  add.set_binop(BINOP_ADD);
303  *add.mutable_lhs() = a_handle;
304  *add.mutable_rhs() = b_handle;
305  add.add_broadcast_dimensions(0);
306  add.add_broadcast_dimensions(1);
307  TF_ASSERT_OK(computation.AddBinaryInstruction(add).status());
308
309  auto hlo_resolver = [](const VersionedComputationHandle& handle) {
310    return nullptr;
311  };
312  VersionedComputationHandle latest_version = computation.GetVersionedHandle();
313
314  // Build the HLO computation.
315  TF_ASSERT_OK_AND_ASSIGN(
316      std::unique_ptr<HloComputation> hlo_computation,
317      computation.BuildHloComputation(latest_version.version, hlo_resolver,
318                                      debug_options));
319
320  // The binary operation has in-dim broadcast and degenerate broadcast, should
321  // first do the in-dim broadcast then convert the degnerate broadcast into a
322  // reshape and a broadcast.
323  //
324  //    b         a
325  //    |         |
326  // broadcast reshape
327  //    |         |
328  //    |     broadcast
329  //     \        /
330  //        add
331  EXPECT_EQ(6, hlo_computation->instruction_count());
332  EXPECT_THAT(hlo_computation->root_instruction(), op::Add());
333  const auto& operands = hlo_computation->root_instruction()->operands();
334  ASSERT_EQ(2, operands.size());
335  EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast &&
336              operands[1]->opcode() == HloOpcode::kBroadcast);
337}
338
339}  // namespace
340}  // namespace xla
341