1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
17
18#include <memory>
19#include <string>
20
21#include "tensorflow/compiler/xla/service/hlo_computation.h"
22#include "tensorflow/compiler/xla/service/hlo_instruction.h"
23#include "tensorflow/compiler/xla/service/hlo_opcode.h"
24#include "tensorflow/compiler/xla/service/hlo_ordering.h"
25#include "tensorflow/compiler/xla/shape_util.h"
26#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
27#include "tensorflow/compiler/xla/types.h"
28#include "tensorflow/compiler/xla/xla_data.pb.h"
29
30namespace xla {
31namespace {
32
33class MinimumMemoryForSequenceTest : public HloTestBase {};
34
35TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
36  auto module = CreateNewModule();
37  const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
38  const Shape tuple_shape =
39      ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
40
41  auto cond_builder = HloComputation::Builder("WhileCond");
42  // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
43  HloInstruction* cond_param = cond_builder.AddInstruction(
44      HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
45  HloInstruction* cond_iter = cond_builder.AddInstruction(
46      HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
47  HloInstruction* cond_data = cond_builder.AddInstruction(
48      HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
49  // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
50  HloInstruction* cond_lt = cond_builder.AddInstruction(
51      HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
52                                   HloOpcode::kLt, cond_iter, cond_data));
53  HloComputation* cond_computation =
54      module->AddEmbeddedComputation(cond_builder.Build());
55
56  auto body_builder = HloComputation::Builder("WhileBody");
57  // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
58  HloInstruction* body_param = body_builder.AddInstruction(
59      HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
60  HloComputation* body_computation =
61      module->AddEmbeddedComputation(body_builder.Build());
62
63  auto builder = HloComputation::Builder(TestName());
64  // Entry params: 8 bytes (4 bytes per param), TOTAL=8
65  HloInstruction* iter = builder.AddInstruction(
66      HloInstruction::CreateParameter(0, scalar_shape, "param_iter"));
67  HloInstruction* data = builder.AddInstruction(
68      HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
69  // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24
70  HloInstruction* tuple =
71      builder.AddInstruction(HloInstruction::CreateTuple({iter, data}));
72  // While: 8 bytes (4 bytes per element), TOTAL=32
73  // Both cond and body use a max of 24 bytes, TOTAL=56
74  HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
75      tuple_shape, cond_computation, body_computation, tuple));
76  HloComputation* entry_computation =
77      module->AddEntryComputation(builder.Build());
78
79  auto size_fn = [](const LogicalBuffer& buffer) {
80    return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
81  };
82
83  SequentialHloOrdering::HloModuleSequence module_sequence;
84  module_sequence[cond_computation] = {cond_param, cond_iter, cond_data,
85                                       cond_lt};
86  module_sequence[body_computation] = {body_param};
87  module_sequence[entry_computation] = {iter, data, tuple, while_op};
88  EXPECT_EQ(56,
89            MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie());
90}
91
92}  // namespace
93}  // namespace xla
94