1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/op_segment.h"
17
18#include <vector>
19#include "tensorflow/core/framework/allocator.h"
20#include "tensorflow/core/framework/node_def_builder.h"
21#include "tensorflow/core/framework/op_kernel.h"
22#include "tensorflow/core/kernels/ops_util.h"
23#include "tensorflow/core/lib/core/errors.h"
24#include "tensorflow/core/lib/core/status_test_util.h"
25#include "tensorflow/core/lib/strings/strcat.h"
26#include "tensorflow/core/platform/logging.h"
27#include "tensorflow/core/platform/test.h"
28#include "tensorflow/core/public/version.h"
29
30namespace tensorflow {
31
32class OpSegmentTest : public ::testing::Test {
33 protected:
34  DeviceBase device_;
35  std::vector<NodeDef> int32_nodedefs_;
36  std::vector<NodeDef> float_nodedefs_;
37
38  OpSegmentTest() : device_(Env::Default()) {
39    for (int i = 0; i < 10; ++i) {
40      NodeDef def;
41      TF_CHECK_OK(NodeDefBuilder(strings::StrCat("op", i), "Mul")
42                      .Input("x", 0, DT_INT32)
43                      .Input("y", 0, DT_INT32)
44                      .Finalize(&def));
45      int32_nodedefs_.push_back(def);
46      TF_CHECK_OK(NodeDefBuilder(strings::StrCat("op", i), "Mul")
47                      .Input("x", 0, DT_FLOAT)
48                      .Input("y", 0, DT_FLOAT)
49                      .Finalize(&def));
50      float_nodedefs_.push_back(def);
51    }
52  }
53
54  void ValidateOpAndTypes(OpKernel* op, const NodeDef& expected, DataType dt) {
55    ASSERT_NE(op, nullptr);
56    EXPECT_EQ(expected.DebugString(), op->def().DebugString());
57    EXPECT_EQ(2, op->num_inputs());
58    EXPECT_EQ(dt, op->input_type(0));
59    EXPECT_EQ(dt, op->input_type(1));
60    EXPECT_EQ(1, op->num_outputs());
61    EXPECT_EQ(dt, op->output_type(0));
62  }
63
64  OpSegment::CreateKernelFn GetFn(const NodeDef* ndef) {
65    return [this, ndef](OpKernel** kernel) {
66      Status s;
67      auto created = CreateOpKernel(DEVICE_CPU, &device_, cpu_allocator(),
68                                    *ndef, TF_GRAPH_DEF_VERSION, &s);
69      if (s.ok()) {
70        *kernel = created.release();
71      }
72      return s;
73    };
74  }
75};
76
77TEST_F(OpSegmentTest, Basic) {
78  OpSegment opseg;
79  OpKernel* op;
80
81  opseg.AddHold("A");
82  opseg.AddHold("B");
83  for (int i = 0; i < 10; ++i) {
84    // Register in session A.
85    auto* ndef = &float_nodedefs_[i];
86    TF_EXPECT_OK(opseg.FindOrCreate("A", ndef->name(), &op, GetFn(ndef)));
87    ValidateOpAndTypes(op, *ndef, DT_FLOAT);
88
89    // Register in session B.
90    ndef = &int32_nodedefs_[i];
91    TF_EXPECT_OK(opseg.FindOrCreate("B", ndef->name(), &op, GetFn(ndef)));
92    ValidateOpAndTypes(op, *ndef, DT_INT32);
93  }
94
95  auto reterr = [](OpKernel** kernel) {
96    return errors::Internal("Should not be called");
97  };
98  for (int i = 0; i < 10; ++i) {
99    // Lookup op in session A.
100    TF_EXPECT_OK(
101        opseg.FindOrCreate("A", strings::StrCat("op", i), &op, reterr));
102    ValidateOpAndTypes(op, float_nodedefs_[i], DT_FLOAT);
103
104    // Lookup op in session B.
105    TF_EXPECT_OK(
106        opseg.FindOrCreate("B", strings::StrCat("op", i), &op, reterr));
107    ValidateOpAndTypes(op, int32_nodedefs_[i], DT_INT32);
108  }
109
110  opseg.RemoveHold("A");
111  opseg.RemoveHold("B");
112}
113
114TEST_F(OpSegmentTest, SessionNotFound) {
115  OpSegment opseg;
116  OpKernel* op;
117  NodeDef def = float_nodedefs_[0];
118  Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def));
119  EXPECT_TRUE(errors::IsNotFound(s)) << s;
120}
121
122TEST_F(OpSegmentTest, CreateFailure) {
123  OpSegment opseg;
124  OpKernel* op;
125  NodeDef def = float_nodedefs_[0];
126  def.set_op("nonexistop");
127  opseg.AddHold("A");
128  Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def));
129  EXPECT_TRUE(errors::IsNotFound(s)) << s;
130  opseg.RemoveHold("A");
131}
132
133TEST_F(OpSegmentTest, AddRemoveHolds) {
134  OpSegment opseg;
135  OpKernel* op;
136  const auto& ndef = int32_nodedefs_[0];
137
138  // No op.
139  opseg.RemoveHold("null");
140
141  // Thread1 register the op and wants to ensure it alive.
142  opseg.AddHold("foo");
143  TF_EXPECT_OK(opseg.FindOrCreate("foo", ndef.name(), &op, GetFn(&ndef)));
144
145  // Thread2 starts some execution needs "op" to be alive.
146  opseg.AddHold("foo");
147
148  // Thread1 clears session "foo".  E.g., a master sends CleanupGraph
149  // before an execution finishes.
150  opseg.RemoveHold("foo");
151
152  // Thread2 should still be able to access "op".
153  ValidateOpAndTypes(op, ndef, DT_INT32);
154
155  // Thread2 then remove its hold on "foo".
156  opseg.RemoveHold("foo");
157}
158
159}  // namespace tensorflow
160