1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Suite of datatypes to represent data-parallel kernel objects (code entities).
17// Kernel is the untyped variant, whereas TypedKernel takes a type signature
18// to do some template-based helper generation and give compile-time type
19// checking for kernel launch parameters.
20//
21// Users typically don't see KernelBase, they see typed kernels, analogous to a
22// typed function pointer. TypedKernels express their argument types via
23// template parameters like so:
24//
25//  TypedKernel<DeviceMemory<int>*, int>
26//
27// Which expresses a data parallel kernel signature for:
28//
29//  void(int*, int);
30//
31// And for a const memory region:
32//
33//  TypedKernel<const DeviceMemory<int>&, int>
34//
35// Corresponds to a data parallel kernel signature for:
36//
37//  void(const int*, int)
38//
39// Note that kernels always have a void return type, so results typically must
40// be memcpy'ied from device memory to the host.
41//
42// Also note that a scalar integer residing in device memory and an array of
43// integers residing in device memory have the same signature: DeviceMemory<T>.
44// However, in the future, checks may be added for additional safety that arrays
45// of minimum sizes are passed when those minimum sizes are contractually
46// expected by the kernel.
47//
48// For user-defined types whose definitions are appropriately shared between the
49// host code doing the launching and the kernel code being launched, the user
50// defined types are similarly permitted to be expressed as residing in device
51// memory:
52//
53//  TypedKernel<DeviceMemory<MyUserDefinedStructure>>
54//
55// And, when the alignment and padding are agreed upon, POD types will also be
56// able to be passed by value; for example, it is a common idiom to specify a
57// bunch of options simultaneously with a structure:
58//
59//  TypedKernel<MyOptionsStructurePassedByValue, DeviceMemory<float>>
60//
61// Which corresponds to a data parallel kernel signature like:
62//
63//  void(MyOptionsStructurePassedByValue value, float *result);
64//
65// Users typically won't need to type out the TypedKernel signature in full, it
66// will be typedef'd by automatically generated code; for example, see
67// perftools::gputools::executor_sample::VecReduceAddKernel.
68
69#ifndef TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
70#define TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
71
72#include <array>
73#include <memory>
74#include <tuple>
75#include <type_traits>
76#include <vector>
77
78#include "tensorflow/stream_executor/device_memory.h"
79#include "tensorflow/stream_executor/kernel_cache_config.h"
80#include "tensorflow/stream_executor/lib/array_slice.h"
81#include "tensorflow/stream_executor/lib/inlined_vector.h"
82#include "tensorflow/stream_executor/lib/stringpiece.h"
83#include "tensorflow/stream_executor/platform/port.h"
84
85namespace perftools {
86namespace gputools {
87
88class DeviceMemoryBase;
89template <typename ElemT>
90class DeviceMemory;
91class StreamExecutor;
92
93namespace internal {
94class KernelInterface;
95}  // namespace internal
96
97// KernelMetadata holds runtime-queryable attributes of a loaded kernel, such as
98// registers allocated, shared memory used, etc.
99// Not all platforms support reporting of all information, so each accessor
100// returns false if the associated field is not populated in the underlying
101// platform.
102class KernelMetadata {
103 public:
104  KernelMetadata()
105      : has_registers_per_thread_(false), has_shared_memory_bytes_(false) {}
106
107  // Returns the number of registers used per thread executing this kernel.
108  bool registers_per_thread(int *registers_per_thread) const;
109
110  // Sets the number of registers used per thread executing this kernel.
111  void set_registers_per_thread(int registers_per_thread);
112
113  // Returns the amount of [static] shared memory used per block executing this
114  // kernel. Note that dynamic shared memory allocations are not (and can not)
115  // be reported here (since they're not specified until kernel launch time).
116  bool shared_memory_bytes(int *shared_memory_bytes) const;
117
118  // Sets the amount of [static] shared memory used per block executing this
119  // kernel.
120  void set_shared_memory_bytes(int shared_memory_bytes);
121
122 private:
123  // Holds the value returned by registers_per_thread above.
124  bool has_registers_per_thread_;
125  int registers_per_thread_;
126
127  // Holds the value returned by shared_memory_bytes above.
128  bool has_shared_memory_bytes_;
129  int64 shared_memory_bytes_;
130};
131
132// A data-parallel kernel (code entity) for launching via the StreamExecutor,
133// analogous to a void* device function pointer. See TypedKernel for the typed
134// variant.
135//
136// Thread-compatible.
137class KernelBase {
138 public:
139  KernelBase(KernelBase &&from);
140
141  // Constructs an "empty" (not-yet-loaded) kernel instance.
142  //
143  // parent is the StreamExecutor that will be responsible for loading the
144  // implementation of this kernel. It must not be null.
145  explicit KernelBase(StreamExecutor *parent);
146
147  // Test-only constructor that can take a mock KernelInterface implementation.
148  KernelBase(StreamExecutor *parent, internal::KernelInterface *implementation);
149
150  // Releases resources associated with the kernel instance (i.e.
151  // platform-specific implementation).
152  ~KernelBase();
153
154  // Returns the number of parameters that this kernel accepts. (Arity refers to
155  // nullary, unary, ...).
156  unsigned Arity() const;
157
158  // Returns the StreamExecutor that represents the platform this kernel
159  // executes upon.
160  StreamExecutor *parent() const { return parent_; }
161
162  // Returns a const pointer to the (opaque) platform-dependent implementation.
163  const internal::KernelInterface *implementation() const {
164    return implementation_.get();
165  }
166
167  // Returns a non-const pointer to the (opaque) platform-dependent
168  // implementation.
169  internal::KernelInterface *implementation() { return implementation_.get(); }
170
171  void set_metadata(const KernelMetadata &metadata) { metadata_ = metadata; }
172
173  const KernelMetadata &metadata() const { return metadata_; }
174
175  // Sets the preferred cache configuration for a kernel. This is just a
176  // suggestion to the runtime, and may not be honored during execution.
177  void SetPreferredCacheConfig(KernelCacheConfig config);
178
179  // Gets the preferred cache configuration for a kernel.
180  KernelCacheConfig GetPreferredCacheConfig() const;
181
182  void set_name(port::StringPiece name);
183  const string &name() const { return name_; }
184  const string &demangled_name() const { return demangled_name_; }
185
186 private:
187  // The StreamExecutor that loads this kernel object.
188  StreamExecutor *parent_;
189
190  // Implementation delegated to for platform-specific functionality.
191  std::unique_ptr<internal::KernelInterface> implementation_;
192
193  string name_;
194  string demangled_name_;
195
196  KernelMetadata metadata_;
197
198  SE_DISALLOW_COPY_AND_ASSIGN(KernelBase);
199};
200
201// Whether T is a DeviceMemory-family pointer.
202template <typename T>
203struct IsDeviceMemoryPointer {
204  static constexpr bool value = false;
205};
206
207template <typename U>
208struct IsDeviceMemoryPointer<DeviceMemory<U> *> {
209  static constexpr bool value = true;
210};
211
212template <>
213struct IsDeviceMemoryPointer<DeviceMemoryBase *> {
214  static constexpr bool value = true;
215};
216
217// Whether T is a DeviceMemory-family value-like thing (which includes a
218// reference). This trait is useful because we pack values in the same manner as
219// references.
220template <typename T>
221struct IsDeviceMemoryValueLike {
222  static constexpr bool value = false;
223};
224
225template <typename U>
226struct IsDeviceMemoryValueLike<DeviceMemory<U> &> {
227  static constexpr bool value = true;
228};
229
230// We need to treat SharedDeviceMemory types differently than other DeviceMemory
231// types (since they maintain no allocations), hence these specializations.
232template <typename U>
233struct IsDeviceMemoryValueLike<SharedDeviceMemory<U> &> {
234  static constexpr bool value = false;
235};
236
237template <>
238struct IsDeviceMemoryValueLike<DeviceMemoryBase &> {
239  static constexpr bool value = true;
240};
241
242template <typename U>
243struct IsDeviceMemoryValueLike<DeviceMemory<U>> {
244  static constexpr bool value = true;
245};
246
247template <typename U>
248struct IsDeviceMemoryValueLike<SharedDeviceMemory<U>> {
249  static constexpr bool value = false;
250};
251
252template <>
253struct IsDeviceMemoryValueLike<DeviceMemoryBase> {
254  static constexpr bool value = true;
255};
256
257template <typename U>
258struct IsSharedDeviceMemory {
259  static constexpr bool value = false;
260};
261
262template <typename U>
263struct IsSharedDeviceMemory<SharedDeviceMemory<U> &> {
264  static constexpr bool value = true;
265};
266
267template <typename U>
268struct IsSharedDeviceMemory<SharedDeviceMemory<U>> {
269  static constexpr bool value = true;
270};
271
272// Basic data about a kernel argument.
273struct KernelArg {
274  bool is_shared;
275  const void *address;
276  size_t size;
277};
278
279// An iterator for traversing all the arguments of a KernelArgsArray.
280class KernelArgIterator {
281 public:
282  KernelArgIterator(int number_of_argument_addresses,
283                    int number_of_shared_memory_arguments,
284                    const void *const *arg_addresses_data,
285                    const size_t *arg_sizes_data,
286                    const size_t *shmem_bytes_data,
287                    const size_t *shmem_indices_data)
288      : arg_index_(0),
289        number_of_arguments_(number_of_argument_addresses +
290                             number_of_shared_memory_arguments),
291        arg_address_iter_(arg_addresses_data),
292        arg_size_iter_(arg_sizes_data),
293        shmem_bytes_iter_(shmem_bytes_data),
294        shmem_indices_iter_(shmem_indices_data),
295        shmem_indices_end_(shmem_indices_data +
296                           number_of_shared_memory_arguments) {}
297
298  // Returns true if another argument is present in the iterator.
299  bool has_next() { return arg_index_ < number_of_arguments_; }
300
301  // Returns the next argument in the iterator.
302  //
303  // Returns a default-constructed KernelArg if there is no next argument.
304  KernelArg next() {
305    KernelArg result = {};
306    if (!has_next()) {
307      return result;
308    } else if ((shmem_indices_iter_ != shmem_indices_end_) &&
309               (arg_index_ == *shmem_indices_iter_)) {
310      result.is_shared = true;
311      result.address = nullptr;
312      result.size = *shmem_bytes_iter_;
313      ++shmem_indices_iter_;
314      ++shmem_bytes_iter_;
315    } else {
316      result.is_shared = false;
317      result.address = *arg_address_iter_;
318      result.size = *arg_size_iter_;
319      ++arg_address_iter_;
320      ++arg_size_iter_;
321    }
322    ++arg_index_;
323    return result;
324  }
325
326 private:
327  size_t arg_index_;
328  size_t number_of_arguments_;
329  const void *const *arg_address_iter_;
330  const size_t *arg_size_iter_;
331  const size_t *shmem_bytes_iter_;
332  const size_t *shmem_indices_iter_;
333  const size_t *const shmem_indices_end_;
334};
335
336// Base class for KernelArgsArray.
337//
338// Supports all the getter methods that do not depend on the compile-time number
339// of arguments template parameter.
340//
341// This class exists as a way to pass kernel arguments to
342// StreamExecutorInterface::Launch. That Launch method is virtual, so it can't
343// be templated to accept any KernelArgsArray type, therefore a reference to
344// this base type is passed instead.
345//
346// Performance is not a concern here because each of these methods will be
347// called at most once per kernel launch. Past performance concerns with
348// KernelArgsArray have been in reference to the argument packing routines which
349// are called once per kernel argument. Those packing routines are now handled
350// by the templated KernelArgsArray subclass of this class where they can take
351// advantage of compile-time knowledge of the number of arguments in order to be
352// very efficient.
353class KernelArgsArrayBase {
354 public:
355  virtual ~KernelArgsArrayBase() = default;
356
357  // Gets the number of arguments added so far, including shared memory
358  // arguments.
359  virtual size_t number_of_arguments() const = 0;
360
361  // Gets the total number of shared memory bytes added so far.
362  virtual uint64 number_of_shared_bytes() const = 0;
363
364  // Gets the list of argument addresses.
365  virtual port::ArraySlice<const void *> argument_addresses() const = 0;
366
367  // Gets an iterator to the arguments in the array.
368  virtual KernelArgIterator arg_iterator() const = 0;
369};
370
371// A list of arguments for a kernel call.
372//
373// The template parameter kNumArgs is the maximum number of arguments which can
374// be stored in the list.
375//
376// Contains a list of addresses for non-shared-memory arguments and a list of
377// sizes for shared-memory arguments. Since the shared-memory arguments may be
378// interspersed with the non-shared-memory arguments, it also stores a list of
379// the indices at which the shared-memory arguments appeared.
380//
381// For example, if the argument address list contains {a, b, c, d, e}, the
382// shared-memory arguments list contains the sizes of {A, B, C}, and the
383// shared-memory indices list contains {0, 3, 5}, then the original list of
384// arguments was {A, a, b, B, c, C, d, e}.
385//
386// This way of storing the arguments makes CUDA kernel calls efficient because
387// they only require the argument address list and the total number of shared
388// bytes, but it also makes it possible for OpenCL kernel calls because they
389// depend on the location of each shared-memory argument and its size.
390//
391// Note that the code for adding arguments has been identified as a performance
392// hotspot in some real-world applications so this structure has been optimized
393// for the performance of argument adding.
394template <size_t kNumArgs>
395class KernelArgsArray : public KernelArgsArrayBase {
396 public:
397  explicit KernelArgsArray()
398      : total_shared_memory_bytes_(0),
399        number_of_argument_addresses_(0),
400        number_of_shared_memory_arguments_(0) {}
401
402  // Adds an argument to the list.
403  //
404  // Note that the address of the argument is stored, so the input must not go
405  // out of scope before the instance of this class that calls this method does.
406  template <typename T>
407  void add_argument(const T &arg) {
408    argument_addresses_[number_of_argument_addresses_] =
409        static_cast<const void *>(&arg);
410    argument_sizes_[number_of_argument_addresses_] = sizeof(arg);
411    ++number_of_argument_addresses_;
412  }
413
414  // Adds a device memory argument to the list.
415  void add_device_memory_argument(const DeviceMemoryBase &arg) {
416    const void **copy_ptr =
417        &device_memory_opaque_pointers_[number_of_argument_addresses_];
418    *copy_ptr = arg.opaque();
419    argument_addresses_[number_of_argument_addresses_] = copy_ptr;
420    argument_sizes_[number_of_argument_addresses_] = sizeof(void *);
421    ++number_of_argument_addresses_;
422  }
423
424  // Adds a shared memory argument to the list.
425  //
426  // The only significant information about a shared argument is its size, so
427  // that is the only parameter in this function.
428  void add_shared_bytes(size_t number_of_bytes) {
429    shared_memory_indices_[number_of_shared_memory_arguments_] =
430        number_of_argument_addresses_ + number_of_shared_memory_arguments_;
431    shared_memory_bytes_[number_of_shared_memory_arguments_] = number_of_bytes;
432    ++number_of_shared_memory_arguments_;
433    total_shared_memory_bytes_ += number_of_bytes;
434  }
435
436  // Gets the number of arguments added so far, including shared memory
437  // arguments.
438  size_t number_of_arguments() const override {
439    return number_of_argument_addresses_ + number_of_shared_memory_arguments_;
440  }
441
442  // Gets the total number of shared memory bytes added so far.
443  uint64 number_of_shared_bytes() const override {
444    return total_shared_memory_bytes_;
445  }
446
447  // Gets the list of argument addresses.
448  port::ArraySlice<const void *> argument_addresses() const override {
449    return port::ArraySlice<const void *>(argument_addresses_.data(),
450                                          number_of_argument_addresses_);
451  }
452
453  // Gets an iterator to the arguments in the array.
454  KernelArgIterator arg_iterator() const override {
455    return KernelArgIterator(
456        number_of_argument_addresses_, number_of_shared_memory_arguments_,
457        argument_addresses_.data(), argument_sizes_.data(),
458        shared_memory_bytes_.data(), shared_memory_indices_.data());
459  }
460
461 private:
462  // A place to store copies of opaque pointers from device memory arguments.
463  std::array<const void *, kNumArgs> device_memory_opaque_pointers_;
464
465  // Addresses for non-shared-memory arguments.
466  std::array<const void *, kNumArgs> argument_addresses_;
467
468  // Sizes for non-shared-memory arguments.
469  std::array<size_t, kNumArgs> argument_sizes_;
470
471  // Size in bytes for each shared memory argument.
472  std::array<size_t, kNumArgs> shared_memory_bytes_;
473
474  // Indices in the arguments array for shared memory arguments.
475  std::array<size_t, kNumArgs> shared_memory_indices_;
476
477  // Total of all shared memory sizes.
478  size_t total_shared_memory_bytes_;
479
480  // Number of significant entries in argument_addresses_ and argument_sizes_.
481  size_t number_of_argument_addresses_;
482
483  // Number of significant entries in shared_memory_bytes_ and
484  // shared_memory_indices_.
485  size_t number_of_shared_memory_arguments_;
486};
487
488// Typed variant of KernelBase, like a typed device function pointer. See the
489// file comment for details and example usage.
490//
491// This class contains template metaprogramming magic to type check the
492// parameters passed to a kernel launch are acceptable, and subsequently pack
493// them into a form which can be used by the StreamExecutorInterface
494// implementation. (i.e.  CUDA and OpenCL both bind void*s with associated
495// sizes as kernel arguments.)
496//
497// Thread-compatible.
498template <typename... Params>
499class TypedKernel : public KernelBase {
500 public:
501  static constexpr size_t kNumberOfParameters = sizeof...(Params);
502
503  // Delegates to KernelBase::KernelBase(), see that constructor.
504  explicit TypedKernel(StreamExecutor *parent) : KernelBase(parent) {}
505
506  // Test-only constructor that can take a mock KernelInterface implementation.
507  // Takes ownership of implementation, it should not be null.
508  TypedKernel(StreamExecutor *parent, internal::KernelInterface *implementation)
509      : KernelBase(parent, implementation) {}
510
511 private:
512  // Stream needs access to the specific parameter-packing functionality that
513  // the TypedKernel provides for its corresponding type signature (and no other
514  // type signatures).
515  friend class Stream;
516
517  // This is the main entry point into the magic. Packs the parameters (which
518  // must type check against the class template) into the args and sizes
519  // arrays.
520  //
521  // Const refs are taken as parameters on all of the handlers to avoid
522  // implicit type promotion of integers.
523  //
524  // WARNING: as a performance optimization this method may store pointers to
525  // some of the input parameters in the kernel args structure, so any params
526  // passed into this method must live at least as long as the kernel args
527  // structure.
528  void PackParams(KernelArgsArray<kNumberOfParameters> *args,
529                  Params &... params) const {
530    PackOneParam(args, params...);
531  }
532
533  template <typename T, typename... RestOfParams>
534  void PackOneParam(KernelArgsArray<kNumberOfParameters> *args, const T &arg,
535                    const RestOfParams &... rest) const {
536    PackOneParam(args, arg);
537    PackOneParam(args, rest...);
538  }
539
540  // Packs one (non-DeviceMemoryBase) parameter into the arg and sizes array.
541  // The enable_if<> is for excluding DeviceMemoryBase args, which have a
542  // separate implementation below.
543  template <typename T>
544  void PackOneParam(
545      KernelArgsArray<kNumberOfParameters> *args, const T &arg,
546      typename std::enable_if<!IsDeviceMemoryValueLike<T>::value &&
547                              !IsDeviceMemoryPointer<T>::value &&
548                              !IsSharedDeviceMemory<T>::value>::type * =
549          nullptr) const {
550    static_assert(!std::is_pointer<T>::value,
551                  "cannot pass raw pointer to the device");
552    static_assert(!std::is_convertible<T, DeviceMemoryBase>::value,
553                  "cannot pass device memory as a normal value");
554    args->add_argument(arg);
555  }
556
557  // DeviceMemoryBase family reference override.
558  template <typename T>
559  void PackOneParam(
560      KernelArgsArray<kNumberOfParameters> *args, const T &arg,
561      typename std::enable_if<IsDeviceMemoryValueLike<T>::value>::type * =
562          nullptr) const {
563    args->add_device_memory_argument(arg);
564  }
565
566  // DeviceMemoryBase family pointer override.
567  template <typename T>
568  void PackOneParam(
569      KernelArgsArray<kNumberOfParameters> *args, T arg,
570      typename std::enable_if<IsDeviceMemoryPointer<T>::value>::type * =
571          nullptr) const {
572    DeviceMemoryBase *ptr = static_cast<DeviceMemoryBase *>(arg);
573    args->add_device_memory_argument(*ptr);
574  }
575
576  // Dynamic shared device memory has a size, but no associated allocation on
577  // the host; internally, the device will allocate storage.
578  template <typename T>
579  void PackOneParam(
580      KernelArgsArray<kNumberOfParameters> *args, T arg,
581      typename std::enable_if<IsSharedDeviceMemory<T>::value>::type * =
582          nullptr) const {
583    args->add_shared_bytes(arg.size());
584  }
585
586  // Base case for variadic template expansion - nothing to do!
587  void PackOneParam(KernelArgsArray<kNumberOfParameters> *args) const {}
588
589  SE_DISALLOW_COPY_AND_ASSIGN(TypedKernel);
590};
591
592// Template metaprogramming helper type that helps us produce better error
593// messages at compile time when the are mismatches between the parameter
594// type list and the argument type list.
595template <typename ParamTuple, typename ArgTuple>
596struct KernelInvocationChecker {
597  // Whether the parameter tuple and argument tuple match in length.
598  static constexpr bool kLengthMatches =
599      std::tuple_size<ParamTuple>::value == std::tuple_size<ArgTuple>::value;
600
601  // The (matching) length of the parameters and arguments type lists.
602  static constexpr int kTupleLength =
603      static_cast<int>(std::tuple_size<ArgTuple>::value);
604
605  // Helper trait to say whether the parameter wants a DeviceMemory-reference
606  // compatible type. This is for inexact type matches, so that it doesn't have
607  // to be precisely a const DeviceMemory<T>&, but can also be a value that
608  // represents the same.
609  template <typename ParamType, typename ArgType>
610  struct IsCompatibleDeviceMemoryRef {
611    static constexpr bool value = false;
612  };
613
614  // See type trait definition above.
615  template <typename U>
616  struct IsCompatibleDeviceMemoryRef<const DeviceMemory<U> &, DeviceMemory<U>> {
617    static constexpr bool value = true;
618  };
619
620  // See type trait definition above.
621  template <typename U>
622  struct IsCompatibleDeviceMemoryRef<const SharedDeviceMemory<U> &,
623                                     SharedDeviceMemory<U>> {
624    static constexpr bool value = true;
625  };
626
627  // Returns whether ParamT and ArgT are compatible for data parallel kernel
628  // parameter packing without any assert functionality.
629  template <typename ParamT, typename ArgT>
630  static constexpr bool CompatibleNoAssert() {
631    return std::is_same<typename std::remove_const<ParamT>::type,
632                        ArgT>::value ||
633           IsCompatibleDeviceMemoryRef<ParamT, ArgT>::value;
634  }
635
636  // Checks whether ParamT and ArgT are compatible for data parallel kernel
637  // parameter packing. kArgumentNumber is unused, it just for error display.
638  //
639  // NOTE: if you encounter an error here, you can see the mismatch by looking
640  // at the end of the last error message, which will be of the form:
641  //
642  //    ...::Compatible<const perftools::gputools::DeviceMemory<OneThing> &,
643  //                    perftools::gputools::DeviceMemory<AnotherThing>, true,
644  //                    0>'
645  //    requested here
646  //
647  // This means that the 0th argument you passed to the kernel invocation should
648  // have been DeviceMemory<OneThing> but was observed to be
649  // DeviceMemory<AnotherThing>.
650  template <typename ParamT, typename ArgT, bool kShouldStaticAssert,
651            int kArgumentNumber>
652  static constexpr bool Compatible() {
653    static_assert(
654        kShouldStaticAssert ? CompatibleNoAssert<ParamT, ArgT>() : true,
655        "parameter type (LHS) is not compatible with argument type (RHS)");
656    return CompatibleNoAssert<ParamT, ArgT>();
657  }
658
659  // Checks the parameter/argument match at kArgumentNumber for an out of bounds
660  // argument number.
661  //
662  // This is the base case: we've run out of argument to check, so we're all
663  // good.
664  template <int kArgumentNumber, bool kShouldStaticAssert>
665  static constexpr bool CheckParam(
666      typename std::enable_if<(kArgumentNumber < 0)>::type *dummy = nullptr) {
667    return true;
668  }
669
670  // Checks the parameter/argument match at kArgumentNumber.
671  // kShouldStaticAssert determines whether to assert out on a mismatch, or just
672  // yield the constexpr boolean value.
673  template <int kArgumentNumber, bool kShouldStaticAssert>
674  static constexpr bool CheckParam(
675      typename std::enable_if<kArgumentNumber >= 0>::type *dummy = nullptr) {
676    typedef typename std::tuple_element<kArgumentNumber, ParamTuple>::type
677        ParamT;
678    typedef typename std::tuple_element<kArgumentNumber, ArgTuple>::type ArgT;
679    return Compatible<ParamT, ArgT, kShouldStaticAssert, kArgumentNumber>() &&
680           CheckParam<kArgumentNumber - 1, kShouldStaticAssert>();
681  }
682
683  // Checks the parameters/arguments for match, but doesn't static assert out.
684  // This is useful for testing/inspecting whether a set of parameters match in
685  // things like tests.
686  static constexpr bool CheckAllNoStaticAssert() {
687    return kLengthMatches && CheckParam<kTupleLength - 1, false>();
688  }
689
690  // Checks the parameters and static asserts out with a helpful error message
691  // (and useful template parameters in the instantiation stack) if there is an
692  // error.
693  static constexpr bool CheckAllStaticAssert() {
694    static_assert(kLengthMatches,
695                  "argument length mismatched against typed kernel parameters");
696    return kLengthMatches && CheckParam<kTupleLength - 1, true>();
697  }
698};
699
700// This is a convenience type for checking whether a typed kernel matches
701// against a type list.
702template <typename KernelT, typename... Params>
703struct KernelParamsOk {
704  static constexpr bool kResult = false;
705};
706
707// See above.
708template <typename... Params, typename... Args>
709struct KernelParamsOk<TypedKernel<Params...>, Args...> {
710  static constexpr bool kResult = KernelInvocationChecker<
711      std::tuple<Params...>, std::tuple<Args...>>::CheckAllNoStaticAssert();
712};
713
714}  // namespace gputools
715}  // namespace perftools
716
717#endif  // TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
718