1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/core/framework/op_segment.h" 17 18#include <vector> 19#include "tensorflow/core/framework/allocator.h" 20#include "tensorflow/core/framework/node_def_builder.h" 21#include "tensorflow/core/framework/op_kernel.h" 22#include "tensorflow/core/kernels/ops_util.h" 23#include "tensorflow/core/lib/core/errors.h" 24#include "tensorflow/core/lib/core/status_test_util.h" 25#include "tensorflow/core/lib/strings/strcat.h" 26#include "tensorflow/core/platform/logging.h" 27#include "tensorflow/core/platform/test.h" 28#include "tensorflow/core/public/version.h" 29 30namespace tensorflow { 31 32class OpSegmentTest : public ::testing::Test { 33 protected: 34 DeviceBase device_; 35 std::vector<NodeDef> int32_nodedefs_; 36 std::vector<NodeDef> float_nodedefs_; 37 38 OpSegmentTest() : device_(Env::Default()) { 39 for (int i = 0; i < 10; ++i) { 40 NodeDef def; 41 TF_CHECK_OK(NodeDefBuilder(strings::StrCat("op", i), "Mul") 42 .Input("x", 0, DT_INT32) 43 .Input("y", 0, DT_INT32) 44 .Finalize(&def)); 45 int32_nodedefs_.push_back(def); 46 TF_CHECK_OK(NodeDefBuilder(strings::StrCat("op", i), "Mul") 47 .Input("x", 0, DT_FLOAT) 48 .Input("y", 0, DT_FLOAT) 49 .Finalize(&def)); 50 float_nodedefs_.push_back(def); 51 } 52 } 53 54 void ValidateOpAndTypes(OpKernel* op, const NodeDef& expected, DataType dt) { 55 ASSERT_NE(op, nullptr); 56 EXPECT_EQ(expected.DebugString(), op->def().DebugString()); 57 EXPECT_EQ(2, op->num_inputs()); 58 EXPECT_EQ(dt, op->input_type(0)); 59 EXPECT_EQ(dt, op->input_type(1)); 60 EXPECT_EQ(1, op->num_outputs()); 61 EXPECT_EQ(dt, op->output_type(0)); 62 } 63 64 OpSegment::CreateKernelFn GetFn(const NodeDef* ndef) { 65 return [this, ndef](OpKernel** kernel) { 66 Status s; 67 auto created = CreateOpKernel(DEVICE_CPU, &device_, cpu_allocator(), 68 *ndef, TF_GRAPH_DEF_VERSION, &s); 69 if (s.ok()) { 70 *kernel = created.release(); 71 } 72 return s; 73 }; 74 } 75}; 76 77TEST_F(OpSegmentTest, Basic) { 78 OpSegment opseg; 79 OpKernel* op; 80 81 opseg.AddHold("A"); 82 opseg.AddHold("B"); 83 for (int i = 0; i < 10; ++i) { 84 // Register in session A. 85 auto* ndef = &float_nodedefs_[i]; 86 TF_EXPECT_OK(opseg.FindOrCreate("A", ndef->name(), &op, GetFn(ndef))); 87 ValidateOpAndTypes(op, *ndef, DT_FLOAT); 88 89 // Register in session B. 90 ndef = &int32_nodedefs_[i]; 91 TF_EXPECT_OK(opseg.FindOrCreate("B", ndef->name(), &op, GetFn(ndef))); 92 ValidateOpAndTypes(op, *ndef, DT_INT32); 93 } 94 95 auto reterr = [](OpKernel** kernel) { 96 return errors::Internal("Should not be called"); 97 }; 98 for (int i = 0; i < 10; ++i) { 99 // Lookup op in session A. 100 TF_EXPECT_OK( 101 opseg.FindOrCreate("A", strings::StrCat("op", i), &op, reterr)); 102 ValidateOpAndTypes(op, float_nodedefs_[i], DT_FLOAT); 103 104 // Lookup op in session B. 105 TF_EXPECT_OK( 106 opseg.FindOrCreate("B", strings::StrCat("op", i), &op, reterr)); 107 ValidateOpAndTypes(op, int32_nodedefs_[i], DT_INT32); 108 } 109 110 opseg.RemoveHold("A"); 111 opseg.RemoveHold("B"); 112} 113 114TEST_F(OpSegmentTest, SessionNotFound) { 115 OpSegment opseg; 116 OpKernel* op; 117 NodeDef def = float_nodedefs_[0]; 118 Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def)); 119 EXPECT_TRUE(errors::IsNotFound(s)) << s; 120} 121 122TEST_F(OpSegmentTest, CreateFailure) { 123 OpSegment opseg; 124 OpKernel* op; 125 NodeDef def = float_nodedefs_[0]; 126 def.set_op("nonexistop"); 127 opseg.AddHold("A"); 128 Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def)); 129 EXPECT_TRUE(errors::IsNotFound(s)) << s; 130 opseg.RemoveHold("A"); 131} 132 133TEST_F(OpSegmentTest, AddRemoveHolds) { 134 OpSegment opseg; 135 OpKernel* op; 136 const auto& ndef = int32_nodedefs_[0]; 137 138 // No op. 139 opseg.RemoveHold("null"); 140 141 // Thread1 register the op and wants to ensure it alive. 142 opseg.AddHold("foo"); 143 TF_EXPECT_OK(opseg.FindOrCreate("foo", ndef.name(), &op, GetFn(&ndef))); 144 145 // Thread2 starts some execution needs "op" to be alive. 146 opseg.AddHold("foo"); 147 148 // Thread1 clears session "foo". E.g., a master sends CleanupGraph 149 // before an execution finishes. 150 opseg.RemoveHold("foo"); 151 152 // Thread2 should still be able to access "op". 153 ValidateOpAndTypes(op, ndef, DT_INT32); 154 155 // Thread2 then remove its hold on "foo". 156 opseg.RemoveHold("foo"); 157} 158 159} // namespace tensorflow 160