1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/graph/testlib.h"
17
18#include <vector>
19#include "tensorflow/core/framework/graph.pb.h"
20#include "tensorflow/core/framework/node_def_builder.h"
21#include "tensorflow/core/framework/node_def_util.h"
22#include "tensorflow/core/framework/op.h"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/types.h"
25#include "tensorflow/core/framework/types.pb.h"
26#include "tensorflow/core/graph/graph.h"
27#include "tensorflow/core/graph/node_builder.h"
28#include "tensorflow/core/kernels/constant_op.h"
29#include "tensorflow/core/lib/core/status.h"
30#include "tensorflow/core/platform/logging.h"
31
32namespace tensorflow {
33
34// HostConst: forced to generate output on the host.
35// Only used by testlib; no op is registered for this kernel
36// externally (i.e., in array_ops.cc)
37REGISTER_KERNEL_BUILDER(Name("HostConst").Device(DEVICE_CPU), HostConstantOp);
38REGISTER_KERNEL_BUILDER(
39    Name("HostConst").Device(DEVICE_GPU).HostMemory("output"), HostConstantOp);
40#ifdef TENSORFLOW_USE_SYCL
41REGISTER_KERNEL_BUILDER(
42    Name("HostConst").Device(DEVICE_SYCL).HostMemory("output"), HostConstantOp);
43#endif  // TENSORFLOW_USE_SYCL
44
45// Register the HostConst Op
46// Returns a constant tensor on the host.  Useful for writing C++ tests
47// and benchmarks which run on GPU but require arguments pinned to the host.
48// Used by test::graph::HostConstant.
49// value: Attr `value` is the tensor to return.
50REGISTER_OP("HostConst")
51    .Output("output: dtype")
52    .Attr("value: tensor")
53    .Attr("dtype: type");
54
55namespace test {
56namespace graph {
57
58Node* Send(Graph* g, Node* input, const string& tensor, const string& sender,
59           const uint64 sender_incarnation, const string& receiver) {
60  Node* ret;
61  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Send")
62                  .Input(input, 0)
63                  .Attr("tensor_name", tensor)
64                  .Attr("send_device", sender)
65                  .Attr("send_device_incarnation",
66                        static_cast<int64>(sender_incarnation))
67                  .Attr("recv_device", receiver)
68                  .Finalize(g, &ret));
69  return ret;
70}
71
72Node* Recv(Graph* g, const string& tensor, const string& type,
73           const string& sender, const uint64 sender_incarnation,
74           const string& receiver) {
75  Node* ret;
76  DataType dtype;
77  CHECK(DataTypeFromString(type, &dtype));
78  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Recv")
79                  .Attr("tensor_type", dtype)
80                  .Attr("tensor_name", tensor)
81                  .Attr("send_device", sender)
82                  .Attr("send_device_incarnation",
83                        static_cast<int64>(sender_incarnation))
84                  .Attr("recv_device", receiver)
85                  .Finalize(g, &ret));
86  return ret;
87}
88
89Node* Constant(Graph* g, const Tensor& tensor) {
90  Node* ret;
91  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Const")
92                  .Attr("dtype", tensor.dtype())
93                  .Attr("value", tensor)
94                  .Finalize(g, &ret));
95  return ret;
96}
97
98Node* Constant(Graph* g, const Tensor& tensor, const string& name) {
99  Node* ret;
100  TF_CHECK_OK(NodeBuilder(name, "Const")
101                  .Attr("dtype", tensor.dtype())
102                  .Attr("value", tensor)
103                  .Finalize(g, &ret));
104  return ret;
105}
106
107Node* HostConstant(Graph* g, const Tensor& tensor) {
108  return HostConstant(g, tensor, g->NewName("n"));
109}
110
111Node* HostConstant(Graph* g, const Tensor& tensor, const string& name) {
112  Node* ret;
113  TF_CHECK_OK(NodeBuilder(name, "HostConst")
114                  .Attr("dtype", tensor.dtype())
115                  .Attr("value", tensor)
116                  .Finalize(g, &ret));
117  return ret;
118}
119
120Node* Var(Graph* g, const DataType dtype, const TensorShape& shape) {
121  Node* ret;
122  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Variable")
123                  .Attr("dtype", dtype)
124                  .Attr("shape", shape)
125                  .Finalize(g, &ret));
126  return ret;
127}
128
129Node* Var(Graph* g, const DataType dtype, const TensorShape& shape,
130          const string& name) {
131  Node* ret;
132  TF_CHECK_OK(NodeBuilder(name, "Variable")
133                  .Attr("dtype", dtype)
134                  .Attr("shape", shape)
135                  .Finalize(g, &ret));
136  return ret;
137}
138
139Node* Assign(Graph* g, Node* var, Node* val) {
140  Node* ret;
141  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Assign")
142                  .Input(var)
143                  .Input(val)
144                  .Attr("use_locking", true)
145                  .Finalize(g, &ret));
146  return ret;
147}
148
149Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes,
150             bool keep_dims) {
151  Node* ret;
152  TF_CHECK_OK(NodeBuilder(g->NewName("n"), reduce, g->op_registry())
153                  .Input(data)
154                  .Input(axes)
155                  .Attr("keep_dims", keep_dims)
156                  .Finalize(g, &ret));
157  return ret;
158}
159
160Node* QuantizeToUINT8(Graph* g, Node* data) {
161  Node* ret;
162  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Quantize")
163                  .Input(data)
164                  .Attr("T", DT_QUINT8)
165                  .Attr("max_range", 1.0f)
166                  .Attr("min_range", -1.0f)
167                  .Finalize(g, &ret));
168  return ret;
169}
170
171Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a,
172             bool transpose_b) {
173  Node* ret;
174  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatMul")
175                  .Input(in0)
176                  .Input(in1)
177                  .Attr("transpose_a", transpose_a)
178                  .Attr("transpose_b", transpose_b)
179                  .Finalize(g, &ret));
180  return ret;
181}
182
183Node* BatchMatmul(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) {
184  Node* ret;
185  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMul")
186                  .Input(in0)
187                  .Input(in1)
188                  .Attr("adj_x", adj_x)
189                  .Attr("adj_y", adj_y)
190                  .Finalize(g, &ret));
191  return ret;
192}
193
194Node* RandomNumberGenerator(const string& op, Graph* g, Node* input,
195                            DataType dtype) {
196  Node* ret;
197  TF_CHECK_OK(NodeBuilder(g->NewName("n"), op, g->op_registry())
198                  .Input(input)
199                  .Attr("dtype", dtype)
200                  .Attr("seed", 0)
201                  .Finalize(g, &ret));
202  return ret;
203}
204
205Node* RandomUniform(Graph* g, Node* input, DataType dtype) {
206  return RandomNumberGenerator("RandomUniform", g, input, dtype);
207}
208
209Node* RandomGaussian(Graph* g, Node* input, DataType dtype) {
210  return RandomNumberGenerator("RandomStandardNormal", g, input, dtype);
211}
212
213Node* TruncatedNormal(Graph* g, Node* input, DataType dtype) {
214  return RandomNumberGenerator("TruncatedNormal", g, input, dtype);
215}
216
217Node* RandomGamma(Graph* g, Node* shape, Node* alpha) {
218  Node* ret;
219  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "RandomGamma")
220                  .Input(shape)
221                  .Input(alpha)
222                  .Attr("seed", 0)
223                  .Finalize(g, &ret));
224  return ret;
225}
226
227Node* RandomPoisson(Graph* g, Node* shape, Node* lam) {
228  Node* ret;
229  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "RandomPoisson")
230                  .Input(shape)
231                  .Input(lam)
232                  .Attr("seed", 0)
233                  .Finalize(g, &ret));
234  return ret;
235}
236
237Node* Unary(Graph* g, const string& func, Node* input, int index) {
238  Node* ret;
239  TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry())
240                  .Input(input, index)
241                  .Finalize(g, &ret));
242  return ret;
243}
244
245Node* Binary(Graph* g, const string& func, Node* in0, Node* in1) {
246  Node* ret;
247  TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry())
248                  .Input(in0)
249                  .Input(in1)
250                  .Finalize(g, &ret));
251  return ret;
252}
253
254Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins) {
255  Node* ret;
256  auto b = NodeBuilder(g->NewName("n"), func, g->op_registry());
257  for (Node* n : ins) b = b.Input(n);
258  TF_CHECK_OK(b.Finalize(g, &ret));
259  return ret;
260}
261
262Node* Identity(Graph* g, Node* input, int index) {
263  Node* ret;
264  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Identity")
265                  .Input(input, index)
266                  .Finalize(g, &ret));
267  return ret;
268}
269
270Node* Add(Graph* g, Node* in0, Node* in1) { return Binary(g, "Add", in0, in1); }
271
272Node* Reverse(Graph* g, Node* tensor, Node* axis) {
273  return Binary(g, "ReverseV2", tensor, axis);
274}
275
276Node* Roll(Graph* g, Node* input, Node* shift, Node* axis) {
277  Node* ret;
278  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Roll", g->op_registry())
279                  .Input(input)
280                  .Input(shift)
281                  .Input(axis)
282                  .Finalize(g, &ret));
283  return ret;
284}
285
286Node* Error(Graph* g, Node* input, const string& errmsg) {
287  Node* ret;
288  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Error")
289                  .Input(input)
290                  .Attr("message", errmsg)
291                  .Finalize(g, &ret));
292  return ret;
293}
294
295Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type) {
296  DCHECK(out_type != invalid_type);
297  Node* ret;
298  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "InvalidRefType")
299                  .Attr("TIn", out_type)
300                  .Attr("TOut", invalid_type)
301                  .Finalize(g, &ret));
302  return ret;
303}
304
305Node* Delay(Graph* g, Node* input, Microseconds delay_micros) {
306  Node* ret;
307  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Delay")
308                  .Input(input)
309                  .Attr("micros", delay_micros.value())
310                  .Finalize(g, &ret));
311  return ret;
312}
313
314Node* NoOp(Graph* g, const std::vector<Node*>& control_inputs) {
315  Node* ret;
316  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "NoOp")
317                  .ControlInputs(control_inputs)
318                  .Finalize(g, &ret));
319  return ret;
320}
321
322Node* Switch(Graph* g, Node* in0, Node* in1) {
323  Node* ret;
324  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Switch")
325                  .Input(in0)
326                  .Input(in1)
327                  .Finalize(g, &ret));
328  return ret;
329}
330
331Node* Enter(Graph* g, Node* input, const string& frame_name) {
332  Node* ret;
333  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Enter")
334                  .Input(input)
335                  .Attr("frame_name", frame_name)
336                  .Finalize(g, &ret));
337  return ret;
338}
339
340Node* Exit(Graph* g, Node* input) {
341  Node* ret;
342  TF_CHECK_OK(
343      NodeBuilder(g->NewName("n"), "Exit").Input(input).Finalize(g, &ret));
344  return ret;
345}
346
347Node* Merge(Graph* g, Node* in0, Node* in1) {
348  Node* ret;
349  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Merge")
350                  .Input({in0, in1})
351                  .Finalize(g, &ret));
352  return ret;
353}
354
355Node* Merge(Graph* g, Node* in0, gtl::ArraySlice<string> remaining_in) {
356  std::vector<NodeBuilder::NodeOut> inputs;
357  inputs.reserve(remaining_in.size() + 1);
358  inputs.emplace_back(in0);
359  for (const string& in_name : remaining_in) {
360    inputs.emplace_back(in_name, 0, inputs[0].dt);
361  }
362
363  Node* ret;
364  TF_CHECK_OK(
365      NodeBuilder(g->NewName("n"), "Merge").Input(inputs).Finalize(g, &ret));
366  return ret;
367}
368
369Node* Concat(Graph* g, Node* concat_dim, gtl::ArraySlice<Node*> tensors) {
370  std::vector<NodeBuilder::NodeOut> nodeouts;
371  nodeouts.reserve(tensors.size());
372  for (auto const t : tensors) {
373    nodeouts.emplace_back(t);
374  }
375  Node* ret;
376  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Concat")
377                  .Input(concat_dim)
378                  .Input(nodeouts)
379                  .Finalize(g, &ret));
380  return ret;
381}
382
383Node* ConcatV2(Graph* g, gtl::ArraySlice<Node*> tensors, Node* concat_dim) {
384  std::vector<NodeBuilder::NodeOut> nodeouts;
385  nodeouts.reserve(tensors.size());
386  for (auto const t : tensors) {
387    nodeouts.emplace_back(t);
388  }
389  Node* ret;
390  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "ConcatV2")
391                  .Input(nodeouts)
392                  .Input(concat_dim)
393                  .Finalize(g, &ret));
394  return ret;
395}
396
397Node* Next(Graph* g, const string& name, Node* input) {
398  Node* ret;
399  TF_CHECK_OK(
400      NodeBuilder(name, "NextIteration").Input(input).Finalize(g, &ret));
401  return ret;
402}
403
404Node* LoopCond(Graph* g, Node* input) {
405  Node* ret;
406  TF_CHECK_OK(
407      NodeBuilder(g->NewName("n"), "LoopCond").Input(input).Finalize(g, &ret));
408  return ret;
409}
410
411Node* Less(Graph* g, Node* in0, Node* in1) {
412  return Binary(g, "Less", in0, in1);
413}
414
415Node* Select(Graph* g, Node* c, Node* inx, Node* iny) {
416  Node* ret;
417  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Select")
418                  .Input(c)
419                  .Input(inx)
420                  .Input(iny)
421                  .Finalize(g, &ret));
422  return ret;
423}
424
425Node* Cast(Graph* g, Node* in, DataType dst) {
426  Node* ret;
427  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Cast")
428                  .Input(in)
429                  .Attr("DstT", dst)
430                  .Finalize(g, &ret));
431  return ret;
432}
433
434Node* Gather(Graph* g, Node* in0, Node* in1, Node* axis) {
435  Node* ret;
436  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GatherV2")
437                  .Input(in0)
438                  .Input(in1)
439                  .Input(axis)
440                  .Finalize(g, &ret));
441  return ret;
442}
443
444Node* GetSessionTensor(Graph* g, Node* in) {
445  Node* ret;
446  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GetSessionTensor")
447                  .Input(in, 0)
448                  .Attr("dtype", DT_FLOAT)
449                  .Finalize(g, &ret));
450  return ret;
451}
452
453Node* Relu(Graph* g, Node* in) {
454  Node* ret;
455  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Relu")
456                  .Input(in, 0)
457                  .Attr("T", DT_FLOAT)
458                  .Finalize(g, &ret));
459  return ret;
460}
461
462Node* Relu6(Graph* g, Node* in) {
463  Node* ret;
464  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Relu6")
465                  .Input(in, 0)
466                  .Attr("T", DT_FLOAT)
467                  .Finalize(g, &ret));
468  return ret;
469}
470
471Node* BiasAdd(Graph* g, Node* value, Node* bias) {
472  Node* ret;
473  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BiasAdd")
474                  .Input(value)
475                  .Input(bias)
476                  .Attr("T", DT_FLOAT)
477                  .Finalize(g, &ret));
478  return ret;
479}
480
481Node* Conv2D(Graph* g, Node* in0, Node* in1) {
482  Node* ret;
483  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Conv2D")
484                  .Input(in0)
485                  .Input(in1)
486                  .Attr("T", DT_FLOAT)
487                  .Attr("strides", {1, 1, 1, 1})
488                  .Attr("padding", "SAME")
489                  .Finalize(g, &ret));
490  return ret;
491}
492
493Node* Diag(Graph* g, Node* in, DataType type) {
494  Node* ret;
495  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Diag")
496                  .Input(in)
497                  .Attr("T", type)
498                  .Finalize(g, &ret));
499  return ret;
500}
501
502Node* DiagPart(Graph* g, Node* in, DataType type) {
503  Node* ret;
504  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "DiagPart")
505                  .Input(in)
506                  .Attr("T", type)
507                  .Finalize(g, &ret));
508  return ret;
509}
510
511void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); }
512
513}  // end namespace graph
514}  // end namespace test
515}  // end namespace tensorflow
516