1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// DEPRECATED: Use the C++ API defined in tensorflow/cc instead.
17
18#ifndef TENSORFLOW_GRAPH_TESTLIB_H_
19#define TENSORFLOW_GRAPH_TESTLIB_H_
20
21#include <string>
22#include <vector>
23
24#include "tensorflow/core/framework/tensor.h"
25#include "tensorflow/core/framework/tensor_shape.h"
26#include "tensorflow/core/graph/graph.h"
27#include "tensorflow/core/graph/types.h"
28#include "tensorflow/core/platform/types.h"
29
30namespace tensorflow {
31namespace test {
32namespace graph {
33
34// Converts "g" into its corresponding GraphDef "def".
35// DEPRECATED: call g->ToGraphDef(def) instead.
36void ToGraphDef(Graph* g, GraphDef* def);
37
38// A few helpers to construct a graph.
39
40// Adds a node in "g" producing a constant "tensor".
41Node* Constant(Graph* g, const Tensor& tensor);
42Node* Constant(Graph* g, const Tensor& tensor, const string& name);
43
44// Adds a node in "g" producing a constant "tensor" on the host.
45// The given node which, unlike the regular Constant above, always
46// stores its output on the host.  This is necessary for use
47// in GPU tests where the test Op in question runs on the device
48// but requires some arguments to be pinned to the host.
49Node* HostConstant(Graph* g, const Tensor& tensor);
50Node* HostConstant(Graph* g, const Tensor& tensor, const string& name);
51
52// Adds a variable in "g" of the given "shape" and "dtype".
53Node* Var(Graph* g, const DataType dtype, const TensorShape& shape);
54Node* Var(Graph* g, const DataType dtype, const TensorShape& shape,
55          const string& name);
56
57// Adds an assign node in "g" which assigns "val" into "var".
58Node* Assign(Graph* g, Node* var, Node* val);
59
60// Adds a send node "g" sending "input" as a named "tensor" from
61// "sender" to "receiver".
62Node* Send(Graph* g, Node* input, const string& tensor, const string& sender,
63           const uint64 sender_incarnation, const string& receiver);
64
65// Adds a recv node in "g" receiving a named "tensor" from "sender"
66// to "receiver".
67Node* Recv(Graph* g, const string& tensor, const string& type,
68           const string& sender, const uint64 sender_incarnation,
69           const string& receiver);
70
71// Adds a reduction "node" in "g" doing sum(data, axes).  "reduce" is
72// a reduction, e.g., Sum, Max, Min, Mean, etc.
73Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes,
74             bool keep_dims = false);
75
76// Adds a Matmul node in g doing in0.contract(in1).
77Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a,
78             bool transpose_b);
79
80// Adds a Matmul node in g doing in0.contract(in1).
81Node* BatchMatmul(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y);
82
83// Adds a Quantize node into g that quantize floats into QUINT8. The range of
84// the input float tensor is assumed to be [-1, 1].
85Node* QuantizeToUINT8(Graph* g, Node* data);
86
87// Adds a unary function "func" "node" in "g" taking "input".
88Node* Unary(Graph* g, const string& func, Node* input, int index = 0);
89
90// Adds an identity node in "g" taking "input" and producing an
91// identity copy.
92Node* Identity(Graph* g, Node* input, int index = 0);
93
94// Adds a binary function "func" node in "g" taking "in0" and "in1".
95Node* Binary(Graph* g, const string& func, Node* in0, Node* in1);
96
97// Adds a function "func" node in "g" taking inputs "ins".
98Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins);
99
100// Adds a binary add node in "g" doing in0 + in1.
101Node* Add(Graph* g, Node* in0, Node* in1);
102
103// Reverses <axis> dimensions of <tensor>>
104Node* Reverse(Graph* g, Node* tensor, Node* axis);
105
106// Generates random unit uniform distribution of the input shape.
107Node* RandomUniform(Graph* g, Node* input, DataType dtype);
108
109// Generates random unit normal distribution of the input shape.
110Node* RandomGaussian(Graph* g, Node* input, DataType dtype);
111
112// Generates random gamma distribution with the given shape and alpha[s].
113// Output dtype determined by alpha.
114Node* RandomGamma(Graph* g, Node* shape, Node* alpha);
115
116// Generates random poisson distribution with the given shape and lam[s].
117// Output dtype determined by lam.
118Node* RandomPoisson(Graph* g, Node* shape, Node* lam);
119
120// Rolls tensor by an offset of <shift> along the corresponding
121// <axis> dimensions.
122Node* Roll(Graph* g, Node* input, Node* shift, Node* axis);
123
124// Generates random parameters from the truncated standard normal distribution
125// of the nput shape
126Node* TruncatedNormal(Graph* g, Node* input, DataType dtype);
127
128// Adds an error node in "g". The node's computation always
129// generates an error with the given error message "errmsg".
130Node* Error(Graph* g, Node* input, const string& errmsg);
131
132// Adds a node that generates a invalid ref output.
133Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type);
134
135// Adds a node in "g". Its Compute() sleeps a while and outputs the
136// input (i.e., same as identity).
137Node* Delay(Graph* g, Node* input, Microseconds delay_micros);
138
139// Adds a no-op "node" in "g", with control inputs from all nodes in
140// control_inputs vector.
141Node* NoOp(Graph* g, const std::vector<Node*>& control_inputs);
142
143// Adds a Switch node in "g". If "in1" is true, it forwards "in0" to
144// output 1. Otherwise, it forwards "in0" to output 0.
145Node* Switch(Graph* g, Node* in0, Node* in1);
146
147// Adds an Enter node in "g", which enters a new frame.
148Node* Enter(Graph* g, Node* input, const string& frame_name);
149
150// Adds an Exit node in "g", which exits a frame.
151Node* Exit(Graph* g, Node* input);
152
153// Adds a Merge node in "g" with two inputs "in0" and "in1".
154Node* Merge(Graph* g, Node* in0, Node* in1);
155
156// Adds a Merge node in "g". The first input is "in0", the remaining
157// inputs are only given by their names in remaining_in.
158Node* Merge(Graph* g, Node* in0, gtl::ArraySlice<string> remaining_in);
159
160// Adds a NextIteration node in "g", which makes its input available
161// to the next iteration.
162Node* Next(Graph* g, const string& name, Node* input);
163
164// Adds a LoopCond node in "g", representing the "pivot" termination
165// condition of a loop.
166Node* LoopCond(Graph* g, Node* input);
167
168// Adds a less node in "g", which returns true iff "in0" < "in1".
169Node* Less(Graph* g, Node* in0, Node* in1);
170
171// Adds a select node in "g", which outputs either "inx" or "iny"
172// depending on the boolean value of "c".
173Node* Select(Graph* g, Node* c, Node* inx, Node* iny);
174
175// Casts "in" into data type "dst".
176Node* Cast(Graph* g, Node* in, DataType dst);
177
178// Perform gather op on params "in0" with indices "in1" and axis "axis".
179Node* Gather(Graph* g, Node* in0, Node* in1, Node* axis);
180
181// Gets a tensor stored in the session state.
182Node* GetSessionTensor(Graph* g, Node* in);
183
184// Adds a Concat node in "g". The first input is "concat_dim", the
185// dimension to concatenate on, and the tensors to concatenate are
186// given in "tensors".
187Node* Concat(Graph* g, Node* concat_dim, gtl::ArraySlice<Node*> tensors);
188
189// Adds a ConcatV2 node in "g". The last input is "concat_dim", the
190// dimension to concatenate on, and the tensors to concatenate are
191// given in "tensors".
192Node* ConcatV2(Graph* g, gtl::ArraySlice<Node*> tensors, Node* concat_dim);
193
194// Add a Relu node in "g".
195Node* Relu(Graph* g, Node* in);
196
197// Add a Relu6 node in "g".
198Node* Relu6(Graph* g, Node* in);
199
200// Add a BiasAdd node in "g".
201Node* BiasAdd(Graph* g, Node* value, Node* bias);
202
203// Add a Conv2D node in "g".
204Node* Conv2D(Graph* g, Node* in0, Node* in1);
205
206// Add a Diag node in "g".
207Node* Diag(Graph* g, Node* in, DataType type);
208
209// Add a DiagPart node in "g".
210Node* DiagPart(Graph* g, Node* in, DataType type);
211
212}  // end namespace graph
213}  // end namespace test
214}  // end namespace tensorflow
215
216#endif  // TENSORFLOW_GRAPH_TESTLIB_H_
217