1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/service/hlo_verifier.h" 17 18#include <memory> 19#include <utility> 20 21#include "tensorflow/compiler/xla/service/hlo_computation.h" 22#include "tensorflow/compiler/xla/service/hlo_instruction.h" 23#include "tensorflow/compiler/xla/service/hlo_opcode.h" 24#include "tensorflow/compiler/xla/shape_util.h" 25#include "tensorflow/compiler/xla/test.h" 26#include "tensorflow/compiler/xla/tests/hlo_test_base.h" 27#include "tensorflow/compiler/xla/types.h" 28#include "tensorflow/compiler/xla/xla_data.pb.h" 29#include "tensorflow/core/lib/core/status_test_util.h" 30 31namespace xla { 32namespace { 33 34using ::testing::HasSubstr; 35 36using HloVerifierTest = HloTestBase; 37 38TEST_F(HloVerifierTest, NullInstructionParent) { 39 HloComputation::Builder builder(TestName()); 40 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); 41 HloInstruction* param = builder.AddInstruction( 42 HloInstruction::CreateParameter(0, scalar_shape, "param")); 43 HloInstruction* negate = builder.AddInstruction( 44 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); 45 auto module = CreateNewModule(); 46 module->AddEntryComputation(builder.Build()); 47 48 TF_ASSERT_OK(verifier().Run(module.get()).status()); 49 50 negate->set_parent(nullptr); 51 52 auto status = verifier().Run(module.get()).status(); 53 ASSERT_FALSE(status.ok()); 54 EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer")); 55} 56 57TEST_F(HloVerifierTest, NullComputationParent) { 58 HloComputation::Builder builder(TestName()); 59 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); 60 HloInstruction* param = builder.AddInstruction( 61 HloInstruction::CreateParameter(0, scalar_shape, "param")); 62 builder.AddInstruction( 63 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); 64 auto module = CreateNewModule(); 65 HloComputation* computation = module->AddEntryComputation(builder.Build()); 66 67 TF_ASSERT_OK(verifier().Run(module.get()).status()); 68 69 computation->set_parent(nullptr); 70 71 auto status = verifier().Run(module.get()).status(); 72 ASSERT_FALSE(status.ok()); 73 EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer")); 74} 75 76TEST_F(HloVerifierTest, DifferentOperandParents) { 77 HloComputation::Builder builder(TestName()); 78 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); 79 HloInstruction* param = builder.AddInstruction( 80 HloInstruction::CreateParameter(0, scalar_shape, "param")); 81 HloInstruction* negate = builder.AddInstruction( 82 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); 83 auto module = CreateNewModule(); 84 module->AddEntryComputation(builder.Build()); 85 86 HloComputation::Builder emb_builder(TestName()); 87 HloInstruction* emb_param = emb_builder.AddInstruction( 88 HloInstruction::CreateParameter(0, scalar_shape, "param")); 89 module->AddEmbeddedComputation(emb_builder.Build()); 90 91 TF_ASSERT_OK(verifier().Run(module.get()).status()); 92 TF_ASSERT_OK(negate->ReplaceOperandWith(0, emb_param)); 93 94 auto status = verifier().Run(module.get()).status(); 95 ASSERT_FALSE(status.ok()); 96 EXPECT_THAT(status.error_message(), 97 HasSubstr("is in a different computation")); 98} 99 100TEST_F(HloVerifierTest, ResetsShapeVerifierState) { 101 HloComputation::Builder builder(TestName()); 102 Shape s1 = ShapeUtil::MakeShape(F32, {1}); 103 Shape s2 = ShapeUtil::MakeShape(F32, {2}); 104 105 HloInstruction* param = 106 builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "param")); 107 108 // Create an add instruction with the incorrect shape. 109 HloInstruction* add = builder.AddInstruction( 110 HloInstruction::CreateBinary(s2, HloOpcode::kAdd, param, param)); 111 112 // In order to trigger the bug we're checking for, the instruction with the 113 // bad shape can't be the root of the computation. 114 builder.AddInstruction( 115 HloInstruction::CreateBinary(s2, HloOpcode::kMultiply, add, add)); 116 117 auto module = CreateNewModule(); 118 module->AddEntryComputation(builder.Build()); 119 120 // Run the verifier twice. It should fail both times, because it shouldn't 121 // carry state in its DFS visitor between runs. 122 EXPECT_FALSE(verifier().Run(module.get()).status().ok()); 123 EXPECT_FALSE(verifier().Run(module.get()).status().ok()); 124} 125 126} // namespace 127} // namespace xla 128