1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/control_flow_ops.h"
17
18#include "tensorflow/core/framework/op_kernel.h"
19#include "tensorflow/core/framework/register_types.h"
20#include "tensorflow/core/framework/tensor.h"
21#include "tensorflow/core/framework/types.h"
22#include "tensorflow/core/platform/macros.h"
23
24namespace tensorflow {
25
26void SwitchOp::Compute(OpKernelContext* context) {
27  const Tensor& outputPorts = context->input(1);
28  OP_REQUIRES(context, TensorShapeUtils::IsScalar(outputPorts.shape()),
29              errors::InvalidArgument("The second input must be a scalar, "
30                                      "but it has shape ",
31                                      outputPorts.shape().DebugString()));
32
33  bool pred = outputPorts.scalar<bool>()();
34  int port = (pred) ? 1 : 0;
35  if (context->input_is_ref(0)) {
36    context->forward_ref_input_to_ref_output(0, port);
37  } else {
38    context->set_output(port, context->input(0));
39  }
40}
41
42#define REGISTER_CPU_SWITCH(type)                         \
43  REGISTER_KERNEL_BUILDER(Name("Switch")                  \
44                              .Device(DEVICE_CPU)         \
45                              .HostMemory("pred")         \
46                              .TypeConstraint<type>("T"), \
47                          SwitchOp)
48
49#define REGISTER_CPU_REF_SWITCH(type)                     \
50  REGISTER_KERNEL_BUILDER(Name("RefSwitch")               \
51                              .Device(DEVICE_CPU)         \
52                              .HostMemory("pred")         \
53                              .TypeConstraint<type>("T"), \
54                          SwitchOp)
55
56#define REGISTER_GPU_SWITCH(type)                         \
57  REGISTER_KERNEL_BUILDER(Name("Switch")                  \
58                              .Device(DEVICE_GPU)         \
59                              .HostMemory("pred")         \
60                              .TypeConstraint<type>("T"), \
61                          SwitchOp)
62
63#define REGISTER_GPU_REF_SWITCH(type)                     \
64  REGISTER_KERNEL_BUILDER(Name("RefSwitch")               \
65                              .Device(DEVICE_GPU)         \
66                              .HostMemory("pred")         \
67                              .TypeConstraint<type>("T"), \
68                          SwitchOp)
69
70TF_CALL_ALL_TYPES(REGISTER_CPU_SWITCH);
71TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SWITCH);
72TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_SWITCH);
73
74TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_SWITCH);
75TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH);
76
77#undef REGISTER_CPU_SWITCH
78#undef REGISTER_CPU_REF_SWITCH
79#undef REGISTER_GPU_SWITCH
80#undef REGISTER_GPU_REF_SWITCH
81
82// Special GPU kernels for int32 and string.
83// TODO(b/25387198): Also enable int32 in device memory. This kernel
84// registration requires all int32 inputs and outputs to be in host memory.
85#define REGISTER_GPU_HOST_KERNEL(type)                    \
86  REGISTER_KERNEL_BUILDER(Name("Switch")                  \
87                              .Device(DEVICE_GPU)         \
88                              .HostMemory("data")         \
89                              .HostMemory("pred")         \
90                              .HostMemory("output_false") \
91                              .HostMemory("output_true")  \
92                              .TypeConstraint<type>("T"), \
93                          SwitchOp)
94
95#define REGISTER_GPU_HOST_REF_KERNEL(type)                \
96  REGISTER_KERNEL_BUILDER(Name("RefSwitch")               \
97                              .Device(DEVICE_GPU)         \
98                              .HostMemory("data")         \
99                              .HostMemory("pred")         \
100                              .HostMemory("output_false") \
101                              .HostMemory("output_true")  \
102                              .TypeConstraint<type>("T"), \
103                          SwitchOp)
104
105REGISTER_GPU_HOST_KERNEL(int32);
106REGISTER_GPU_HOST_REF_KERNEL(int32);
107REGISTER_GPU_HOST_KERNEL(bool);
108REGISTER_GPU_HOST_REF_KERNEL(bool);
109REGISTER_GPU_HOST_KERNEL(string);
110REGISTER_GPU_HOST_REF_KERNEL(string);
111
112#undef REGISTER_GPU_HOST_KERNEL
113#undef REGISTER_GPU_HOST_REF_KERNEL
114
115#ifdef TENSORFLOW_USE_SYCL
116#define REGISTER_SYCL_SWITCH(type)                        \
117  REGISTER_KERNEL_BUILDER(Name("Switch")                  \
118                              .Device(DEVICE_SYCL)        \
119                              .HostMemory("pred")         \
120                              .TypeConstraint<type>("T"), \
121                          SwitchOp)
122TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_SWITCH);
123
124#define REGISTER_SYCL_REF_SWITCH(type)                    \
125  REGISTER_KERNEL_BUILDER(Name("RefSwitch")               \
126                              .Device(DEVICE_SYCL)        \
127                              .HostMemory("pred")         \
128                              .TypeConstraint<type>("T"), \
129                          SwitchOp)
130TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_SWITCH);
131
132#undef REGISTER_SYCL_SWITCH
133#undef REGISTER_SYCL_REF_SWITCH
134
135#define REGISTER_SYCL_HOST_KERNEL(type)                   \
136  REGISTER_KERNEL_BUILDER(Name("Switch")                  \
137                              .Device(DEVICE_SYCL)        \
138                              .HostMemory("data")         \
139                              .HostMemory("pred")         \
140                              .HostMemory("output_false") \
141                              .HostMemory("output_true")  \
142                              .TypeConstraint<type>("T"), \
143                          SwitchOp)
144
145REGISTER_SYCL_HOST_KERNEL(bool);
146REGISTER_SYCL_HOST_KERNEL(string);
147REGISTER_SYCL_HOST_KERNEL(int32);
148
149#define REGISTER_SYCL_HOST_REF_KERNEL(type)               \
150  REGISTER_KERNEL_BUILDER(Name("RefSwitch")               \
151                              .Device(DEVICE_SYCL)        \
152                              .HostMemory("data")         \
153                              .HostMemory("pred")         \
154                              .HostMemory("output_false") \
155                              .HostMemory("output_true")  \
156                              .TypeConstraint<type>("T"), \
157                          SwitchOp)
158
159REGISTER_SYCL_HOST_REF_KERNEL(int32);
160REGISTER_SYCL_HOST_REF_KERNEL(bool);
161REGISTER_SYCL_HOST_REF_KERNEL(string);
162
163#undef REGISTER_SYCL_HOST_KERNEL
164#undef REGISTER_SYCL_HOST_REF_KERNEL
165#endif  // TENSORFLOW_USE_SYCL
166
167class RefSelectOp : public OpKernel {
168 public:
169  explicit RefSelectOp(OpKernelConstruction* context) : OpKernel(context) {
170    OP_REQUIRES_OK(context, context->GetAttr("N", &num_ref_inputs_));
171  }
172
173  void Compute(OpKernelContext* context) override {
174    const Tensor& index_tensor = context->input(0);
175    OP_REQUIRES(context, TensorShapeUtils::IsScalar(index_tensor.shape()),
176                errors::InvalidArgument("Index must be a scalar, "
177                                        "but it has shape ",
178                                        index_tensor.shape().DebugString()));
179
180    int32 index = index_tensor.scalar<int32>()();
181
182    OP_REQUIRES(context, index >= 0 && index < num_ref_inputs_,
183                errors::InvalidArgument("Index must be in the range [0, ",
184                                        num_ref_inputs_, ") but got ", index));
185    context->forward_ref_input_to_ref_output(index + 1, 0);
186  }
187
188  bool IsExpensive() override { return false; }
189
190  ~RefSelectOp() override {}
191
192  TF_DISALLOW_COPY_AND_ASSIGN(RefSelectOp);
193
194 private:
195  int num_ref_inputs_;
196};
197
198#define REGISTER_CPU_REF_SELECT(type)                     \
199  REGISTER_KERNEL_BUILDER(Name("RefSelect")               \
200                              .Device(DEVICE_CPU)         \
201                              .HostMemory("index")        \
202                              .TypeConstraint<type>("T"), \
203                          RefSelectOp)
204TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SELECT);
205
206#undef REGISTER_CPU_REF_SWITCH
207
208MergeOp::MergeOp(OpKernelConstruction* context) : OpKernel(context) {
209  const DataType dt = context->input_type(0);
210  const int num_in = context->num_inputs();
211  OP_REQUIRES_OK(context, context->MatchSignature(DataTypeVector(num_in, dt),
212                                                  {dt, DT_INT32}));
213}
214
215void MergeOp::Compute(OpKernelContext* context) {
216  bool input_seen = false;
217  for (int i = 0; i < context->num_inputs(); ++i) {
218    if (context->has_input(i)) {
219      if (input_seen) {
220        context->SetStatus(
221            errors::Internal("Merge can not have more than one valid input."));
222        return;
223      }
224      input_seen = true;
225
226      if (IsRefType(context->input_dtype(i))) {
227        context->forward_ref_input_to_ref_output(i, 0);
228      } else {
229        context->set_output(0, context->input(i));
230      }
231      Tensor* value_index = nullptr;
232      OP_REQUIRES_OK(
233          context, context->allocate_output(1, TensorShape({}), &value_index));
234      value_index->scalar<int32>()() = i;
235    }
236  }
237}
238
239REGISTER_KERNEL_BUILDER(Name("Merge").Device(DEVICE_CPU), MergeOp);
240REGISTER_KERNEL_BUILDER(Name("RefMerge").Device(DEVICE_CPU), MergeOp);
241
242#define REGISTER_GPU_KERNEL(type)                         \
243  REGISTER_KERNEL_BUILDER(Name("Merge")                   \
244                              .Device(DEVICE_GPU)         \
245                              .TypeConstraint<type>("T")  \
246                              .HostMemory("value_index"), \
247                          MergeOp);
248
249#define REGISTER_GPU_REF_KERNEL(type)                     \
250  REGISTER_KERNEL_BUILDER(Name("RefMerge")                \
251                              .Device(DEVICE_GPU)         \
252                              .TypeConstraint<type>("T")  \
253                              .HostMemory("value_index"), \
254                          MergeOp);
255
256TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
257TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL);
258REGISTER_GPU_KERNEL(bool);
259REGISTER_GPU_REF_KERNEL(bool);
260
261#undef REGISTER_GPU_KERNEL
262#undef REGISTER_GPU_REF_KERNEL
263
264#ifdef TENSORFLOW_USE_SYCL
265#define REGISTER_SYCL_KERNEL(type)                        \
266  REGISTER_KERNEL_BUILDER(Name("Merge")                   \
267                              .Device(DEVICE_SYCL)        \
268                              .TypeConstraint<type>("T")  \
269                              .HostMemory("value_index"), \
270                          MergeOp);
271REGISTER_SYCL_KERNEL(bool);
272TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
273
274#define REGISTER_SYCL_REF_KERNEL(type)                    \
275  REGISTER_KERNEL_BUILDER(Name("RefMerge")                \
276                              .Device(DEVICE_SYCL)        \
277                              .TypeConstraint<type>("T")  \
278                              .HostMemory("value_index"), \
279                          MergeOp);
280REGISTER_SYCL_REF_KERNEL(bool);
281TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
282
283#undef REGISTER_SYCL_KERNEL
284#undef REGISTER_SYCL_REF_KERNEL
285#endif  // TENSORFLOW_USE_SYCL
286
287// Special GPU kernels for int32 and string.
288// TODO(b/25387198): Also enable int32 in device memory. This kernel
289// registration requires all int32 inputs and outputs to be in host memory.
290#define REGISTER_GPU_HOST_KERNEL(type)                    \
291  REGISTER_KERNEL_BUILDER(Name("Merge")                   \
292                              .Device(DEVICE_GPU)         \
293                              .HostMemory("inputs")       \
294                              .HostMemory("output")       \
295                              .HostMemory("value_index")  \
296                              .TypeConstraint<type>("T"), \
297                          MergeOp);                       \
298  REGISTER_KERNEL_BUILDER(Name("RefMerge")                \
299                              .Device(DEVICE_GPU)         \
300                              .HostMemory("inputs")       \
301                              .HostMemory("output")       \
302                              .HostMemory("value_index")  \
303                              .TypeConstraint<type>("T"), \
304                          MergeOp)
305
306REGISTER_GPU_HOST_KERNEL(int32);
307REGISTER_GPU_HOST_KERNEL(string);
308REGISTER_GPU_HOST_KERNEL(ResourceHandle);
309
310#undef REGISTER_GPU_HOST_KERNEL
311
312#ifdef TENSORFLOW_USE_SYCL
313#define REGISTER_SYCL_HOST_KERNEL(type)                   \
314  REGISTER_KERNEL_BUILDER(Name("Merge")                   \
315                              .Device(DEVICE_SYCL)        \
316                              .HostMemory("inputs")       \
317                              .HostMemory("output")       \
318                              .HostMemory("value_index")  \
319                              .TypeConstraint<type>("T"), \
320                          MergeOp);                       \
321  REGISTER_KERNEL_BUILDER(Name("RefMerge")                \
322                              .Device(DEVICE_SYCL)        \
323                              .HostMemory("inputs")       \
324                              .HostMemory("output")       \
325                              .HostMemory("value_index")  \
326                              .TypeConstraint<type>("T"), \
327                          MergeOp)
328
329REGISTER_SYCL_HOST_KERNEL(int32);
330REGISTER_SYCL_HOST_KERNEL(string);
331REGISTER_SYCL_HOST_KERNEL(ResourceHandle);
332
333#undef REGISTER_SYCL_HOST_KERNEL
334#endif  // TENSORFLOW_USE_SYCL
335
336void EnterOp::Compute(OpKernelContext* context) {
337  if (IsRefType(context->input_dtype(0))) {
338    context->forward_ref_input_to_ref_output(0, 0);
339  } else {
340    context->set_output(0, context->input(0));
341  }
342}
343
344REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_CPU), EnterOp);
345REGISTER_KERNEL_BUILDER(Name("RefEnter").Device(DEVICE_CPU), EnterOp);
346
347#define REGISTER_GPU_KERNEL(type) \
348  REGISTER_KERNEL_BUILDER(        \
349      Name("Enter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp)
350#define REGISTER_GPU_REF_KERNEL(type) \
351  REGISTER_KERNEL_BUILDER(            \
352      Name("RefEnter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp)
353
354TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
355TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL);
356REGISTER_GPU_KERNEL(bool);
357REGISTER_GPU_REF_KERNEL(bool);
358
359#undef REGISTER_GPU_KERNEL
360#undef REGISTER_GPU_REF_KERNEL
361
362#ifdef TENSORFLOW_USE_SYCL
363#define REGISTER_SYCL_KERNEL(type) \
364  REGISTER_KERNEL_BUILDER(         \
365      Name("Enter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp)
366REGISTER_SYCL_KERNEL(bool);
367TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
368
369#define REGISTER_SYCL_REF_KERNEL(type) \
370  REGISTER_KERNEL_BUILDER(             \
371      Name("RefEnter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp)
372REGISTER_SYCL_REF_KERNEL(bool);
373TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
374
375#undef REGISTER_SYCL_KERNEL
376#undef REGISTER_SYCL_REF_KERNEL
377#define REGISTER_SYCL_HOST_KERNEL(type)                   \
378  REGISTER_KERNEL_BUILDER(Name("Enter")                   \
379                              .Device(DEVICE_SYCL)        \
380                              .HostMemory("data")         \
381                              .HostMemory("output")       \
382                              .TypeConstraint<type>("T"), \
383                          EnterOp)
384
385#define REGISTER_SYCL_HOST_REF_KERNEL(type)               \
386  REGISTER_KERNEL_BUILDER(Name("RefEnter")                \
387                              .Device(DEVICE_SYCL)        \
388                              .HostMemory("data")         \
389                              .HostMemory("output")       \
390                              .TypeConstraint<type>("T"), \
391                          EnterOp)
392
393REGISTER_SYCL_HOST_KERNEL(int32);
394REGISTER_SYCL_HOST_REF_KERNEL(int32);
395REGISTER_SYCL_HOST_KERNEL(string);
396REGISTER_SYCL_HOST_REF_KERNEL(string);
397REGISTER_SYCL_HOST_KERNEL(ResourceHandle);
398
399#undef REGISTER_SYCL_HOST_KERNEL
400#undef REGISTER_SYCL_HOST_REF_KERNEL
401#endif  // TENSORFLOW_USE_SYCL
402
403// Special GPU kernels for int32 and string.
404// TODO(b/25387198): Also enable int32 in device memory. This kernel
405// registration requires all int32 inputs and outputs to be in host memory.
406#define REGISTER_GPU_HOST_KERNEL(type)                    \
407  REGISTER_KERNEL_BUILDER(Name("Enter")                   \
408                              .Device(DEVICE_GPU)         \
409                              .HostMemory("data")         \
410                              .HostMemory("output")       \
411                              .TypeConstraint<type>("T"), \
412                          EnterOp)
413
414#define REGISTER_GPU_HOST_REF_KERNEL(type)                \
415  REGISTER_KERNEL_BUILDER(Name("RefEnter")                \
416                              .Device(DEVICE_GPU)         \
417                              .HostMemory("data")         \
418                              .HostMemory("output")       \
419                              .TypeConstraint<type>("T"), \
420                          EnterOp)
421
422REGISTER_GPU_HOST_KERNEL(int32);
423REGISTER_GPU_HOST_REF_KERNEL(int32);
424REGISTER_GPU_HOST_KERNEL(string);
425REGISTER_GPU_HOST_REF_KERNEL(string);
426REGISTER_GPU_HOST_KERNEL(ResourceHandle);
427
428#undef REGISTER_GPU_HOST_KERNEL
429#undef REGISTER_GPU_HOST_REF_KERNEL
430
431void ExitOp::Compute(OpKernelContext* context) {
432  if (IsRefType(context->input_dtype(0))) {
433    context->forward_ref_input_to_ref_output(0, 0);
434  } else {
435    context->set_output(0, context->input(0));
436  }
437}
438
439REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_CPU), ExitOp);
440REGISTER_KERNEL_BUILDER(Name("RefExit").Device(DEVICE_CPU), ExitOp);
441
442#define REGISTER_GPU_KERNEL(type) \
443  REGISTER_KERNEL_BUILDER(        \
444      Name("Exit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp);
445#define REGISTER_GPU_REF_KERNEL(type) \
446  REGISTER_KERNEL_BUILDER(            \
447      Name("RefExit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp);
448
449TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
450TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL);
451REGISTER_GPU_KERNEL(bool);
452REGISTER_GPU_REF_KERNEL(bool);
453
454#undef REGISTER_GPU_KERNEL
455#undef REGISTER_GPU_REF_KERNEL
456
457#ifdef TENSORFLOW_USE_SYCL
458#define REGISTER_SYCL_KERNEL(type)                                         \
459  REGISTER_KERNEL_BUILDER(                                                 \
460      Name("Exit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp); \
461  REGISTER_KERNEL_BUILDER(                                                 \
462      Name("RefExit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp);
463REGISTER_SYCL_KERNEL(bool);
464TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
465
466#undef REGISTER_SYCL_KERNEL
467#undef REGISTER_SYCL_REF_KERNEL
468
469#define REGISTER_SYCL_HOST_KERNEL(type)                   \
470  REGISTER_KERNEL_BUILDER(Name("Exit")                    \
471                              .Device(DEVICE_SYCL)        \
472                              .HostMemory("data")         \
473                              .HostMemory("output")       \
474                              .TypeConstraint<type>("T"), \
475                          ExitOp);                        \
476  REGISTER_KERNEL_BUILDER(Name("RefExit")                 \
477                              .Device(DEVICE_SYCL)        \
478                              .HostMemory("data")         \
479                              .HostMemory("output")       \
480                              .TypeConstraint<type>("T"), \
481                          ExitOp)
482
483REGISTER_SYCL_HOST_KERNEL(int32);
484REGISTER_SYCL_HOST_KERNEL(string);
485#undef REGISTER_SYCL_HOST_KERNEL
486#endif  // TENSORFLOW_USE_SYCL
487
488// Special GPU kernels for int32 and string.
489// TODO(b/25387198): Also enable int32 in device memory. This kernel
490// registration requires all int32 inputs and outputs to be in host memory.
491#define REGISTER_GPU_HOST_KERNEL(type)                    \
492  REGISTER_KERNEL_BUILDER(Name("Exit")                    \
493                              .Device(DEVICE_GPU)         \
494                              .HostMemory("data")         \
495                              .HostMemory("output")       \
496                              .TypeConstraint<type>("T"), \
497                          ExitOp);                        \
498  REGISTER_KERNEL_BUILDER(Name("RefExit")                 \
499                              .Device(DEVICE_GPU)         \
500                              .HostMemory("data")         \
501                              .HostMemory("output")       \
502                              .TypeConstraint<type>("T"), \
503                          ExitOp)
504
505REGISTER_GPU_HOST_KERNEL(int32);
506REGISTER_GPU_HOST_KERNEL(string);
507
508#undef REGISTER_GPU_HOST_KERNEL
509
510void NextIterationOp::Compute(OpKernelContext* context) {
511  if (IsRefType(context->input_dtype(0))) {
512    context->forward_ref_input_to_ref_output(0, 0);
513  } else {
514    context->set_output(0, context->input(0));
515  }
516}
517
518REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_CPU),
519                        NextIterationOp);
520REGISTER_KERNEL_BUILDER(Name("RefNextIteration").Device(DEVICE_CPU),
521                        NextIterationOp);
522
523#define REGISTER_GPU_KERNEL(type)                                            \
524  REGISTER_KERNEL_BUILDER(                                                   \
525      Name("NextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"),    \
526      NextIterationOp);                                                      \
527  REGISTER_KERNEL_BUILDER(                                                   \
528      Name("RefNextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
529      NextIterationOp)
530
531TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
532REGISTER_GPU_KERNEL(bool);
533
534#undef REGISTER_GPU_KERNEL
535
536// Special GPU kernels for int32 and string.
537// TODO(b/25387198): Also enable int32 in device memory. This kernel
538// registration requires all int32 inputs and outputs to be in host memory.
539#define REGISTER_GPU_HOST_KERNEL(type)                    \
540  REGISTER_KERNEL_BUILDER(Name("NextIteration")           \
541                              .Device(DEVICE_GPU)         \
542                              .HostMemory("data")         \
543                              .HostMemory("output")       \
544                              .TypeConstraint<type>("T"), \
545                          NextIterationOp);               \
546  REGISTER_KERNEL_BUILDER(Name("RefNextIteration")        \
547                              .Device(DEVICE_GPU)         \
548                              .HostMemory("data")         \
549                              .HostMemory("output")       \
550                              .TypeConstraint<type>("T"), \
551                          NextIterationOp)
552
553REGISTER_GPU_HOST_KERNEL(int32);
554REGISTER_GPU_HOST_KERNEL(string);
555
556#undef REGISTER_GPU_HOST_KERNEL
557
558#ifdef TENSORFLOW_USE_SYCL
559#define REGISTER_SYCL_KERNEL(type)                                            \
560  REGISTER_KERNEL_BUILDER(                                                    \
561      Name("NextIteration").Device(DEVICE_SYCL).TypeConstraint<type>("T"),    \
562      NextIterationOp);                                                       \
563  REGISTER_KERNEL_BUILDER(                                                    \
564      Name("RefNextIteration").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
565      NextIterationOp)
566REGISTER_SYCL_KERNEL(bool);
567TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
568
569#undef REGISTER_SYCL_KERNEL
570
571#define REGISTER_SYCL_HOST_KERNEL(type)                   \
572  REGISTER_KERNEL_BUILDER(Name("NextIteration")           \
573                              .Device(DEVICE_SYCL)        \
574                              .HostMemory("data")         \
575                              .HostMemory("output")       \
576                              .TypeConstraint<type>("T"), \
577                          NextIterationOp);               \
578  REGISTER_KERNEL_BUILDER(Name("RefNextIteration")        \
579                              .Device(DEVICE_SYCL)        \
580                              .HostMemory("data")         \
581                              .HostMemory("output")       \
582                              .TypeConstraint<type>("T"), \
583                          NextIterationOp)
584
585REGISTER_SYCL_HOST_KERNEL(int32);
586REGISTER_SYCL_HOST_KERNEL(string);
587#undef REGISTER_SYCL_HOST_KERNEL
588#endif  // TENSORFLOW_USE_SYCL
589
590// A LoopCond op has one input and one output. The input is a boolean
591// scalar representing the taken branches of the "pivot" Switch that
592// determines loop termination. As a contract, any high-level front-end
593// should always use port '0' of the "pivot" switches for loop exit.
594class LoopCondOp : public OpKernel {
595 public:
596  explicit LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {}
597
598  void Compute(OpKernelContext* context) override {
599    context->set_output(0, context->input(0));
600  }
601
602  bool IsExpensive() override { return false; }
603
604  ~LoopCondOp() override {}
605
606  TF_DISALLOW_COPY_AND_ASSIGN(LoopCondOp);
607};
608
609REGISTER_KERNEL_BUILDER(Name("LoopCond").Device(DEVICE_CPU), LoopCondOp);
610REGISTER_KERNEL_BUILDER(Name("LoopCond")
611                            .Device(DEVICE_GPU)
612                            .HostMemory("input")
613                            .HostMemory("output"),
614                        LoopCondOp);
615
616#ifdef TENSORFLOW_USE_SYCL
617REGISTER_KERNEL_BUILDER(Name("LoopCond")
618                            .Device(DEVICE_SYCL)
619                            .HostMemory("input")
620                            .HostMemory("output"),
621                        LoopCondOp);
622#endif  // TENSORFLOW_USE_SYCL
623
624// ControlTrigger kernels
625REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_CPU),
626                        ControlTriggerOp);
627
628REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_GPU),
629                        ControlTriggerOp);
630
631#ifdef TENSORFLOW_USE_SYCL
632REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_SYCL),
633                        ControlTriggerOp);
634#endif  // TENSORFLOW_USE_SYCL
635
636// When called, abort op will abort the current process. This can be used to
637// abort remote PSs when needed.
638class AbortOp : public OpKernel {
639 public:
640  explicit AbortOp(OpKernelConstruction* context) : OpKernel(context) {
641    OP_REQUIRES_OK(context, context->GetAttr("error_msg", &error_msg_));
642    OP_REQUIRES_OK(
643        context, context->GetAttr("exit_without_error", &exit_without_error_));
644  }
645
646  void Compute(OpKernelContext* context) override {
647    if (!exit_without_error_) {
648      LOG(FATAL) << "Abort_op intentional failure; " << error_msg_;
649    } else {
650      LOG(WARNING) << "Exiting the process: " << error_msg_;
651      exit(0);
652    }
653  }
654
655 private:
656  string error_msg_;
657  bool exit_without_error_;
658};
659
660REGISTER_KERNEL_BUILDER(Name("Abort").Device(DEVICE_CPU), AbortOp);
661
662}  // namespace tensorflow
663