testlib.cc revision 51e6e84fadd7b37c6e8e5c13cd0de4c7c8e2959f
1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/core/graph/testlib.h" 17 18#include <vector> 19#include "tensorflow/core/framework/graph.pb.h" 20#include "tensorflow/core/framework/node_def_builder.h" 21#include "tensorflow/core/framework/op.h" 22#include "tensorflow/core/framework/op_kernel.h" 23#include "tensorflow/core/framework/types.h" 24#include "tensorflow/core/framework/types.pb.h" 25#include "tensorflow/core/graph/graph.h" 26#include "tensorflow/core/graph/node_builder.h" 27#include "tensorflow/core/kernels/constant_op.h" 28#include "tensorflow/core/lib/core/status.h" 29#include "tensorflow/core/platform/logging.h" 30 31namespace tensorflow { 32 33// HostConst: forced to generate output on the host. 34// Only used by testlib; no op is registered for this kernel 35// externally (i.e., in array_ops.cc) 36REGISTER_KERNEL_BUILDER(Name("HostConst").Device(DEVICE_CPU), HostConstantOp); 37REGISTER_KERNEL_BUILDER( 38 Name("HostConst").Device(DEVICE_GPU).HostMemory("output"), HostConstantOp); 39 40// Register the HostConst Op 41// Returns a constant tensor on the host. Useful for writing C++ tests 42// and benchmarks which run on GPU but require arguments pinned to the host. 43// Used by test::graph::HostConstant. 44// value: Attr `value` is the tensor to return. 45REGISTER_OP("HostConst") 46 .Output("output: dtype") 47 .Attr("value: tensor") 48 .Attr("dtype: type"); 49 50namespace test { 51namespace graph { 52 53Node* Send(Graph* g, Node* input, const string& tensor, const string& sender, 54 const uint64 sender_incarnation, const string& receiver) { 55 Node* ret; 56 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Send") 57 .Input(input, 0) 58 .Attr("tensor_name", tensor) 59 .Attr("send_device", sender) 60 .Attr("send_device_incarnation", 61 static_cast<int64>(sender_incarnation)) 62 .Attr("recv_device", receiver) 63 .Finalize(g, &ret)); 64 return ret; 65} 66 67Node* Recv(Graph* g, const string& tensor, const string& type, 68 const string& sender, const uint64 sender_incarnation, 69 const string& receiver) { 70 Node* ret; 71 DataType dtype; 72 CHECK(DataTypeFromString(type, &dtype)); 73 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Recv") 74 .Attr("tensor_type", dtype) 75 .Attr("tensor_name", tensor) 76 .Attr("send_device", sender) 77 .Attr("send_device_incarnation", 78 static_cast<int64>(sender_incarnation)) 79 .Attr("recv_device", receiver) 80 .Finalize(g, &ret)); 81 return ret; 82} 83 84Node* Constant(Graph* g, const Tensor& tensor) { 85 Node* ret; 86 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Const") 87 .Attr("dtype", tensor.dtype()) 88 .Attr("value", tensor) 89 .Finalize(g, &ret)); 90 return ret; 91} 92 93Node* Constant(Graph* g, const Tensor& tensor, const string& name) { 94 Node* ret; 95 TF_CHECK_OK(NodeBuilder(name, "Const") 96 .Attr("dtype", tensor.dtype()) 97 .Attr("value", tensor) 98 .Finalize(g, &ret)); 99 return ret; 100} 101 102Node* HostConstant(Graph* g, const Tensor& tensor) { 103 return HostConstant(g, tensor, g->NewName("n")); 104} 105 106Node* HostConstant(Graph* g, const Tensor& tensor, const string& name) { 107 Node* ret; 108 TF_CHECK_OK(NodeBuilder(name, "HostConst") 109 .Attr("dtype", tensor.dtype()) 110 .Attr("value", tensor) 111 .Finalize(g, &ret)); 112 return ret; 113} 114 115Node* Var(Graph* g, const DataType dtype, const TensorShape& shape) { 116 Node* ret; 117 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Variable") 118 .Attr("dtype", dtype) 119 .Attr("shape", shape) 120 .Finalize(g, &ret)); 121 return ret; 122} 123 124Node* Var(Graph* g, const DataType dtype, const TensorShape& shape, 125 const string& name) { 126 Node* ret; 127 TF_CHECK_OK(NodeBuilder(name, "Variable") 128 .Attr("dtype", dtype) 129 .Attr("shape", shape) 130 .Finalize(g, &ret)); 131 return ret; 132} 133 134Node* Assign(Graph* g, Node* var, Node* val) { 135 Node* ret; 136 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Assign") 137 .Input(var) 138 .Input(val) 139 .Attr("use_locking", true) 140 .Finalize(g, &ret)); 141 return ret; 142} 143 144Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes, 145 bool keep_dims) { 146 Node* ret; 147 TF_CHECK_OK(NodeBuilder(g->NewName("n"), reduce, g->op_registry()) 148 .Input(data) 149 .Input(axes) 150 .Attr("keep_dims", keep_dims) 151 .Finalize(g, &ret)); 152 return ret; 153} 154 155Node* QuantizeToUINT8(Graph* g, Node* data) { 156 Node* ret; 157 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Quantize") 158 .Input(data) 159 .Attr("T", DT_QUINT8) 160 .Attr("max_range", 1.0f) 161 .Attr("min_range", -1.0f) 162 .Finalize(g, &ret)); 163 return ret; 164} 165 166Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a, 167 bool transpose_b) { 168 Node* ret; 169 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatMul") 170 .Input(in0) 171 .Input(in1) 172 .Attr("transpose_a", transpose_a) 173 .Attr("transpose_b", transpose_b) 174 .Finalize(g, &ret)); 175 return ret; 176} 177 178Node* BatchMatmul(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) { 179 Node* ret; 180 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMul") 181 .Input(in0) 182 .Input(in1) 183 .Attr("adj_x", adj_x) 184 .Attr("adj_y", adj_y) 185 .Finalize(g, &ret)); 186 return ret; 187} 188 189Node* RandomNumberGenerator(const string& op, Graph* g, Node* input, 190 DataType dtype) { 191 Node* ret; 192 TF_CHECK_OK(NodeBuilder(g->NewName("n"), op, g->op_registry()) 193 .Input(input) 194 .Attr("dtype", dtype) 195 .Attr("seed", 0) 196 .Finalize(g, &ret)); 197 return ret; 198} 199 200Node* RandomUniform(Graph* g, Node* input, DataType dtype) { 201 return RandomNumberGenerator("RandomUniform", g, input, dtype); 202} 203 204Node* RandomGaussian(Graph* g, Node* input, DataType dtype) { 205 return RandomNumberGenerator("RandomStandardNormal", g, input, dtype); 206} 207 208Node* TruncatedNormal(Graph* g, Node* input, DataType dtype) { 209 return RandomNumberGenerator("TruncatedNormal", g, input, dtype); 210} 211 212Node* RandomGamma(Graph* g, Node* shape, Node* alpha) { 213 Node* ret; 214 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "RandomGamma") 215 .Input(shape) 216 .Input(alpha) 217 .Attr("seed", 0) 218 .Finalize(g, &ret)); 219 return ret; 220} 221 222Node* Unary(Graph* g, const string& func, Node* input, int index) { 223 Node* ret; 224 TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry()) 225 .Input(input, index) 226 .Finalize(g, &ret)); 227 return ret; 228} 229 230Node* Binary(Graph* g, const string& func, Node* in0, Node* in1) { 231 Node* ret; 232 TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry()) 233 .Input(in0) 234 .Input(in1) 235 .Finalize(g, &ret)); 236 return ret; 237} 238 239Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins) { 240 Node* ret; 241 auto b = NodeBuilder(g->NewName("n"), func, g->op_registry()); 242 for (Node* n : ins) b = b.Input(n); 243 TF_CHECK_OK(b.Finalize(g, &ret)); 244 return ret; 245} 246 247Node* Identity(Graph* g, Node* input, int index) { 248 Node* ret; 249 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Identity") 250 .Input(input, index) 251 .Finalize(g, &ret)); 252 return ret; 253} 254 255Node* Add(Graph* g, Node* in0, Node* in1) { return Binary(g, "Add", in0, in1); } 256 257Node* Reverse(Graph* g, Node* tensor, Node* axis) { 258 return Binary(g, "ReverseV2", tensor, axis); 259} 260 261Node* Error(Graph* g, Node* input, const string& errmsg) { 262 Node* ret; 263 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Error") 264 .Input(input) 265 .Attr("message", errmsg) 266 .Finalize(g, &ret)); 267 return ret; 268} 269 270Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type) { 271 DCHECK(out_type != invalid_type); 272 Node* ret; 273 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "InvalidRefType") 274 .Attr("TIn", out_type) 275 .Attr("TOut", invalid_type) 276 .Finalize(g, &ret)); 277 return ret; 278} 279 280Node* Delay(Graph* g, Node* input, Microseconds delay_micros) { 281 Node* ret; 282 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Delay") 283 .Input(input) 284 .Attr("micros", delay_micros.value()) 285 .Finalize(g, &ret)); 286 return ret; 287} 288 289Node* NoOp(Graph* g, const std::vector<Node*>& control_inputs) { 290 Node* ret; 291 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "NoOp") 292 .ControlInputs(control_inputs) 293 .Finalize(g, &ret)); 294 return ret; 295} 296 297Node* Switch(Graph* g, Node* in0, Node* in1) { 298 Node* ret; 299 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Switch") 300 .Input(in0) 301 .Input(in1) 302 .Finalize(g, &ret)); 303 return ret; 304} 305 306Node* Enter(Graph* g, Node* input, const string& frame_name) { 307 Node* ret; 308 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Enter") 309 .Input(input) 310 .Attr("frame_name", frame_name) 311 .Finalize(g, &ret)); 312 return ret; 313} 314 315Node* Exit(Graph* g, Node* input) { 316 Node* ret; 317 TF_CHECK_OK( 318 NodeBuilder(g->NewName("n"), "Exit").Input(input).Finalize(g, &ret)); 319 return ret; 320} 321 322Node* Merge(Graph* g, Node* in0, Node* in1) { 323 Node* ret; 324 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Merge") 325 .Input({in0, in1}) 326 .Finalize(g, &ret)); 327 return ret; 328} 329 330Node* Merge(Graph* g, Node* in0, gtl::ArraySlice<string> remaining_in) { 331 std::vector<NodeBuilder::NodeOut> inputs; 332 inputs.reserve(remaining_in.size() + 1); 333 inputs.emplace_back(in0); 334 for (const string& in_name : remaining_in) { 335 inputs.emplace_back(in_name, 0, inputs[0].dt); 336 } 337 338 Node* ret; 339 TF_CHECK_OK( 340 NodeBuilder(g->NewName("n"), "Merge").Input(inputs).Finalize(g, &ret)); 341 return ret; 342} 343 344Node* Concat(Graph* g, Node* concat_dim, gtl::ArraySlice<Node*> tensors) { 345 std::vector<NodeBuilder::NodeOut> nodeouts; 346 nodeouts.reserve(tensors.size()); 347 for (auto const t : tensors) { 348 nodeouts.emplace_back(t); 349 } 350 Node* ret; 351 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Concat") 352 .Input(concat_dim) 353 .Input(nodeouts) 354 .Finalize(g, &ret)); 355 return ret; 356} 357 358Node* ConcatV2(Graph* g, gtl::ArraySlice<Node*> tensors, Node* concat_dim) { 359 std::vector<NodeBuilder::NodeOut> nodeouts; 360 nodeouts.reserve(tensors.size()); 361 for (auto const t : tensors) { 362 nodeouts.emplace_back(t); 363 } 364 Node* ret; 365 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "ConcatV2") 366 .Input(nodeouts) 367 .Input(concat_dim) 368 .Finalize(g, &ret)); 369 return ret; 370} 371 372Node* Next(Graph* g, const string& name, Node* input) { 373 Node* ret; 374 TF_CHECK_OK( 375 NodeBuilder(name, "NextIteration").Input(input).Finalize(g, &ret)); 376 return ret; 377} 378 379Node* LoopCond(Graph* g, Node* input) { 380 Node* ret; 381 TF_CHECK_OK( 382 NodeBuilder(g->NewName("n"), "LoopCond").Input(input).Finalize(g, &ret)); 383 return ret; 384} 385 386Node* Less(Graph* g, Node* in0, Node* in1) { 387 return Binary(g, "Less", in0, in1); 388} 389 390Node* Select(Graph* g, Node* c, Node* inx, Node* iny) { 391 Node* ret; 392 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Select") 393 .Input(c) 394 .Input(inx) 395 .Input(iny) 396 .Finalize(g, &ret)); 397 return ret; 398} 399 400Node* Cast(Graph* g, Node* in, DataType dst) { 401 Node* ret; 402 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Cast") 403 .Input(in) 404 .Attr("DstT", dst) 405 .Finalize(g, &ret)); 406 return ret; 407} 408 409Node* BroadcastArgs(Graph* g, Node* s0, Node* s1) { 410 Node* ret; 411 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastArgs") 412 .Input(s0) 413 .Input(s1) 414 .Finalize(g, &ret)); 415 return ret; 416} 417 418Node* BroadcastGradientArgs(Graph* g, Node* s0, Node* s1) { 419 Node* ret; 420 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastGradientArgs") 421 .Input(s0) 422 .Input(s1) 423 .Finalize(g, &ret)); 424 return ret; 425} 426 427Node* Gather(Graph* g, Node* in0, Node* in1) { 428 Node* ret; 429 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Gather") 430 .Input(in0) 431 .Input(in1) 432 .Finalize(g, &ret)); 433 return ret; 434} 435 436Node* GetSessionTensor(Graph* g, Node* in) { 437 Node* ret; 438 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GetSessionTensor") 439 .Input(in, 0) 440 .Attr("dtype", DT_FLOAT) 441 .Finalize(g, &ret)); 442 return ret; 443} 444 445Node* Relu(Graph* g, Node* in) { 446 Node* ret; 447 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Relu") 448 .Input(in, 0) 449 .Attr("T", DT_FLOAT) 450 .Finalize(g, &ret)); 451 return ret; 452} 453 454Node* Relu6(Graph* g, Node* in) { 455 Node* ret; 456 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Relu6") 457 .Input(in, 0) 458 .Attr("T", DT_FLOAT) 459 .Finalize(g, &ret)); 460 return ret; 461} 462 463Node* BiasAdd(Graph* g, Node* value, Node* bias) { 464 Node* ret; 465 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BiasAdd") 466 .Input(value) 467 .Input(bias) 468 .Attr("T", DT_FLOAT) 469 .Finalize(g, &ret)); 470 return ret; 471} 472 473Node* Conv2D(Graph* g, Node* in0, Node* in1) { 474 Node* ret; 475 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Conv2D") 476 .Input(in0) 477 .Input(in1) 478 .Attr("T", DT_FLOAT) 479 .Attr("strides", {1, 1, 1, 1}) 480 .Attr("padding", "SAME") 481 .Finalize(g, &ret)); 482 return ret; 483} 484 485void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); } 486 487} // end namespace graph 488} // end namespace test 489} // end namespace tensorflow 490