testlib.cc revision 2585e0a75e2802ca8b9877fd06544ecca0b95cd9
1/* Copyright 2015 Google Inc. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/graph/testlib.h"
17
18#include <vector>
19#include "tensorflow/core/framework/graph.pb.h"
20#include "tensorflow/core/framework/node_def_builder.h"
21#include "tensorflow/core/framework/op.h"
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/types.h"
24#include "tensorflow/core/framework/types.pb.h"
25#include "tensorflow/core/graph/graph.h"
26#include "tensorflow/core/graph/node_builder.h"
27#include "tensorflow/core/kernels/constant_op.h"
28#include "tensorflow/core/lib/core/status.h"
29#include "tensorflow/core/platform/logging.h"
30
31namespace tensorflow {
32
33// HostConst: forced to generate output on the host.
34// Only used by testlib; no op is registered for this kernel
35// externally (i.e., in array_ops.cc)
36REGISTER_KERNEL_BUILDER(Name("HostConst").Device(DEVICE_CPU), HostConstantOp);
37REGISTER_KERNEL_BUILDER(
38    Name("HostConst").Device(DEVICE_GPU).HostMemory("output"), HostConstantOp);
39
40// Register the HostConst Op
41// Returns a constant tensor on the host.  Useful for writing C++ tests
42// and benchmarks which run on GPU but require arguments pinned to the host.
43// Used by test::graph::HostConstant.
44// value: Attr `value` is the tensor to return.
45REGISTER_OP("HostConst")
46    .Output("output: dtype")
47    .Attr("value: tensor")
48    .Attr("dtype: type");
49
50namespace test {
51namespace graph {
52
53Node* Send(Graph* g, Node* input, const string& tensor, const string& sender,
54           const uint64 sender_incarnation, const string& receiver) {
55  Node* ret;
56  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Send")
57                  .Input(input, 0)
58                  .Attr("tensor_name", tensor)
59                  .Attr("send_device", sender)
60                  .Attr("send_device_incarnation",
61                        static_cast<int64>(sender_incarnation))
62                  .Attr("recv_device", receiver)
63                  .Finalize(g, &ret));
64  return ret;
65}
66
67Node* Recv(Graph* g, const string& tensor, const string& type,
68           const string& sender, const uint64 sender_incarnation,
69           const string& receiver) {
70  Node* ret;
71  DataType dtype;
72  CHECK(DataTypeFromString(type, &dtype));
73  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Recv")
74                  .Attr("tensor_type", dtype)
75                  .Attr("tensor_name", tensor)
76                  .Attr("send_device", sender)
77                  .Attr("send_device_incarnation",
78                        static_cast<int64>(sender_incarnation))
79                  .Attr("recv_device", receiver)
80                  .Finalize(g, &ret));
81  return ret;
82}
83
84Node* Constant(Graph* g, const Tensor& tensor) {
85  Node* ret;
86  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Const")
87                  .Attr("dtype", tensor.dtype())
88                  .Attr("value", tensor)
89                  .Finalize(g, &ret));
90  return ret;
91}
92
93Node* Constant(Graph* g, const Tensor& tensor, const string& name) {
94  Node* ret;
95  TF_CHECK_OK(NodeBuilder(name, "Const")
96                  .Attr("dtype", tensor.dtype())
97                  .Attr("value", tensor)
98                  .Finalize(g, &ret));
99  return ret;
100}
101
102Node* HostConstant(Graph* g, const Tensor& tensor) {
103  return HostConstant(g, tensor, g->NewName("n"));
104}
105
106Node* HostConstant(Graph* g, const Tensor& tensor, const string& name) {
107  Node* ret;
108  TF_CHECK_OK(NodeBuilder(name, "HostConst")
109                  .Attr("dtype", tensor.dtype())
110                  .Attr("value", tensor)
111                  .Finalize(g, &ret));
112  return ret;
113}
114
115Node* Var(Graph* g, const DataType dtype, const TensorShape& shape) {
116  Node* ret;
117  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Variable")
118                  .Attr("dtype", dtype)
119                  .Attr("shape", shape)
120                  .Finalize(g, &ret));
121  return ret;
122}
123
124Node* Assign(Graph* g, Node* var, Node* val) {
125  Node* ret;
126  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Assign")
127                  .Input(var)
128                  .Input(val)
129                  .Attr("use_locking", true)
130                  .Finalize(g, &ret));
131  return ret;
132}
133
134Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes,
135             bool keep_dims) {
136  Node* ret;
137  TF_CHECK_OK(NodeBuilder(g->NewName("n"), reduce)
138                  .Input(data)
139                  .Input(axes)
140                  .Attr("keep_dims", keep_dims)
141                  .Finalize(g, &ret));
142  return ret;
143}
144
145Node* QuantizeToUINT8(Graph* g, Node* data) {
146  Node* ret;
147  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Quantize")
148                  .Input(data)
149                  .Attr("T", DT_QUINT8)
150                  .Attr("max_range", 1.0f)
151                  .Attr("min_range", -1.0f)
152                  .Finalize(g, &ret));
153  return ret;
154}
155
156Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a,
157             bool transpose_b) {
158  Node* ret;
159  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatMul")
160                  .Input(in0)
161                  .Input(in1)
162                  .Attr("transpose_a", transpose_a)
163                  .Attr("transpose_b", transpose_b)
164                  .Finalize(g, &ret));
165  return ret;
166}
167
168Node* RandomNumberGenerator(const string& op, Graph* g, Node* input,
169                            DataType dtype) {
170  Node* ret;
171  TF_CHECK_OK(NodeBuilder(g->NewName("n"), op)
172                  .Input(input)
173                  .Attr("dtype", dtype)
174                  .Attr("seed", 0)
175                  .Finalize(g, &ret));
176  return ret;
177}
178
179Node* RandomUniform(Graph* g, Node* input, DataType dtype) {
180  return RandomNumberGenerator("RandomUniform", g, input, dtype);
181}
182
183Node* RandomGaussian(Graph* g, Node* input, DataType dtype) {
184  return RandomNumberGenerator("RandomStandardNormal", g, input, dtype);
185}
186
187Node* RandomParameters(Graph* g, Node* input, DataType dtype) {
188  return RandomNumberGenerator("RandomParameters", g, input, dtype);
189}
190
191Node* Unary(Graph* g, const string& func, Node* input, int index) {
192  Node* ret;
193  TF_CHECK_OK(
194      NodeBuilder(g->NewName("n"), func).Input(input, index).Finalize(g, &ret));
195  return ret;
196}
197
198Node* Binary(Graph* g, const string& func, Node* in0, Node* in1) {
199  Node* ret;
200  TF_CHECK_OK(NodeBuilder(g->NewName("n"), func)
201                  .Input(in0)
202                  .Input(in1)
203                  .Finalize(g, &ret));
204  return ret;
205}
206
207Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins) {
208  Node* ret;
209  auto b = NodeBuilder(g->NewName("n"), func);
210  for (Node* n : ins) b = b.Input(n);
211  TF_CHECK_OK(b.Finalize(g, &ret));
212  return ret;
213}
214
215Node* Identity(Graph* g, Node* input, int index) {
216  Node* ret;
217  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Identity")
218                  .Input(input, index)
219                  .Finalize(g, &ret));
220  return ret;
221}
222
223Node* Add(Graph* g, Node* in0, Node* in1) { return Binary(g, "Add", in0, in1); }
224
225Node* Error(Graph* g, Node* input, const string& errmsg) {
226  Node* ret;
227  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Error")
228                  .Input(input)
229                  .Attr("message", errmsg)
230                  .Finalize(g, &ret));
231  return ret;
232}
233
234Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type) {
235  DCHECK(out_type != invalid_type);
236  Node* ret;
237  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "InvalidRefType")
238                  .Attr("TIn", out_type)
239                  .Attr("TOut", invalid_type)
240                  .Finalize(g, &ret));
241  return ret;
242}
243
244Node* Delay(Graph* g, Node* input, Microseconds delay_micros) {
245  Node* ret;
246  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Delay")
247                  .Input(input)
248                  .Attr("micros", delay_micros.value())
249                  .Finalize(g, &ret));
250  return ret;
251}
252
253Node* NoOp(Graph* g, const std::vector<Node*>& control_inputs) {
254  Node* ret;
255  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "NoOp")
256                  .ControlInputs(control_inputs)
257                  .Finalize(g, &ret));
258  return ret;
259}
260
261Node* Switch(Graph* g, Node* in0, Node* in1) {
262  Node* ret;
263  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Switch")
264                  .Input(in0)
265                  .Input(in1)
266                  .Finalize(g, &ret));
267  return ret;
268}
269
270Node* Enter(Graph* g, Node* input, const string& frame_name) {
271  Node* ret;
272  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Enter")
273                  .Input(input)
274                  .Attr("frame_name", frame_name)
275                  .Finalize(g, &ret));
276  return ret;
277}
278
279Node* Exit(Graph* g, Node* input) {
280  Node* ret;
281  TF_CHECK_OK(
282      NodeBuilder(g->NewName("n"), "Exit").Input(input).Finalize(g, &ret));
283  return ret;
284}
285
286Node* Merge(Graph* g, Node* in0, Node* in1) {
287  Node* ret;
288  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Merge")
289                  .Input({in0, in1})
290                  .Finalize(g, &ret));
291  return ret;
292}
293
294Node* Merge(Graph* g, Node* in0, gtl::ArraySlice<string> remaining_in) {
295  std::vector<NodeBuilder::NodeOut> inputs;
296  inputs.reserve(remaining_in.size() + 1);
297  inputs.emplace_back(in0);
298  for (const string& in_name : remaining_in) {
299    inputs.emplace_back(in_name, 0, inputs[0].dt);
300  }
301
302  Node* ret;
303  TF_CHECK_OK(
304      NodeBuilder(g->NewName("n"), "Merge").Input(inputs).Finalize(g, &ret));
305  return ret;
306}
307
308Node* Next(Graph* g, const string& name, Node* input) {
309  Node* ret;
310  TF_CHECK_OK(
311      NodeBuilder(name, "NextIteration").Input(input).Finalize(g, &ret));
312  return ret;
313}
314
315Node* LoopCond(Graph* g, Node* input) {
316  Node* ret;
317  TF_CHECK_OK(
318      NodeBuilder(g->NewName("n"), "LoopCond").Input(input).Finalize(g, &ret));
319  return ret;
320}
321
322Node* Less(Graph* g, Node* in0, Node* in1) {
323  return Binary(g, "Less", in0, in1);
324}
325
326Node* Select(Graph* g, Node* c, Node* inx, Node* iny) {
327  Node* ret;
328  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Select")
329                  .Input(c)
330                  .Input(inx)
331                  .Input(iny)
332                  .Finalize(g, &ret));
333  return ret;
334}
335
336Node* Cast(Graph* g, Node* in, DataType dst) {
337  Node* ret;
338  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Cast")
339                  .Input(in)
340                  .Attr("DstT", dst)
341                  .Finalize(g, &ret));
342  return ret;
343}
344
345Node* BroadcastGradientArgs(Graph* g, Node* s0, Node* s1) {
346  Node* ret;
347  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastGradientArgs")
348                  .Input(s0)
349                  .Input(s1)
350                  .Finalize(g, &ret));
351  return ret;
352}
353
354void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); }
355
356}  // end namespace graph
357}  // end namespace test
358}  // end namespace tensorflow
359