1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/hlo_verifier.h"
17
18#include <memory>
19#include <utility>
20
21#include "tensorflow/compiler/xla/service/hlo_computation.h"
22#include "tensorflow/compiler/xla/service/hlo_instruction.h"
23#include "tensorflow/compiler/xla/service/hlo_opcode.h"
24#include "tensorflow/compiler/xla/shape_util.h"
25#include "tensorflow/compiler/xla/test.h"
26#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
27#include "tensorflow/compiler/xla/types.h"
28#include "tensorflow/compiler/xla/xla_data.pb.h"
29#include "tensorflow/core/lib/core/status_test_util.h"
30
31namespace xla {
32namespace {
33
34using ::testing::HasSubstr;
35
36using HloVerifierTest = HloTestBase;
37
38TEST_F(HloVerifierTest, NullInstructionParent) {
39  HloComputation::Builder builder(TestName());
40  const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
41  HloInstruction* param = builder.AddInstruction(
42      HloInstruction::CreateParameter(0, scalar_shape, "param"));
43  HloInstruction* negate = builder.AddInstruction(
44      HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
45  auto module = CreateNewModule();
46  module->AddEntryComputation(builder.Build());
47
48  TF_ASSERT_OK(verifier().Run(module.get()).status());
49
50  negate->set_parent(nullptr);
51
52  auto status = verifier().Run(module.get()).status();
53  ASSERT_FALSE(status.ok());
54  EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer"));
55}
56
57TEST_F(HloVerifierTest, NullComputationParent) {
58  HloComputation::Builder builder(TestName());
59  const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
60  HloInstruction* param = builder.AddInstruction(
61      HloInstruction::CreateParameter(0, scalar_shape, "param"));
62  builder.AddInstruction(
63      HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
64  auto module = CreateNewModule();
65  HloComputation* computation = module->AddEntryComputation(builder.Build());
66
67  TF_ASSERT_OK(verifier().Run(module.get()).status());
68
69  computation->set_parent(nullptr);
70
71  auto status = verifier().Run(module.get()).status();
72  ASSERT_FALSE(status.ok());
73  EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer"));
74}
75
76TEST_F(HloVerifierTest, DifferentOperandParents) {
77  HloComputation::Builder builder(TestName());
78  const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
79  HloInstruction* param = builder.AddInstruction(
80      HloInstruction::CreateParameter(0, scalar_shape, "param"));
81  HloInstruction* negate = builder.AddInstruction(
82      HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
83  auto module = CreateNewModule();
84  module->AddEntryComputation(builder.Build());
85
86  HloComputation::Builder emb_builder(TestName());
87  HloInstruction* emb_param = emb_builder.AddInstruction(
88      HloInstruction::CreateParameter(0, scalar_shape, "param"));
89  module->AddEmbeddedComputation(emb_builder.Build());
90
91  TF_ASSERT_OK(verifier().Run(module.get()).status());
92  TF_ASSERT_OK(negate->ReplaceOperandWith(0, emb_param));
93
94  auto status = verifier().Run(module.get()).status();
95  ASSERT_FALSE(status.ok());
96  EXPECT_THAT(status.error_message(),
97              HasSubstr("is in a different computation"));
98}
99
100TEST_F(HloVerifierTest, ResetsShapeVerifierState) {
101  HloComputation::Builder builder(TestName());
102  Shape s1 = ShapeUtil::MakeShape(F32, {1});
103  Shape s2 = ShapeUtil::MakeShape(F32, {2});
104
105  HloInstruction* param =
106      builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "param"));
107
108  // Create an add instruction with the incorrect shape.
109  HloInstruction* add = builder.AddInstruction(
110      HloInstruction::CreateBinary(s2, HloOpcode::kAdd, param, param));
111
112  // In order to trigger the bug we're checking for, the instruction with the
113  // bad shape can't be the root of the computation.
114  builder.AddInstruction(
115      HloInstruction::CreateBinary(s2, HloOpcode::kMultiply, add, add));
116
117  auto module = CreateNewModule();
118  module->AddEntryComputation(builder.Build());
119
120  // Run the verifier twice.  It should fail both times, because it shouldn't
121  // carry state in its DFS visitor between runs.
122  EXPECT_FALSE(verifier().Run(module.get()).status().ok());
123  EXPECT_FALSE(verifier().Run(module.get()).status().ok());
124}
125
126}  // namespace
127}  // namespace xla
128