1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/core/kernels/control_flow_ops.h" 17 18#include "tensorflow/core/framework/op_kernel.h" 19#include "tensorflow/core/framework/register_types.h" 20#include "tensorflow/core/framework/tensor.h" 21#include "tensorflow/core/framework/types.h" 22#include "tensorflow/core/platform/macros.h" 23 24namespace tensorflow { 25 26void SwitchOp::Compute(OpKernelContext* context) { 27 const Tensor& outputPorts = context->input(1); 28 OP_REQUIRES(context, TensorShapeUtils::IsScalar(outputPorts.shape()), 29 errors::InvalidArgument("The second input must be a scalar, " 30 "but it has shape ", 31 outputPorts.shape().DebugString())); 32 33 bool pred = outputPorts.scalar<bool>()(); 34 int port = (pred) ? 1 : 0; 35 if (context->input_is_ref(0)) { 36 context->forward_ref_input_to_ref_output(0, port); 37 } else { 38 context->set_output(port, context->input(0)); 39 } 40} 41 42#define REGISTER_CPU_SWITCH(type) \ 43 REGISTER_KERNEL_BUILDER(Name("Switch") \ 44 .Device(DEVICE_CPU) \ 45 .HostMemory("pred") \ 46 .TypeConstraint<type>("T"), \ 47 SwitchOp) 48 49#define REGISTER_CPU_REF_SWITCH(type) \ 50 REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ 51 .Device(DEVICE_CPU) \ 52 .HostMemory("pred") \ 53 .TypeConstraint<type>("T"), \ 54 SwitchOp) 55 56#define REGISTER_GPU_SWITCH(type) \ 57 REGISTER_KERNEL_BUILDER(Name("Switch") \ 58 .Device(DEVICE_GPU) \ 59 .HostMemory("pred") \ 60 .TypeConstraint<type>("T"), \ 61 SwitchOp) 62 63#define REGISTER_GPU_REF_SWITCH(type) \ 64 REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ 65 .Device(DEVICE_GPU) \ 66 .HostMemory("pred") \ 67 .TypeConstraint<type>("T"), \ 68 SwitchOp) 69 70TF_CALL_ALL_TYPES(REGISTER_CPU_SWITCH); 71TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SWITCH); 72TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_SWITCH); 73 74TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_SWITCH); 75TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH); 76 77#undef REGISTER_CPU_SWITCH 78#undef REGISTER_CPU_REF_SWITCH 79#undef REGISTER_GPU_SWITCH 80#undef REGISTER_GPU_REF_SWITCH 81 82// Special GPU kernels for int32 and string. 83// TODO(b/25387198): Also enable int32 in device memory. This kernel 84// registration requires all int32 inputs and outputs to be in host memory. 85#define REGISTER_GPU_HOST_KERNEL(type) \ 86 REGISTER_KERNEL_BUILDER(Name("Switch") \ 87 .Device(DEVICE_GPU) \ 88 .HostMemory("data") \ 89 .HostMemory("pred") \ 90 .HostMemory("output_false") \ 91 .HostMemory("output_true") \ 92 .TypeConstraint<type>("T"), \ 93 SwitchOp) 94 95#define REGISTER_GPU_HOST_REF_KERNEL(type) \ 96 REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ 97 .Device(DEVICE_GPU) \ 98 .HostMemory("data") \ 99 .HostMemory("pred") \ 100 .HostMemory("output_false") \ 101 .HostMemory("output_true") \ 102 .TypeConstraint<type>("T"), \ 103 SwitchOp) 104 105REGISTER_GPU_HOST_KERNEL(int32); 106REGISTER_GPU_HOST_REF_KERNEL(int32); 107REGISTER_GPU_HOST_KERNEL(bool); 108REGISTER_GPU_HOST_REF_KERNEL(bool); 109REGISTER_GPU_HOST_KERNEL(string); 110REGISTER_GPU_HOST_REF_KERNEL(string); 111 112#undef REGISTER_GPU_HOST_KERNEL 113#undef REGISTER_GPU_HOST_REF_KERNEL 114 115#ifdef TENSORFLOW_USE_SYCL 116#define REGISTER_SYCL_SWITCH(type) \ 117 REGISTER_KERNEL_BUILDER(Name("Switch") \ 118 .Device(DEVICE_SYCL) \ 119 .HostMemory("pred") \ 120 .TypeConstraint<type>("T"), \ 121 SwitchOp) 122TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_SWITCH); 123 124#define REGISTER_SYCL_REF_SWITCH(type) \ 125 REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ 126 .Device(DEVICE_SYCL) \ 127 .HostMemory("pred") \ 128 .TypeConstraint<type>("T"), \ 129 SwitchOp) 130TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_SWITCH); 131 132#undef REGISTER_SYCL_SWITCH 133#undef REGISTER_SYCL_REF_SWITCH 134 135#define REGISTER_SYCL_HOST_KERNEL(type) \ 136 REGISTER_KERNEL_BUILDER(Name("Switch") \ 137 .Device(DEVICE_SYCL) \ 138 .HostMemory("data") \ 139 .HostMemory("pred") \ 140 .HostMemory("output_false") \ 141 .HostMemory("output_true") \ 142 .TypeConstraint<type>("T"), \ 143 SwitchOp) 144 145REGISTER_SYCL_HOST_KERNEL(bool); 146REGISTER_SYCL_HOST_KERNEL(string); 147REGISTER_SYCL_HOST_KERNEL(int32); 148 149#define REGISTER_SYCL_HOST_REF_KERNEL(type) \ 150 REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ 151 .Device(DEVICE_SYCL) \ 152 .HostMemory("data") \ 153 .HostMemory("pred") \ 154 .HostMemory("output_false") \ 155 .HostMemory("output_true") \ 156 .TypeConstraint<type>("T"), \ 157 SwitchOp) 158 159REGISTER_SYCL_HOST_REF_KERNEL(int32); 160REGISTER_SYCL_HOST_REF_KERNEL(bool); 161REGISTER_SYCL_HOST_REF_KERNEL(string); 162 163#undef REGISTER_SYCL_HOST_KERNEL 164#undef REGISTER_SYCL_HOST_REF_KERNEL 165#endif // TENSORFLOW_USE_SYCL 166 167class RefSelectOp : public OpKernel { 168 public: 169 explicit RefSelectOp(OpKernelConstruction* context) : OpKernel(context) { 170 OP_REQUIRES_OK(context, context->GetAttr("N", &num_ref_inputs_)); 171 } 172 173 void Compute(OpKernelContext* context) override { 174 const Tensor& index_tensor = context->input(0); 175 OP_REQUIRES(context, TensorShapeUtils::IsScalar(index_tensor.shape()), 176 errors::InvalidArgument("Index must be a scalar, " 177 "but it has shape ", 178 index_tensor.shape().DebugString())); 179 180 int32 index = index_tensor.scalar<int32>()(); 181 182 OP_REQUIRES(context, index >= 0 && index < num_ref_inputs_, 183 errors::InvalidArgument("Index must be in the range [0, ", 184 num_ref_inputs_, ") but got ", index)); 185 context->forward_ref_input_to_ref_output(index + 1, 0); 186 } 187 188 bool IsExpensive() override { return false; } 189 190 ~RefSelectOp() override {} 191 192 TF_DISALLOW_COPY_AND_ASSIGN(RefSelectOp); 193 194 private: 195 int num_ref_inputs_; 196}; 197 198#define REGISTER_CPU_REF_SELECT(type) \ 199 REGISTER_KERNEL_BUILDER(Name("RefSelect") \ 200 .Device(DEVICE_CPU) \ 201 .HostMemory("index") \ 202 .TypeConstraint<type>("T"), \ 203 RefSelectOp) 204TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SELECT); 205 206#undef REGISTER_CPU_REF_SWITCH 207 208MergeOp::MergeOp(OpKernelConstruction* context) : OpKernel(context) { 209 const DataType dt = context->input_type(0); 210 const int num_in = context->num_inputs(); 211 OP_REQUIRES_OK(context, context->MatchSignature(DataTypeVector(num_in, dt), 212 {dt, DT_INT32})); 213} 214 215void MergeOp::Compute(OpKernelContext* context) { 216 bool input_seen = false; 217 for (int i = 0; i < context->num_inputs(); ++i) { 218 if (context->has_input(i)) { 219 if (input_seen) { 220 context->SetStatus( 221 errors::Internal("Merge can not have more than one valid input.")); 222 return; 223 } 224 input_seen = true; 225 226 if (IsRefType(context->input_dtype(i))) { 227 context->forward_ref_input_to_ref_output(i, 0); 228 } else { 229 context->set_output(0, context->input(i)); 230 } 231 Tensor* value_index = nullptr; 232 OP_REQUIRES_OK( 233 context, context->allocate_output(1, TensorShape({}), &value_index)); 234 value_index->scalar<int32>()() = i; 235 } 236 } 237} 238 239REGISTER_KERNEL_BUILDER(Name("Merge").Device(DEVICE_CPU), MergeOp); 240REGISTER_KERNEL_BUILDER(Name("RefMerge").Device(DEVICE_CPU), MergeOp); 241 242#define REGISTER_GPU_KERNEL(type) \ 243 REGISTER_KERNEL_BUILDER(Name("Merge") \ 244 .Device(DEVICE_GPU) \ 245 .TypeConstraint<type>("T") \ 246 .HostMemory("value_index"), \ 247 MergeOp); 248 249#define REGISTER_GPU_REF_KERNEL(type) \ 250 REGISTER_KERNEL_BUILDER(Name("RefMerge") \ 251 .Device(DEVICE_GPU) \ 252 .TypeConstraint<type>("T") \ 253 .HostMemory("value_index"), \ 254 MergeOp); 255 256TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); 257TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL); 258REGISTER_GPU_KERNEL(bool); 259REGISTER_GPU_REF_KERNEL(bool); 260 261#undef REGISTER_GPU_KERNEL 262#undef REGISTER_GPU_REF_KERNEL 263 264#ifdef TENSORFLOW_USE_SYCL 265#define REGISTER_SYCL_KERNEL(type) \ 266 REGISTER_KERNEL_BUILDER(Name("Merge") \ 267 .Device(DEVICE_SYCL) \ 268 .TypeConstraint<type>("T") \ 269 .HostMemory("value_index"), \ 270 MergeOp); 271REGISTER_SYCL_KERNEL(bool); 272TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); 273 274#define REGISTER_SYCL_REF_KERNEL(type) \ 275 REGISTER_KERNEL_BUILDER(Name("RefMerge") \ 276 .Device(DEVICE_SYCL) \ 277 .TypeConstraint<type>("T") \ 278 .HostMemory("value_index"), \ 279 MergeOp); 280REGISTER_SYCL_REF_KERNEL(bool); 281TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL); 282 283#undef REGISTER_SYCL_KERNEL 284#undef REGISTER_SYCL_REF_KERNEL 285#endif // TENSORFLOW_USE_SYCL 286 287// Special GPU kernels for int32 and string. 288// TODO(b/25387198): Also enable int32 in device memory. This kernel 289// registration requires all int32 inputs and outputs to be in host memory. 290#define REGISTER_GPU_HOST_KERNEL(type) \ 291 REGISTER_KERNEL_BUILDER(Name("Merge") \ 292 .Device(DEVICE_GPU) \ 293 .HostMemory("inputs") \ 294 .HostMemory("output") \ 295 .HostMemory("value_index") \ 296 .TypeConstraint<type>("T"), \ 297 MergeOp); \ 298 REGISTER_KERNEL_BUILDER(Name("RefMerge") \ 299 .Device(DEVICE_GPU) \ 300 .HostMemory("inputs") \ 301 .HostMemory("output") \ 302 .HostMemory("value_index") \ 303 .TypeConstraint<type>("T"), \ 304 MergeOp) 305 306REGISTER_GPU_HOST_KERNEL(int32); 307REGISTER_GPU_HOST_KERNEL(string); 308REGISTER_GPU_HOST_KERNEL(ResourceHandle); 309 310#undef REGISTER_GPU_HOST_KERNEL 311 312#ifdef TENSORFLOW_USE_SYCL 313#define REGISTER_SYCL_HOST_KERNEL(type) \ 314 REGISTER_KERNEL_BUILDER(Name("Merge") \ 315 .Device(DEVICE_SYCL) \ 316 .HostMemory("inputs") \ 317 .HostMemory("output") \ 318 .HostMemory("value_index") \ 319 .TypeConstraint<type>("T"), \ 320 MergeOp); \ 321 REGISTER_KERNEL_BUILDER(Name("RefMerge") \ 322 .Device(DEVICE_SYCL) \ 323 .HostMemory("inputs") \ 324 .HostMemory("output") \ 325 .HostMemory("value_index") \ 326 .TypeConstraint<type>("T"), \ 327 MergeOp) 328 329REGISTER_SYCL_HOST_KERNEL(int32); 330REGISTER_SYCL_HOST_KERNEL(string); 331REGISTER_SYCL_HOST_KERNEL(ResourceHandle); 332 333#undef REGISTER_SYCL_HOST_KERNEL 334#endif // TENSORFLOW_USE_SYCL 335 336void EnterOp::Compute(OpKernelContext* context) { 337 if (IsRefType(context->input_dtype(0))) { 338 context->forward_ref_input_to_ref_output(0, 0); 339 } else { 340 context->set_output(0, context->input(0)); 341 } 342} 343 344REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_CPU), EnterOp); 345REGISTER_KERNEL_BUILDER(Name("RefEnter").Device(DEVICE_CPU), EnterOp); 346 347#define REGISTER_GPU_KERNEL(type) \ 348 REGISTER_KERNEL_BUILDER( \ 349 Name("Enter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp) 350#define REGISTER_GPU_REF_KERNEL(type) \ 351 REGISTER_KERNEL_BUILDER( \ 352 Name("RefEnter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp) 353 354TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); 355TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL); 356REGISTER_GPU_KERNEL(bool); 357REGISTER_GPU_REF_KERNEL(bool); 358 359#undef REGISTER_GPU_KERNEL 360#undef REGISTER_GPU_REF_KERNEL 361 362#ifdef TENSORFLOW_USE_SYCL 363#define REGISTER_SYCL_KERNEL(type) \ 364 REGISTER_KERNEL_BUILDER( \ 365 Name("Enter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp) 366REGISTER_SYCL_KERNEL(bool); 367TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); 368 369#define REGISTER_SYCL_REF_KERNEL(type) \ 370 REGISTER_KERNEL_BUILDER( \ 371 Name("RefEnter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp) 372REGISTER_SYCL_REF_KERNEL(bool); 373TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL); 374 375#undef REGISTER_SYCL_KERNEL 376#undef REGISTER_SYCL_REF_KERNEL 377#define REGISTER_SYCL_HOST_KERNEL(type) \ 378 REGISTER_KERNEL_BUILDER(Name("Enter") \ 379 .Device(DEVICE_SYCL) \ 380 .HostMemory("data") \ 381 .HostMemory("output") \ 382 .TypeConstraint<type>("T"), \ 383 EnterOp) 384 385#define REGISTER_SYCL_HOST_REF_KERNEL(type) \ 386 REGISTER_KERNEL_BUILDER(Name("RefEnter") \ 387 .Device(DEVICE_SYCL) \ 388 .HostMemory("data") \ 389 .HostMemory("output") \ 390 .TypeConstraint<type>("T"), \ 391 EnterOp) 392 393REGISTER_SYCL_HOST_KERNEL(int32); 394REGISTER_SYCL_HOST_REF_KERNEL(int32); 395REGISTER_SYCL_HOST_KERNEL(string); 396REGISTER_SYCL_HOST_REF_KERNEL(string); 397REGISTER_SYCL_HOST_KERNEL(ResourceHandle); 398 399#undef REGISTER_SYCL_HOST_KERNEL 400#undef REGISTER_SYCL_HOST_REF_KERNEL 401#endif // TENSORFLOW_USE_SYCL 402 403// Special GPU kernels for int32 and string. 404// TODO(b/25387198): Also enable int32 in device memory. This kernel 405// registration requires all int32 inputs and outputs to be in host memory. 406#define REGISTER_GPU_HOST_KERNEL(type) \ 407 REGISTER_KERNEL_BUILDER(Name("Enter") \ 408 .Device(DEVICE_GPU) \ 409 .HostMemory("data") \ 410 .HostMemory("output") \ 411 .TypeConstraint<type>("T"), \ 412 EnterOp) 413 414#define REGISTER_GPU_HOST_REF_KERNEL(type) \ 415 REGISTER_KERNEL_BUILDER(Name("RefEnter") \ 416 .Device(DEVICE_GPU) \ 417 .HostMemory("data") \ 418 .HostMemory("output") \ 419 .TypeConstraint<type>("T"), \ 420 EnterOp) 421 422REGISTER_GPU_HOST_KERNEL(int32); 423REGISTER_GPU_HOST_REF_KERNEL(int32); 424REGISTER_GPU_HOST_KERNEL(string); 425REGISTER_GPU_HOST_REF_KERNEL(string); 426REGISTER_GPU_HOST_KERNEL(ResourceHandle); 427 428#undef REGISTER_GPU_HOST_KERNEL 429#undef REGISTER_GPU_HOST_REF_KERNEL 430 431void ExitOp::Compute(OpKernelContext* context) { 432 if (IsRefType(context->input_dtype(0))) { 433 context->forward_ref_input_to_ref_output(0, 0); 434 } else { 435 context->set_output(0, context->input(0)); 436 } 437} 438 439REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_CPU), ExitOp); 440REGISTER_KERNEL_BUILDER(Name("RefExit").Device(DEVICE_CPU), ExitOp); 441 442#define REGISTER_GPU_KERNEL(type) \ 443 REGISTER_KERNEL_BUILDER( \ 444 Name("Exit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp); 445#define REGISTER_GPU_REF_KERNEL(type) \ 446 REGISTER_KERNEL_BUILDER( \ 447 Name("RefExit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp); 448 449TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); 450TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL); 451REGISTER_GPU_KERNEL(bool); 452REGISTER_GPU_REF_KERNEL(bool); 453 454#undef REGISTER_GPU_KERNEL 455#undef REGISTER_GPU_REF_KERNEL 456 457#ifdef TENSORFLOW_USE_SYCL 458#define REGISTER_SYCL_KERNEL(type) \ 459 REGISTER_KERNEL_BUILDER( \ 460 Name("Exit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp); \ 461 REGISTER_KERNEL_BUILDER( \ 462 Name("RefExit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp); 463REGISTER_SYCL_KERNEL(bool); 464TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); 465 466#undef REGISTER_SYCL_KERNEL 467#undef REGISTER_SYCL_REF_KERNEL 468 469#define REGISTER_SYCL_HOST_KERNEL(type) \ 470 REGISTER_KERNEL_BUILDER(Name("Exit") \ 471 .Device(DEVICE_SYCL) \ 472 .HostMemory("data") \ 473 .HostMemory("output") \ 474 .TypeConstraint<type>("T"), \ 475 ExitOp); \ 476 REGISTER_KERNEL_BUILDER(Name("RefExit") \ 477 .Device(DEVICE_SYCL) \ 478 .HostMemory("data") \ 479 .HostMemory("output") \ 480 .TypeConstraint<type>("T"), \ 481 ExitOp) 482 483REGISTER_SYCL_HOST_KERNEL(int32); 484REGISTER_SYCL_HOST_KERNEL(string); 485#undef REGISTER_SYCL_HOST_KERNEL 486#endif // TENSORFLOW_USE_SYCL 487 488// Special GPU kernels for int32 and string. 489// TODO(b/25387198): Also enable int32 in device memory. This kernel 490// registration requires all int32 inputs and outputs to be in host memory. 491#define REGISTER_GPU_HOST_KERNEL(type) \ 492 REGISTER_KERNEL_BUILDER(Name("Exit") \ 493 .Device(DEVICE_GPU) \ 494 .HostMemory("data") \ 495 .HostMemory("output") \ 496 .TypeConstraint<type>("T"), \ 497 ExitOp); \ 498 REGISTER_KERNEL_BUILDER(Name("RefExit") \ 499 .Device(DEVICE_GPU) \ 500 .HostMemory("data") \ 501 .HostMemory("output") \ 502 .TypeConstraint<type>("T"), \ 503 ExitOp) 504 505REGISTER_GPU_HOST_KERNEL(int32); 506REGISTER_GPU_HOST_KERNEL(string); 507 508#undef REGISTER_GPU_HOST_KERNEL 509 510void NextIterationOp::Compute(OpKernelContext* context) { 511 if (IsRefType(context->input_dtype(0))) { 512 context->forward_ref_input_to_ref_output(0, 0); 513 } else { 514 context->set_output(0, context->input(0)); 515 } 516} 517 518REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_CPU), 519 NextIterationOp); 520REGISTER_KERNEL_BUILDER(Name("RefNextIteration").Device(DEVICE_CPU), 521 NextIterationOp); 522 523#define REGISTER_GPU_KERNEL(type) \ 524 REGISTER_KERNEL_BUILDER( \ 525 Name("NextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 526 NextIterationOp); \ 527 REGISTER_KERNEL_BUILDER( \ 528 Name("RefNextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 529 NextIterationOp) 530 531TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); 532REGISTER_GPU_KERNEL(bool); 533 534#undef REGISTER_GPU_KERNEL 535 536// Special GPU kernels for int32 and string. 537// TODO(b/25387198): Also enable int32 in device memory. This kernel 538// registration requires all int32 inputs and outputs to be in host memory. 539#define REGISTER_GPU_HOST_KERNEL(type) \ 540 REGISTER_KERNEL_BUILDER(Name("NextIteration") \ 541 .Device(DEVICE_GPU) \ 542 .HostMemory("data") \ 543 .HostMemory("output") \ 544 .TypeConstraint<type>("T"), \ 545 NextIterationOp); \ 546 REGISTER_KERNEL_BUILDER(Name("RefNextIteration") \ 547 .Device(DEVICE_GPU) \ 548 .HostMemory("data") \ 549 .HostMemory("output") \ 550 .TypeConstraint<type>("T"), \ 551 NextIterationOp) 552 553REGISTER_GPU_HOST_KERNEL(int32); 554REGISTER_GPU_HOST_KERNEL(string); 555 556#undef REGISTER_GPU_HOST_KERNEL 557 558#ifdef TENSORFLOW_USE_SYCL 559#define REGISTER_SYCL_KERNEL(type) \ 560 REGISTER_KERNEL_BUILDER( \ 561 Name("NextIteration").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ 562 NextIterationOp); \ 563 REGISTER_KERNEL_BUILDER( \ 564 Name("RefNextIteration").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ 565 NextIterationOp) 566REGISTER_SYCL_KERNEL(bool); 567TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); 568 569#undef REGISTER_SYCL_KERNEL 570 571#define REGISTER_SYCL_HOST_KERNEL(type) \ 572 REGISTER_KERNEL_BUILDER(Name("NextIteration") \ 573 .Device(DEVICE_SYCL) \ 574 .HostMemory("data") \ 575 .HostMemory("output") \ 576 .TypeConstraint<type>("T"), \ 577 NextIterationOp); \ 578 REGISTER_KERNEL_BUILDER(Name("RefNextIteration") \ 579 .Device(DEVICE_SYCL) \ 580 .HostMemory("data") \ 581 .HostMemory("output") \ 582 .TypeConstraint<type>("T"), \ 583 NextIterationOp) 584 585REGISTER_SYCL_HOST_KERNEL(int32); 586REGISTER_SYCL_HOST_KERNEL(string); 587#undef REGISTER_SYCL_HOST_KERNEL 588#endif // TENSORFLOW_USE_SYCL 589 590// A LoopCond op has one input and one output. The input is a boolean 591// scalar representing the taken branches of the "pivot" Switch that 592// determines loop termination. As a contract, any high-level front-end 593// should always use port '0' of the "pivot" switches for loop exit. 594class LoopCondOp : public OpKernel { 595 public: 596 explicit LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {} 597 598 void Compute(OpKernelContext* context) override { 599 context->set_output(0, context->input(0)); 600 } 601 602 bool IsExpensive() override { return false; } 603 604 ~LoopCondOp() override {} 605 606 TF_DISALLOW_COPY_AND_ASSIGN(LoopCondOp); 607}; 608 609REGISTER_KERNEL_BUILDER(Name("LoopCond").Device(DEVICE_CPU), LoopCondOp); 610REGISTER_KERNEL_BUILDER(Name("LoopCond") 611 .Device(DEVICE_GPU) 612 .HostMemory("input") 613 .HostMemory("output"), 614 LoopCondOp); 615 616#ifdef TENSORFLOW_USE_SYCL 617REGISTER_KERNEL_BUILDER(Name("LoopCond") 618 .Device(DEVICE_SYCL) 619 .HostMemory("input") 620 .HostMemory("output"), 621 LoopCondOp); 622#endif // TENSORFLOW_USE_SYCL 623 624// ControlTrigger kernels 625REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_CPU), 626 ControlTriggerOp); 627 628REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_GPU), 629 ControlTriggerOp); 630 631#ifdef TENSORFLOW_USE_SYCL 632REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_SYCL), 633 ControlTriggerOp); 634#endif // TENSORFLOW_USE_SYCL 635 636// When called, abort op will abort the current process. This can be used to 637// abort remote PSs when needed. 638class AbortOp : public OpKernel { 639 public: 640 explicit AbortOp(OpKernelConstruction* context) : OpKernel(context) { 641 OP_REQUIRES_OK(context, context->GetAttr("error_msg", &error_msg_)); 642 OP_REQUIRES_OK( 643 context, context->GetAttr("exit_without_error", &exit_without_error_)); 644 } 645 646 void Compute(OpKernelContext* context) override { 647 if (!exit_without_error_) { 648 LOG(FATAL) << "Abort_op intentional failure; " << error_msg_; 649 } else { 650 LOG(WARNING) << "Exiting the process: " << error_msg_; 651 exit(0); 652 } 653 } 654 655 private: 656 string error_msg_; 657 bool exit_without_error_; 658}; 659 660REGISTER_KERNEL_BUILDER(Name("Abort").Device(DEVICE_CPU), AbortOp); 661 662} // namespace tensorflow 663