1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/core/graph/testlib.h" 17 18#include <vector> 19#include "tensorflow/core/framework/graph.pb.h" 20#include "tensorflow/core/framework/node_def_builder.h" 21#include "tensorflow/core/framework/node_def_util.h" 22#include "tensorflow/core/framework/op.h" 23#include "tensorflow/core/framework/op_kernel.h" 24#include "tensorflow/core/framework/types.h" 25#include "tensorflow/core/framework/types.pb.h" 26#include "tensorflow/core/graph/graph.h" 27#include "tensorflow/core/graph/node_builder.h" 28#include "tensorflow/core/kernels/constant_op.h" 29#include "tensorflow/core/lib/core/status.h" 30#include "tensorflow/core/platform/logging.h" 31 32namespace tensorflow { 33 34// HostConst: forced to generate output on the host. 35// Only used by testlib; no op is registered for this kernel 36// externally (i.e., in array_ops.cc) 37REGISTER_KERNEL_BUILDER(Name("HostConst").Device(DEVICE_CPU), HostConstantOp); 38REGISTER_KERNEL_BUILDER( 39 Name("HostConst").Device(DEVICE_GPU).HostMemory("output"), HostConstantOp); 40#ifdef TENSORFLOW_USE_SYCL 41REGISTER_KERNEL_BUILDER( 42 Name("HostConst").Device(DEVICE_SYCL).HostMemory("output"), HostConstantOp); 43#endif // TENSORFLOW_USE_SYCL 44 45// Register the HostConst Op 46// Returns a constant tensor on the host. Useful for writing C++ tests 47// and benchmarks which run on GPU but require arguments pinned to the host. 48// Used by test::graph::HostConstant. 49// value: Attr `value` is the tensor to return. 50REGISTER_OP("HostConst") 51 .Output("output: dtype") 52 .Attr("value: tensor") 53 .Attr("dtype: type"); 54 55namespace test { 56namespace graph { 57 58Node* Send(Graph* g, Node* input, const string& tensor, const string& sender, 59 const uint64 sender_incarnation, const string& receiver) { 60 Node* ret; 61 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Send") 62 .Input(input, 0) 63 .Attr("tensor_name", tensor) 64 .Attr("send_device", sender) 65 .Attr("send_device_incarnation", 66 static_cast<int64>(sender_incarnation)) 67 .Attr("recv_device", receiver) 68 .Finalize(g, &ret)); 69 return ret; 70} 71 72Node* Recv(Graph* g, const string& tensor, const string& type, 73 const string& sender, const uint64 sender_incarnation, 74 const string& receiver) { 75 Node* ret; 76 DataType dtype; 77 CHECK(DataTypeFromString(type, &dtype)); 78 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Recv") 79 .Attr("tensor_type", dtype) 80 .Attr("tensor_name", tensor) 81 .Attr("send_device", sender) 82 .Attr("send_device_incarnation", 83 static_cast<int64>(sender_incarnation)) 84 .Attr("recv_device", receiver) 85 .Finalize(g, &ret)); 86 return ret; 87} 88 89Node* Constant(Graph* g, const Tensor& tensor) { 90 Node* ret; 91 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Const") 92 .Attr("dtype", tensor.dtype()) 93 .Attr("value", tensor) 94 .Finalize(g, &ret)); 95 return ret; 96} 97 98Node* Constant(Graph* g, const Tensor& tensor, const string& name) { 99 Node* ret; 100 TF_CHECK_OK(NodeBuilder(name, "Const") 101 .Attr("dtype", tensor.dtype()) 102 .Attr("value", tensor) 103 .Finalize(g, &ret)); 104 return ret; 105} 106 107Node* HostConstant(Graph* g, const Tensor& tensor) { 108 return HostConstant(g, tensor, g->NewName("n")); 109} 110 111Node* HostConstant(Graph* g, const Tensor& tensor, const string& name) { 112 Node* ret; 113 TF_CHECK_OK(NodeBuilder(name, "HostConst") 114 .Attr("dtype", tensor.dtype()) 115 .Attr("value", tensor) 116 .Finalize(g, &ret)); 117 return ret; 118} 119 120Node* Var(Graph* g, const DataType dtype, const TensorShape& shape) { 121 Node* ret; 122 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Variable") 123 .Attr("dtype", dtype) 124 .Attr("shape", shape) 125 .Finalize(g, &ret)); 126 return ret; 127} 128 129Node* Var(Graph* g, const DataType dtype, const TensorShape& shape, 130 const string& name) { 131 Node* ret; 132 TF_CHECK_OK(NodeBuilder(name, "Variable") 133 .Attr("dtype", dtype) 134 .Attr("shape", shape) 135 .Finalize(g, &ret)); 136 return ret; 137} 138 139Node* Assign(Graph* g, Node* var, Node* val) { 140 Node* ret; 141 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Assign") 142 .Input(var) 143 .Input(val) 144 .Attr("use_locking", true) 145 .Finalize(g, &ret)); 146 return ret; 147} 148 149Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes, 150 bool keep_dims) { 151 Node* ret; 152 TF_CHECK_OK(NodeBuilder(g->NewName("n"), reduce, g->op_registry()) 153 .Input(data) 154 .Input(axes) 155 .Attr("keep_dims", keep_dims) 156 .Finalize(g, &ret)); 157 return ret; 158} 159 160Node* QuantizeToUINT8(Graph* g, Node* data) { 161 Node* ret; 162 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Quantize") 163 .Input(data) 164 .Attr("T", DT_QUINT8) 165 .Attr("max_range", 1.0f) 166 .Attr("min_range", -1.0f) 167 .Finalize(g, &ret)); 168 return ret; 169} 170 171Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a, 172 bool transpose_b) { 173 Node* ret; 174 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatMul") 175 .Input(in0) 176 .Input(in1) 177 .Attr("transpose_a", transpose_a) 178 .Attr("transpose_b", transpose_b) 179 .Finalize(g, &ret)); 180 return ret; 181} 182 183Node* BatchMatmul(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) { 184 Node* ret; 185 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMul") 186 .Input(in0) 187 .Input(in1) 188 .Attr("adj_x", adj_x) 189 .Attr("adj_y", adj_y) 190 .Finalize(g, &ret)); 191 return ret; 192} 193 194Node* RandomNumberGenerator(const string& op, Graph* g, Node* input, 195 DataType dtype) { 196 Node* ret; 197 TF_CHECK_OK(NodeBuilder(g->NewName("n"), op, g->op_registry()) 198 .Input(input) 199 .Attr("dtype", dtype) 200 .Attr("seed", 0) 201 .Finalize(g, &ret)); 202 return ret; 203} 204 205Node* RandomUniform(Graph* g, Node* input, DataType dtype) { 206 return RandomNumberGenerator("RandomUniform", g, input, dtype); 207} 208 209Node* RandomGaussian(Graph* g, Node* input, DataType dtype) { 210 return RandomNumberGenerator("RandomStandardNormal", g, input, dtype); 211} 212 213Node* TruncatedNormal(Graph* g, Node* input, DataType dtype) { 214 return RandomNumberGenerator("TruncatedNormal", g, input, dtype); 215} 216 217Node* RandomGamma(Graph* g, Node* shape, Node* alpha) { 218 Node* ret; 219 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "RandomGamma") 220 .Input(shape) 221 .Input(alpha) 222 .Attr("seed", 0) 223 .Finalize(g, &ret)); 224 return ret; 225} 226 227Node* RandomPoisson(Graph* g, Node* shape, Node* lam) { 228 Node* ret; 229 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "RandomPoisson") 230 .Input(shape) 231 .Input(lam) 232 .Attr("seed", 0) 233 .Finalize(g, &ret)); 234 return ret; 235} 236 237Node* Unary(Graph* g, const string& func, Node* input, int index) { 238 Node* ret; 239 TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry()) 240 .Input(input, index) 241 .Finalize(g, &ret)); 242 return ret; 243} 244 245Node* Binary(Graph* g, const string& func, Node* in0, Node* in1) { 246 Node* ret; 247 TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry()) 248 .Input(in0) 249 .Input(in1) 250 .Finalize(g, &ret)); 251 return ret; 252} 253 254Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins) { 255 Node* ret; 256 auto b = NodeBuilder(g->NewName("n"), func, g->op_registry()); 257 for (Node* n : ins) b = b.Input(n); 258 TF_CHECK_OK(b.Finalize(g, &ret)); 259 return ret; 260} 261 262Node* Identity(Graph* g, Node* input, int index) { 263 Node* ret; 264 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Identity") 265 .Input(input, index) 266 .Finalize(g, &ret)); 267 return ret; 268} 269 270Node* Add(Graph* g, Node* in0, Node* in1) { return Binary(g, "Add", in0, in1); } 271 272Node* Reverse(Graph* g, Node* tensor, Node* axis) { 273 return Binary(g, "ReverseV2", tensor, axis); 274} 275 276Node* Roll(Graph* g, Node* input, Node* shift, Node* axis) { 277 Node* ret; 278 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Roll", g->op_registry()) 279 .Input(input) 280 .Input(shift) 281 .Input(axis) 282 .Finalize(g, &ret)); 283 return ret; 284} 285 286Node* Error(Graph* g, Node* input, const string& errmsg) { 287 Node* ret; 288 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Error") 289 .Input(input) 290 .Attr("message", errmsg) 291 .Finalize(g, &ret)); 292 return ret; 293} 294 295Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type) { 296 DCHECK(out_type != invalid_type); 297 Node* ret; 298 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "InvalidRefType") 299 .Attr("TIn", out_type) 300 .Attr("TOut", invalid_type) 301 .Finalize(g, &ret)); 302 return ret; 303} 304 305Node* Delay(Graph* g, Node* input, Microseconds delay_micros) { 306 Node* ret; 307 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Delay") 308 .Input(input) 309 .Attr("micros", delay_micros.value()) 310 .Finalize(g, &ret)); 311 return ret; 312} 313 314Node* NoOp(Graph* g, const std::vector<Node*>& control_inputs) { 315 Node* ret; 316 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "NoOp") 317 .ControlInputs(control_inputs) 318 .Finalize(g, &ret)); 319 return ret; 320} 321 322Node* Switch(Graph* g, Node* in0, Node* in1) { 323 Node* ret; 324 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Switch") 325 .Input(in0) 326 .Input(in1) 327 .Finalize(g, &ret)); 328 return ret; 329} 330 331Node* Enter(Graph* g, Node* input, const string& frame_name) { 332 Node* ret; 333 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Enter") 334 .Input(input) 335 .Attr("frame_name", frame_name) 336 .Finalize(g, &ret)); 337 return ret; 338} 339 340Node* Exit(Graph* g, Node* input) { 341 Node* ret; 342 TF_CHECK_OK( 343 NodeBuilder(g->NewName("n"), "Exit").Input(input).Finalize(g, &ret)); 344 return ret; 345} 346 347Node* Merge(Graph* g, Node* in0, Node* in1) { 348 Node* ret; 349 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Merge") 350 .Input({in0, in1}) 351 .Finalize(g, &ret)); 352 return ret; 353} 354 355Node* Merge(Graph* g, Node* in0, gtl::ArraySlice<string> remaining_in) { 356 std::vector<NodeBuilder::NodeOut> inputs; 357 inputs.reserve(remaining_in.size() + 1); 358 inputs.emplace_back(in0); 359 for (const string& in_name : remaining_in) { 360 inputs.emplace_back(in_name, 0, inputs[0].dt); 361 } 362 363 Node* ret; 364 TF_CHECK_OK( 365 NodeBuilder(g->NewName("n"), "Merge").Input(inputs).Finalize(g, &ret)); 366 return ret; 367} 368 369Node* Concat(Graph* g, Node* concat_dim, gtl::ArraySlice<Node*> tensors) { 370 std::vector<NodeBuilder::NodeOut> nodeouts; 371 nodeouts.reserve(tensors.size()); 372 for (auto const t : tensors) { 373 nodeouts.emplace_back(t); 374 } 375 Node* ret; 376 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Concat") 377 .Input(concat_dim) 378 .Input(nodeouts) 379 .Finalize(g, &ret)); 380 return ret; 381} 382 383Node* ConcatV2(Graph* g, gtl::ArraySlice<Node*> tensors, Node* concat_dim) { 384 std::vector<NodeBuilder::NodeOut> nodeouts; 385 nodeouts.reserve(tensors.size()); 386 for (auto const t : tensors) { 387 nodeouts.emplace_back(t); 388 } 389 Node* ret; 390 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "ConcatV2") 391 .Input(nodeouts) 392 .Input(concat_dim) 393 .Finalize(g, &ret)); 394 return ret; 395} 396 397Node* Next(Graph* g, const string& name, Node* input) { 398 Node* ret; 399 TF_CHECK_OK( 400 NodeBuilder(name, "NextIteration").Input(input).Finalize(g, &ret)); 401 return ret; 402} 403 404Node* LoopCond(Graph* g, Node* input) { 405 Node* ret; 406 TF_CHECK_OK( 407 NodeBuilder(g->NewName("n"), "LoopCond").Input(input).Finalize(g, &ret)); 408 return ret; 409} 410 411Node* Less(Graph* g, Node* in0, Node* in1) { 412 return Binary(g, "Less", in0, in1); 413} 414 415Node* Select(Graph* g, Node* c, Node* inx, Node* iny) { 416 Node* ret; 417 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Select") 418 .Input(c) 419 .Input(inx) 420 .Input(iny) 421 .Finalize(g, &ret)); 422 return ret; 423} 424 425Node* Cast(Graph* g, Node* in, DataType dst) { 426 Node* ret; 427 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Cast") 428 .Input(in) 429 .Attr("DstT", dst) 430 .Finalize(g, &ret)); 431 return ret; 432} 433 434Node* Gather(Graph* g, Node* in0, Node* in1, Node* axis) { 435 Node* ret; 436 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GatherV2") 437 .Input(in0) 438 .Input(in1) 439 .Input(axis) 440 .Finalize(g, &ret)); 441 return ret; 442} 443 444Node* GetSessionTensor(Graph* g, Node* in) { 445 Node* ret; 446 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GetSessionTensor") 447 .Input(in, 0) 448 .Attr("dtype", DT_FLOAT) 449 .Finalize(g, &ret)); 450 return ret; 451} 452 453Node* Relu(Graph* g, Node* in) { 454 Node* ret; 455 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Relu") 456 .Input(in, 0) 457 .Attr("T", DT_FLOAT) 458 .Finalize(g, &ret)); 459 return ret; 460} 461 462Node* Relu6(Graph* g, Node* in) { 463 Node* ret; 464 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Relu6") 465 .Input(in, 0) 466 .Attr("T", DT_FLOAT) 467 .Finalize(g, &ret)); 468 return ret; 469} 470 471Node* BiasAdd(Graph* g, Node* value, Node* bias) { 472 Node* ret; 473 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BiasAdd") 474 .Input(value) 475 .Input(bias) 476 .Attr("T", DT_FLOAT) 477 .Finalize(g, &ret)); 478 return ret; 479} 480 481Node* Conv2D(Graph* g, Node* in0, Node* in1) { 482 Node* ret; 483 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Conv2D") 484 .Input(in0) 485 .Input(in1) 486 .Attr("T", DT_FLOAT) 487 .Attr("strides", {1, 1, 1, 1}) 488 .Attr("padding", "SAME") 489 .Finalize(g, &ret)); 490 return ret; 491} 492 493Node* Diag(Graph* g, Node* in, DataType type) { 494 Node* ret; 495 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Diag") 496 .Input(in) 497 .Attr("T", type) 498 .Finalize(g, &ret)); 499 return ret; 500} 501 502Node* DiagPart(Graph* g, Node* in, DataType type) { 503 Node* ret; 504 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "DiagPart") 505 .Input(in) 506 .Attr("T", type) 507 .Finalize(g, &ret)); 508 return ret; 509} 510 511void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); } 512 513} // end namespace graph 514} // end namespace test 515} // end namespace tensorflow 516