testlib.cc revision 51e6e84fadd7b37c6e8e5c13cd0de4c7c8e2959f
1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/graph/testlib.h"
17
18#include <vector>
19#include "tensorflow/core/framework/graph.pb.h"
20#include "tensorflow/core/framework/node_def_builder.h"
21#include "tensorflow/core/framework/op.h"
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/types.h"
24#include "tensorflow/core/framework/types.pb.h"
25#include "tensorflow/core/graph/graph.h"
26#include "tensorflow/core/graph/node_builder.h"
27#include "tensorflow/core/kernels/constant_op.h"
28#include "tensorflow/core/lib/core/status.h"
29#include "tensorflow/core/platform/logging.h"
30
31namespace tensorflow {
32
33// HostConst: forced to generate output on the host.
34// Only used by testlib; no op is registered for this kernel
35// externally (i.e., in array_ops.cc)
36REGISTER_KERNEL_BUILDER(Name("HostConst").Device(DEVICE_CPU), HostConstantOp);
37REGISTER_KERNEL_BUILDER(
38    Name("HostConst").Device(DEVICE_GPU).HostMemory("output"), HostConstantOp);
39
40// Register the HostConst Op
41// Returns a constant tensor on the host.  Useful for writing C++ tests
42// and benchmarks which run on GPU but require arguments pinned to the host.
43// Used by test::graph::HostConstant.
44// value: Attr `value` is the tensor to return.
45REGISTER_OP("HostConst")
46    .Output("output: dtype")
47    .Attr("value: tensor")
48    .Attr("dtype: type");
49
50namespace test {
51namespace graph {
52
53Node* Send(Graph* g, Node* input, const string& tensor, const string& sender,
54           const uint64 sender_incarnation, const string& receiver) {
55  Node* ret;
56  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Send")
57                  .Input(input, 0)
58                  .Attr("tensor_name", tensor)
59                  .Attr("send_device", sender)
60                  .Attr("send_device_incarnation",
61                        static_cast<int64>(sender_incarnation))
62                  .Attr("recv_device", receiver)
63                  .Finalize(g, &ret));
64  return ret;
65}
66
67Node* Recv(Graph* g, const string& tensor, const string& type,
68           const string& sender, const uint64 sender_incarnation,
69           const string& receiver) {
70  Node* ret;
71  DataType dtype;
72  CHECK(DataTypeFromString(type, &dtype));
73  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Recv")
74                  .Attr("tensor_type", dtype)
75                  .Attr("tensor_name", tensor)
76                  .Attr("send_device", sender)
77                  .Attr("send_device_incarnation",
78                        static_cast<int64>(sender_incarnation))
79                  .Attr("recv_device", receiver)
80                  .Finalize(g, &ret));
81  return ret;
82}
83
84Node* Constant(Graph* g, const Tensor& tensor) {
85  Node* ret;
86  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Const")
87                  .Attr("dtype", tensor.dtype())
88                  .Attr("value", tensor)
89                  .Finalize(g, &ret));
90  return ret;
91}
92
93Node* Constant(Graph* g, const Tensor& tensor, const string& name) {
94  Node* ret;
95  TF_CHECK_OK(NodeBuilder(name, "Const")
96                  .Attr("dtype", tensor.dtype())
97                  .Attr("value", tensor)
98                  .Finalize(g, &ret));
99  return ret;
100}
101
102Node* HostConstant(Graph* g, const Tensor& tensor) {
103  return HostConstant(g, tensor, g->NewName("n"));
104}
105
106Node* HostConstant(Graph* g, const Tensor& tensor, const string& name) {
107  Node* ret;
108  TF_CHECK_OK(NodeBuilder(name, "HostConst")
109                  .Attr("dtype", tensor.dtype())
110                  .Attr("value", tensor)
111                  .Finalize(g, &ret));
112  return ret;
113}
114
115Node* Var(Graph* g, const DataType dtype, const TensorShape& shape) {
116  Node* ret;
117  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Variable")
118                  .Attr("dtype", dtype)
119                  .Attr("shape", shape)
120                  .Finalize(g, &ret));
121  return ret;
122}
123
124Node* Var(Graph* g, const DataType dtype, const TensorShape& shape,
125          const string& name) {
126  Node* ret;
127  TF_CHECK_OK(NodeBuilder(name, "Variable")
128                  .Attr("dtype", dtype)
129                  .Attr("shape", shape)
130                  .Finalize(g, &ret));
131  return ret;
132}
133
134Node* Assign(Graph* g, Node* var, Node* val) {
135  Node* ret;
136  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Assign")
137                  .Input(var)
138                  .Input(val)
139                  .Attr("use_locking", true)
140                  .Finalize(g, &ret));
141  return ret;
142}
143
144Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes,
145             bool keep_dims) {
146  Node* ret;
147  TF_CHECK_OK(NodeBuilder(g->NewName("n"), reduce, g->op_registry())
148                  .Input(data)
149                  .Input(axes)
150                  .Attr("keep_dims", keep_dims)
151                  .Finalize(g, &ret));
152  return ret;
153}
154
155Node* QuantizeToUINT8(Graph* g, Node* data) {
156  Node* ret;
157  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Quantize")
158                  .Input(data)
159                  .Attr("T", DT_QUINT8)
160                  .Attr("max_range", 1.0f)
161                  .Attr("min_range", -1.0f)
162                  .Finalize(g, &ret));
163  return ret;
164}
165
166Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a,
167             bool transpose_b) {
168  Node* ret;
169  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatMul")
170                  .Input(in0)
171                  .Input(in1)
172                  .Attr("transpose_a", transpose_a)
173                  .Attr("transpose_b", transpose_b)
174                  .Finalize(g, &ret));
175  return ret;
176}
177
178Node* BatchMatmul(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) {
179  Node* ret;
180  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMul")
181                  .Input(in0)
182                  .Input(in1)
183                  .Attr("adj_x", adj_x)
184                  .Attr("adj_y", adj_y)
185                  .Finalize(g, &ret));
186  return ret;
187}
188
189Node* RandomNumberGenerator(const string& op, Graph* g, Node* input,
190                            DataType dtype) {
191  Node* ret;
192  TF_CHECK_OK(NodeBuilder(g->NewName("n"), op, g->op_registry())
193                  .Input(input)
194                  .Attr("dtype", dtype)
195                  .Attr("seed", 0)
196                  .Finalize(g, &ret));
197  return ret;
198}
199
200Node* RandomUniform(Graph* g, Node* input, DataType dtype) {
201  return RandomNumberGenerator("RandomUniform", g, input, dtype);
202}
203
204Node* RandomGaussian(Graph* g, Node* input, DataType dtype) {
205  return RandomNumberGenerator("RandomStandardNormal", g, input, dtype);
206}
207
208Node* TruncatedNormal(Graph* g, Node* input, DataType dtype) {
209  return RandomNumberGenerator("TruncatedNormal", g, input, dtype);
210}
211
212Node* RandomGamma(Graph* g, Node* shape, Node* alpha) {
213  Node* ret;
214  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "RandomGamma")
215                  .Input(shape)
216                  .Input(alpha)
217                  .Attr("seed", 0)
218                  .Finalize(g, &ret));
219  return ret;
220}
221
222Node* Unary(Graph* g, const string& func, Node* input, int index) {
223  Node* ret;
224  TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry())
225                  .Input(input, index)
226                  .Finalize(g, &ret));
227  return ret;
228}
229
230Node* Binary(Graph* g, const string& func, Node* in0, Node* in1) {
231  Node* ret;
232  TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry())
233                  .Input(in0)
234                  .Input(in1)
235                  .Finalize(g, &ret));
236  return ret;
237}
238
239Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins) {
240  Node* ret;
241  auto b = NodeBuilder(g->NewName("n"), func, g->op_registry());
242  for (Node* n : ins) b = b.Input(n);
243  TF_CHECK_OK(b.Finalize(g, &ret));
244  return ret;
245}
246
247Node* Identity(Graph* g, Node* input, int index) {
248  Node* ret;
249  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Identity")
250                  .Input(input, index)
251                  .Finalize(g, &ret));
252  return ret;
253}
254
255Node* Add(Graph* g, Node* in0, Node* in1) { return Binary(g, "Add", in0, in1); }
256
257Node* Reverse(Graph* g, Node* tensor, Node* axis) {
258  return Binary(g, "ReverseV2", tensor, axis);
259}
260
261Node* Error(Graph* g, Node* input, const string& errmsg) {
262  Node* ret;
263  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Error")
264                  .Input(input)
265                  .Attr("message", errmsg)
266                  .Finalize(g, &ret));
267  return ret;
268}
269
270Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type) {
271  DCHECK(out_type != invalid_type);
272  Node* ret;
273  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "InvalidRefType")
274                  .Attr("TIn", out_type)
275                  .Attr("TOut", invalid_type)
276                  .Finalize(g, &ret));
277  return ret;
278}
279
280Node* Delay(Graph* g, Node* input, Microseconds delay_micros) {
281  Node* ret;
282  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Delay")
283                  .Input(input)
284                  .Attr("micros", delay_micros.value())
285                  .Finalize(g, &ret));
286  return ret;
287}
288
289Node* NoOp(Graph* g, const std::vector<Node*>& control_inputs) {
290  Node* ret;
291  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "NoOp")
292                  .ControlInputs(control_inputs)
293                  .Finalize(g, &ret));
294  return ret;
295}
296
297Node* Switch(Graph* g, Node* in0, Node* in1) {
298  Node* ret;
299  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Switch")
300                  .Input(in0)
301                  .Input(in1)
302                  .Finalize(g, &ret));
303  return ret;
304}
305
306Node* Enter(Graph* g, Node* input, const string& frame_name) {
307  Node* ret;
308  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Enter")
309                  .Input(input)
310                  .Attr("frame_name", frame_name)
311                  .Finalize(g, &ret));
312  return ret;
313}
314
315Node* Exit(Graph* g, Node* input) {
316  Node* ret;
317  TF_CHECK_OK(
318      NodeBuilder(g->NewName("n"), "Exit").Input(input).Finalize(g, &ret));
319  return ret;
320}
321
322Node* Merge(Graph* g, Node* in0, Node* in1) {
323  Node* ret;
324  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Merge")
325                  .Input({in0, in1})
326                  .Finalize(g, &ret));
327  return ret;
328}
329
330Node* Merge(Graph* g, Node* in0, gtl::ArraySlice<string> remaining_in) {
331  std::vector<NodeBuilder::NodeOut> inputs;
332  inputs.reserve(remaining_in.size() + 1);
333  inputs.emplace_back(in0);
334  for (const string& in_name : remaining_in) {
335    inputs.emplace_back(in_name, 0, inputs[0].dt);
336  }
337
338  Node* ret;
339  TF_CHECK_OK(
340      NodeBuilder(g->NewName("n"), "Merge").Input(inputs).Finalize(g, &ret));
341  return ret;
342}
343
344Node* Concat(Graph* g, Node* concat_dim, gtl::ArraySlice<Node*> tensors) {
345  std::vector<NodeBuilder::NodeOut> nodeouts;
346  nodeouts.reserve(tensors.size());
347  for (auto const t : tensors) {
348    nodeouts.emplace_back(t);
349  }
350  Node* ret;
351  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Concat")
352                  .Input(concat_dim)
353                  .Input(nodeouts)
354                  .Finalize(g, &ret));
355  return ret;
356}
357
358Node* ConcatV2(Graph* g, gtl::ArraySlice<Node*> tensors, Node* concat_dim) {
359  std::vector<NodeBuilder::NodeOut> nodeouts;
360  nodeouts.reserve(tensors.size());
361  for (auto const t : tensors) {
362    nodeouts.emplace_back(t);
363  }
364  Node* ret;
365  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "ConcatV2")
366                  .Input(nodeouts)
367                  .Input(concat_dim)
368                  .Finalize(g, &ret));
369  return ret;
370}
371
372Node* Next(Graph* g, const string& name, Node* input) {
373  Node* ret;
374  TF_CHECK_OK(
375      NodeBuilder(name, "NextIteration").Input(input).Finalize(g, &ret));
376  return ret;
377}
378
379Node* LoopCond(Graph* g, Node* input) {
380  Node* ret;
381  TF_CHECK_OK(
382      NodeBuilder(g->NewName("n"), "LoopCond").Input(input).Finalize(g, &ret));
383  return ret;
384}
385
386Node* Less(Graph* g, Node* in0, Node* in1) {
387  return Binary(g, "Less", in0, in1);
388}
389
390Node* Select(Graph* g, Node* c, Node* inx, Node* iny) {
391  Node* ret;
392  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Select")
393                  .Input(c)
394                  .Input(inx)
395                  .Input(iny)
396                  .Finalize(g, &ret));
397  return ret;
398}
399
400Node* Cast(Graph* g, Node* in, DataType dst) {
401  Node* ret;
402  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Cast")
403                  .Input(in)
404                  .Attr("DstT", dst)
405                  .Finalize(g, &ret));
406  return ret;
407}
408
409Node* BroadcastArgs(Graph* g, Node* s0, Node* s1) {
410  Node* ret;
411  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastArgs")
412                  .Input(s0)
413                  .Input(s1)
414                  .Finalize(g, &ret));
415  return ret;
416}
417
418Node* BroadcastGradientArgs(Graph* g, Node* s0, Node* s1) {
419  Node* ret;
420  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastGradientArgs")
421                  .Input(s0)
422                  .Input(s1)
423                  .Finalize(g, &ret));
424  return ret;
425}
426
427Node* Gather(Graph* g, Node* in0, Node* in1) {
428  Node* ret;
429  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Gather")
430                  .Input(in0)
431                  .Input(in1)
432                  .Finalize(g, &ret));
433  return ret;
434}
435
436Node* GetSessionTensor(Graph* g, Node* in) {
437  Node* ret;
438  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GetSessionTensor")
439                  .Input(in, 0)
440                  .Attr("dtype", DT_FLOAT)
441                  .Finalize(g, &ret));
442  return ret;
443}
444
445Node* Relu(Graph* g, Node* in) {
446  Node* ret;
447  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Relu")
448                  .Input(in, 0)
449                  .Attr("T", DT_FLOAT)
450                  .Finalize(g, &ret));
451  return ret;
452}
453
454Node* Relu6(Graph* g, Node* in) {
455  Node* ret;
456  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Relu6")
457                  .Input(in, 0)
458                  .Attr("T", DT_FLOAT)
459                  .Finalize(g, &ret));
460  return ret;
461}
462
463Node* BiasAdd(Graph* g, Node* value, Node* bias) {
464  Node* ret;
465  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BiasAdd")
466                  .Input(value)
467                  .Input(bias)
468                  .Attr("T", DT_FLOAT)
469                  .Finalize(g, &ret));
470  return ret;
471}
472
473Node* Conv2D(Graph* g, Node* in0, Node* in1) {
474  Node* ret;
475  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Conv2D")
476                  .Input(in0)
477                  .Input(in1)
478                  .Attr("T", DT_FLOAT)
479                  .Attr("strides", {1, 1, 1, 1})
480                  .Attr("padding", "SAME")
481                  .Finalize(g, &ret));
482  return ret;
483}
484
485void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); }
486
487}  // end namespace graph
488}  // end namespace test
489}  // end namespace tensorflow
490