1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/hlo_module.h"
17
18#include "tensorflow/compiler/xla/literal_util.h"
19#include "tensorflow/compiler/xla/ptr_util.h"
20#include "tensorflow/compiler/xla/service/hlo_computation.h"
21#include "tensorflow/compiler/xla/service/hlo_instruction.h"
22#include "tensorflow/compiler/xla/shape_util.h"
23#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
24#include "tensorflow/compiler/xla/xla_data.pb.h"
25
26#include "tensorflow/compiler/xla/test.h"
27#include "tensorflow/core/lib/gtl/array_slice.h"
28
29namespace xla {
30
31namespace {
32
33class HloModuleTest : public HloTestBase {
34 protected:
35  HloModuleTest() {}
36
37  // Create a computation which returns a constant.
38  std::unique_ptr<HloComputation> CreateConstantComputation() {
39    auto builder = HloComputation::Builder("Constant");
40    builder.AddInstruction(
41        HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
42    return builder.Build();
43  }
44
45  // Creates a computation which calls the given zero-parameter computations.
46  std::unique_ptr<HloComputation> CreateCallComputation(
47      tensorflow::gtl::ArraySlice<HloComputation*> computations) {
48    auto builder = HloComputation::Builder("Call");
49    for (auto computation : computations) {
50      builder.AddInstruction(
51          HloInstruction::CreateCall(r0f32_, {}, computation));
52    }
53    return builder.Build();
54  }
55
56  Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
57};
58
59TEST_F(HloModuleTest, OneComputationPostOrder) {
60  // Create a module with a single computation.
61  auto module = CreateNewModule();
62  auto computation = module->AddEntryComputation(CreateConstantComputation());
63
64  EXPECT_THAT(module->MakeComputationPostOrder(),
65              ::testing::ElementsAre(computation));
66}
67
68TEST_F(HloModuleTest, TwoComputationsPostOrder) {
69  // Create a module with two unconnected computations.
70  auto module = CreateNewModule();
71  auto computation1 = module->AddEntryComputation(CreateConstantComputation());
72  auto computation2 =
73      module->AddEmbeddedComputation(CreateConstantComputation());
74
75  EXPECT_THAT(module->MakeComputationPostOrder(),
76              ::testing::UnorderedElementsAre(computation1, computation2));
77
78  // We specified the same name for both computations, but the HloModule should
79  // have made the names unique.
80  EXPECT_EQ(computation1->name(), "Constant");
81  EXPECT_EQ(computation2->name(), "Constant.1");
82}
83
84TEST_F(HloModuleTest, CloneTest) {
85  // Create and copy a module with a diamond call graph of computations.
86  auto module = CreateNewModule();
87  auto computation1 =
88      module->AddEmbeddedComputation(CreateConstantComputation());
89  auto computation2 =
90      module->AddEmbeddedComputation(CreateCallComputation({computation1}));
91  auto computation3 =
92      module->AddEmbeddedComputation(CreateCallComputation({computation1}));
93  module->AddEntryComputation(
94      CreateCallComputation({computation2, computation3}));
95
96  auto post_order = module->MakeComputationPostOrder();
97  auto cloned_module = module->Clone("copy");
98  auto post_order_copied = cloned_module->MakeComputationPostOrder();
99
100  EXPECT_EQ(post_order.size(), post_order_copied.size());
101  for (auto origin = post_order.begin(), copied = post_order_copied.begin();
102       origin != post_order.end() && copied != post_order_copied.end();
103       ++origin, ++copied) {
104    EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name());
105  }
106}
107
108TEST_F(HloModuleTest, CloneHasFusion) {
109  auto module = CreateNewModule();
110
111  // Create the fused computation.
112  HloComputation* fused_computation;
113  {
114    auto b = HloComputation::Builder("Fused");
115    auto x = b.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
116    b.AddInstruction(
117        HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, x, x));
118    fused_computation = module->AddEmbeddedComputation(b.Build());
119  }
120
121  // Create the entry computation.
122  {
123    auto b = HloComputation::Builder("Entry");
124    auto input = b.AddInstruction(
125        HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
126    b.AddInstruction(
127        HloInstruction::CreateFusion(r0f32_, HloInstruction::FusionKind::kInput,
128                                     /*operands=*/{input}, fused_computation));
129    module->AddEntryComputation(b.Build());
130  }
131
132  auto post_order = module->MakeComputationPostOrder();
133  auto cloned_module = module->Clone("copy");
134  auto post_order_copied = cloned_module->MakeComputationPostOrder();
135
136  EXPECT_EQ(post_order.size(), post_order_copied.size());
137  for (auto origin = post_order.begin(), copied = post_order_copied.begin();
138       origin != post_order.end() && copied != post_order_copied.end();
139       ++origin, ++copied) {
140    if ((*origin)->name() == "Fused") {
141      // Clone of the fused computation is handled when its fusion instruction
142      // is cloned, which always use suffix ".clone".
143      EXPECT_EQ((*origin)->name() + ".clone", (*copied)->name());
144    } else {
145      EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name());
146    }
147  }
148}
149
150TEST_F(HloModuleTest, DiamondComputationsPostOrder) {
151  // Create a module with a diamond call graph of computations.
152  auto module = CreateNewModule();
153  auto computation1 =
154      module->AddEmbeddedComputation(CreateConstantComputation());
155  auto computation2 =
156      module->AddEmbeddedComputation(CreateCallComputation({computation1}));
157  auto computation3 =
158      module->AddEmbeddedComputation(CreateCallComputation({computation1}));
159  auto computation4 = module->AddEntryComputation(
160      CreateCallComputation({computation2, computation3}));
161
162  auto post_order = module->MakeComputationPostOrder();
163  EXPECT_THAT(post_order,
164              ::testing::UnorderedElementsAre(computation1, computation2,
165                                              computation3, computation4));
166  EXPECT_EQ(post_order.back(), computation4);
167  EXPECT_EQ(post_order.front(), computation1);
168}
169
170TEST_F(HloModuleTest, LargeConstantToString) {
171  // Create a module with a single computation.
172  auto module = CreateNewModule();
173  auto builder = HloComputation::Builder("Constant");
174  std::vector<float> values(16, 42.0);
175  builder.AddInstruction(
176      HloInstruction::CreateConstant(Literal::CreateR1<float>(values)));
177  module->AddEntryComputation(builder.Build());
178
179  EXPECT_EQ(
180      "HloModule LargeConstantToString\n\nENTRY %Constant () -> f32[16] {\n  "
181      "ROOT %constant = f32[16]{0} constant({...})\n}\n\n",
182      module->ToString(HloPrintOptions().set_print_large_constants(false)));
183
184  EXPECT_EQ(
185      "HloModule LargeConstantToString\n\nENTRY %Constant () -> f32[16] {\n  "
186      "ROOT %constant = f32[16]{0} constant({42, 42, 42, 42, 42, 42, 42, 42, "
187      "42, 42, 42, 42, 42, 42, 42, 42})\n}\n\n",
188      module->ToString(HloPrintOptions().set_print_large_constants(true)));
189}
190
191TEST_F(HloModuleTest, UniqueModuleId) {
192  auto module_a = CreateNewModule();
193  auto module_b = CreateNewModule();
194  EXPECT_NE(module_a->unique_id(), module_b->unique_id());
195}
196
197}  // namespace
198
199}  // namespace xla
200