testlib.cc revision 2585e0a75e2802ca8b9877fd06544ecca0b95cd9
1/* Copyright 2015 Google Inc. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/core/graph/testlib.h" 17 18#include <vector> 19#include "tensorflow/core/framework/graph.pb.h" 20#include "tensorflow/core/framework/node_def_builder.h" 21#include "tensorflow/core/framework/op.h" 22#include "tensorflow/core/framework/op_kernel.h" 23#include "tensorflow/core/framework/types.h" 24#include "tensorflow/core/framework/types.pb.h" 25#include "tensorflow/core/graph/graph.h" 26#include "tensorflow/core/graph/node_builder.h" 27#include "tensorflow/core/kernels/constant_op.h" 28#include "tensorflow/core/lib/core/status.h" 29#include "tensorflow/core/platform/logging.h" 30 31namespace tensorflow { 32 33// HostConst: forced to generate output on the host. 34// Only used by testlib; no op is registered for this kernel 35// externally (i.e., in array_ops.cc) 36REGISTER_KERNEL_BUILDER(Name("HostConst").Device(DEVICE_CPU), HostConstantOp); 37REGISTER_KERNEL_BUILDER( 38 Name("HostConst").Device(DEVICE_GPU).HostMemory("output"), HostConstantOp); 39 40// Register the HostConst Op 41// Returns a constant tensor on the host. Useful for writing C++ tests 42// and benchmarks which run on GPU but require arguments pinned to the host. 43// Used by test::graph::HostConstant. 44// value: Attr `value` is the tensor to return. 45REGISTER_OP("HostConst") 46 .Output("output: dtype") 47 .Attr("value: tensor") 48 .Attr("dtype: type"); 49 50namespace test { 51namespace graph { 52 53Node* Send(Graph* g, Node* input, const string& tensor, const string& sender, 54 const uint64 sender_incarnation, const string& receiver) { 55 Node* ret; 56 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Send") 57 .Input(input, 0) 58 .Attr("tensor_name", tensor) 59 .Attr("send_device", sender) 60 .Attr("send_device_incarnation", 61 static_cast<int64>(sender_incarnation)) 62 .Attr("recv_device", receiver) 63 .Finalize(g, &ret)); 64 return ret; 65} 66 67Node* Recv(Graph* g, const string& tensor, const string& type, 68 const string& sender, const uint64 sender_incarnation, 69 const string& receiver) { 70 Node* ret; 71 DataType dtype; 72 CHECK(DataTypeFromString(type, &dtype)); 73 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Recv") 74 .Attr("tensor_type", dtype) 75 .Attr("tensor_name", tensor) 76 .Attr("send_device", sender) 77 .Attr("send_device_incarnation", 78 static_cast<int64>(sender_incarnation)) 79 .Attr("recv_device", receiver) 80 .Finalize(g, &ret)); 81 return ret; 82} 83 84Node* Constant(Graph* g, const Tensor& tensor) { 85 Node* ret; 86 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Const") 87 .Attr("dtype", tensor.dtype()) 88 .Attr("value", tensor) 89 .Finalize(g, &ret)); 90 return ret; 91} 92 93Node* Constant(Graph* g, const Tensor& tensor, const string& name) { 94 Node* ret; 95 TF_CHECK_OK(NodeBuilder(name, "Const") 96 .Attr("dtype", tensor.dtype()) 97 .Attr("value", tensor) 98 .Finalize(g, &ret)); 99 return ret; 100} 101 102Node* HostConstant(Graph* g, const Tensor& tensor) { 103 return HostConstant(g, tensor, g->NewName("n")); 104} 105 106Node* HostConstant(Graph* g, const Tensor& tensor, const string& name) { 107 Node* ret; 108 TF_CHECK_OK(NodeBuilder(name, "HostConst") 109 .Attr("dtype", tensor.dtype()) 110 .Attr("value", tensor) 111 .Finalize(g, &ret)); 112 return ret; 113} 114 115Node* Var(Graph* g, const DataType dtype, const TensorShape& shape) { 116 Node* ret; 117 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Variable") 118 .Attr("dtype", dtype) 119 .Attr("shape", shape) 120 .Finalize(g, &ret)); 121 return ret; 122} 123 124Node* Assign(Graph* g, Node* var, Node* val) { 125 Node* ret; 126 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Assign") 127 .Input(var) 128 .Input(val) 129 .Attr("use_locking", true) 130 .Finalize(g, &ret)); 131 return ret; 132} 133 134Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes, 135 bool keep_dims) { 136 Node* ret; 137 TF_CHECK_OK(NodeBuilder(g->NewName("n"), reduce) 138 .Input(data) 139 .Input(axes) 140 .Attr("keep_dims", keep_dims) 141 .Finalize(g, &ret)); 142 return ret; 143} 144 145Node* QuantizeToUINT8(Graph* g, Node* data) { 146 Node* ret; 147 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Quantize") 148 .Input(data) 149 .Attr("T", DT_QUINT8) 150 .Attr("max_range", 1.0f) 151 .Attr("min_range", -1.0f) 152 .Finalize(g, &ret)); 153 return ret; 154} 155 156Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a, 157 bool transpose_b) { 158 Node* ret; 159 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatMul") 160 .Input(in0) 161 .Input(in1) 162 .Attr("transpose_a", transpose_a) 163 .Attr("transpose_b", transpose_b) 164 .Finalize(g, &ret)); 165 return ret; 166} 167 168Node* RandomNumberGenerator(const string& op, Graph* g, Node* input, 169 DataType dtype) { 170 Node* ret; 171 TF_CHECK_OK(NodeBuilder(g->NewName("n"), op) 172 .Input(input) 173 .Attr("dtype", dtype) 174 .Attr("seed", 0) 175 .Finalize(g, &ret)); 176 return ret; 177} 178 179Node* RandomUniform(Graph* g, Node* input, DataType dtype) { 180 return RandomNumberGenerator("RandomUniform", g, input, dtype); 181} 182 183Node* RandomGaussian(Graph* g, Node* input, DataType dtype) { 184 return RandomNumberGenerator("RandomStandardNormal", g, input, dtype); 185} 186 187Node* RandomParameters(Graph* g, Node* input, DataType dtype) { 188 return RandomNumberGenerator("RandomParameters", g, input, dtype); 189} 190 191Node* Unary(Graph* g, const string& func, Node* input, int index) { 192 Node* ret; 193 TF_CHECK_OK( 194 NodeBuilder(g->NewName("n"), func).Input(input, index).Finalize(g, &ret)); 195 return ret; 196} 197 198Node* Binary(Graph* g, const string& func, Node* in0, Node* in1) { 199 Node* ret; 200 TF_CHECK_OK(NodeBuilder(g->NewName("n"), func) 201 .Input(in0) 202 .Input(in1) 203 .Finalize(g, &ret)); 204 return ret; 205} 206 207Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins) { 208 Node* ret; 209 auto b = NodeBuilder(g->NewName("n"), func); 210 for (Node* n : ins) b = b.Input(n); 211 TF_CHECK_OK(b.Finalize(g, &ret)); 212 return ret; 213} 214 215Node* Identity(Graph* g, Node* input, int index) { 216 Node* ret; 217 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Identity") 218 .Input(input, index) 219 .Finalize(g, &ret)); 220 return ret; 221} 222 223Node* Add(Graph* g, Node* in0, Node* in1) { return Binary(g, "Add", in0, in1); } 224 225Node* Error(Graph* g, Node* input, const string& errmsg) { 226 Node* ret; 227 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Error") 228 .Input(input) 229 .Attr("message", errmsg) 230 .Finalize(g, &ret)); 231 return ret; 232} 233 234Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type) { 235 DCHECK(out_type != invalid_type); 236 Node* ret; 237 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "InvalidRefType") 238 .Attr("TIn", out_type) 239 .Attr("TOut", invalid_type) 240 .Finalize(g, &ret)); 241 return ret; 242} 243 244Node* Delay(Graph* g, Node* input, Microseconds delay_micros) { 245 Node* ret; 246 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Delay") 247 .Input(input) 248 .Attr("micros", delay_micros.value()) 249 .Finalize(g, &ret)); 250 return ret; 251} 252 253Node* NoOp(Graph* g, const std::vector<Node*>& control_inputs) { 254 Node* ret; 255 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "NoOp") 256 .ControlInputs(control_inputs) 257 .Finalize(g, &ret)); 258 return ret; 259} 260 261Node* Switch(Graph* g, Node* in0, Node* in1) { 262 Node* ret; 263 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Switch") 264 .Input(in0) 265 .Input(in1) 266 .Finalize(g, &ret)); 267 return ret; 268} 269 270Node* Enter(Graph* g, Node* input, const string& frame_name) { 271 Node* ret; 272 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Enter") 273 .Input(input) 274 .Attr("frame_name", frame_name) 275 .Finalize(g, &ret)); 276 return ret; 277} 278 279Node* Exit(Graph* g, Node* input) { 280 Node* ret; 281 TF_CHECK_OK( 282 NodeBuilder(g->NewName("n"), "Exit").Input(input).Finalize(g, &ret)); 283 return ret; 284} 285 286Node* Merge(Graph* g, Node* in0, Node* in1) { 287 Node* ret; 288 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Merge") 289 .Input({in0, in1}) 290 .Finalize(g, &ret)); 291 return ret; 292} 293 294Node* Merge(Graph* g, Node* in0, gtl::ArraySlice<string> remaining_in) { 295 std::vector<NodeBuilder::NodeOut> inputs; 296 inputs.reserve(remaining_in.size() + 1); 297 inputs.emplace_back(in0); 298 for (const string& in_name : remaining_in) { 299 inputs.emplace_back(in_name, 0, inputs[0].dt); 300 } 301 302 Node* ret; 303 TF_CHECK_OK( 304 NodeBuilder(g->NewName("n"), "Merge").Input(inputs).Finalize(g, &ret)); 305 return ret; 306} 307 308Node* Next(Graph* g, const string& name, Node* input) { 309 Node* ret; 310 TF_CHECK_OK( 311 NodeBuilder(name, "NextIteration").Input(input).Finalize(g, &ret)); 312 return ret; 313} 314 315Node* LoopCond(Graph* g, Node* input) { 316 Node* ret; 317 TF_CHECK_OK( 318 NodeBuilder(g->NewName("n"), "LoopCond").Input(input).Finalize(g, &ret)); 319 return ret; 320} 321 322Node* Less(Graph* g, Node* in0, Node* in1) { 323 return Binary(g, "Less", in0, in1); 324} 325 326Node* Select(Graph* g, Node* c, Node* inx, Node* iny) { 327 Node* ret; 328 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Select") 329 .Input(c) 330 .Input(inx) 331 .Input(iny) 332 .Finalize(g, &ret)); 333 return ret; 334} 335 336Node* Cast(Graph* g, Node* in, DataType dst) { 337 Node* ret; 338 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Cast") 339 .Input(in) 340 .Attr("DstT", dst) 341 .Finalize(g, &ret)); 342 return ret; 343} 344 345Node* BroadcastGradientArgs(Graph* g, Node* s0, Node* s1) { 346 Node* ret; 347 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastGradientArgs") 348 .Input(s0) 349 .Input(s1) 350 .Finalize(g, &ret)); 351 return ret; 352} 353 354void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); } 355 356} // end namespace graph 357} // end namespace test 358} // end namespace tensorflow 359