kernel.h revision f41959ccb2d9d4c722fe8fc3351401d53bcf4900
1// Suite of datatypes to represent data-parallel kernel objects (code entities).
2// Kernel is the untyped variant, whereas TypedKernel takes a type signature
3// to do some template-based helper generation and give compile-time type
4// checking for kernel launch parameters.
5//
6// Users typically don't see KernelBase, they see typed kernels, analogous to a
7// typed function pointer. TypedKernels express their argument types via
8// template parameters like so:
9//
10//  TypedKernel<DeviceMemory<int>*, int>
11//
12// Which expresses a data parallel kernel signature for:
13//
14//  void(int*, int);
15//
16// And for a const memory region:
17//
18//  TypedKernel<const DeviceMemory<int>&, int>
19//
20// Corresponds to a data parallel kernel signature for:
21//
22//  void(const int*, int)
23//
24// Note that kernels always have a void return type, so results typically must
25// be memcpy'ied from device memory to the host.
26//
27// Also note that a scalar integer residing in device memory and an array of
28// integers residing in device memory have the same signature: DeviceMemory<T>.
29// However, in the future, checks may be added for additional safety that arrays
30// of minimum sizes are passed when those minimum sizes are contractually
31// expected by the kernel.
32//
33// For user-defined types whose definitions are appropriately shared between the
34// host code doing the launching and the kernel code being launched, the user
35// defined types are similarly permitted to be expressed as residing in device
36// memory:
37//
38//  TypedKernel<DeviceMemory<MyUserDefinedStructure>>
39//
40// And, when the alignment and padding are agreed upon, POD types will also be
41// able to be passed by value; for example, it is a common idiom to specify a
42// bunch of options simultaneously with a structure:
43//
44//  TypedKernel<MyOptionsStructurePassedByValue, DeviceMemory<float>>
45//
46// Which corresponds to a data parallel kernel signature like:
47//
48//  void(MyOptionsStructurePassedByValue value, float *result);
49//
50// Users typically won't need to type out the TypedKernel signature in full, it
51// will be typedef'd by automatically generated code; for example, see
52// perftools::gputools::executor_sample::VecReduceAddKernel.
53
54#ifndef TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
55#define TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
56
57#include <memory>
58#include <tuple>
59#include <type_traits>
60#include <vector>
61
62#include "tensorflow/stream_executor/device_memory.h"
63#include "tensorflow/stream_executor/kernel_cache_config.h"
64#include "tensorflow/stream_executor/lib/stringpiece.h"
65#include "tensorflow/stream_executor/platform/port.h"
66#include "tensorflow/stream_executor/lib/inlined_vector.h"
67
68namespace perftools {
69namespace gputools {
70
71class DeviceMemoryBase;
72template <typename ElemT>
73class DeviceMemory;
74class StreamExecutor;
75
76namespace internal {
77class KernelInterface;
78}  // namespace internal
79
80// KernelMetadata holds runtime-queryable attributes of a loaded kernel, such as
81// registers allocated, shared memory used, etc.
82// Not all platforms support reporting of all information, so each accessor
83// returns false if the associated field is not populated in the underlying
84// platform.
85class KernelMetadata {
86 public:
87  KernelMetadata()
88      : has_registers_per_thread_(false), has_shared_memory_bytes_(false) {}
89
90  // Returns the number of registers used per thread executing this kernel.
91  bool registers_per_thread(int *registers_per_thread) const;
92
93  // Sets the number of registers used per thread executing this kernel.
94  void set_registers_per_thread(int registers_per_thread);
95
96  // Returns the amount of [static] shared memory used per block executing this
97  // kernel. Note that dynamic shared memory allocations are not (and can not)
98  // be reported here (since they're not specified until kernel launch time).
99  bool shared_memory_bytes(int *shared_memory_bytes) const;
100
101  // Sets the amount of [static] shared memory used per block executing this
102  // kernel.
103  void set_shared_memory_bytes(int shared_memory_bytes);
104
105 private:
106  // Holds the value returned by registers_per_thread above.
107  bool has_registers_per_thread_;
108  int registers_per_thread_;
109
110  // Holds the value returned by shared_memory_bytes above.
111  bool has_shared_memory_bytes_;
112  int64 shared_memory_bytes_;
113};
114
115// A data-parallel kernel (code entity) for launching via the StreamExecutor,
116// analogous to a void* device function pointer. See TypedKernel for the typed
117// variant.
118//
119// Thread-compatible.
120class KernelBase {
121 public:
122  // Constructs an "empty" (not-yet-loaded) kernel instance.
123  //
124  // parent is the StreamExecutor that will be responsible for loading the
125  // implementation of this kernel. It must not be null.
126  explicit KernelBase(StreamExecutor *parent);
127
128  // Test-only constructor that can take a mock KernelInterface implementation.
129  // Takes ownership of implementation, it should not be null.
130  KernelBase(StreamExecutor *parent, internal::KernelInterface *implementation);
131
132  // Releases resources associated with the kernel instance (i.e.
133  // platform-specific implementation).
134  ~KernelBase();
135
136  // Returns the number of parameters that this kernel accepts. (Arity refers to
137  // nullary, unary, ...).
138  unsigned Arity() const;
139
140  // Returns the StreamExecutor that represents the platform this kernel
141  // executes upon.
142  StreamExecutor *parent() const { return parent_; }
143
144  // Returns a const pointer to the (opaque) platform-dependent implementation.
145  const internal::KernelInterface *implementation() const {
146    return implementation_.get();
147  }
148
149  // Returns a non-const pointer to the (opaque) platform-dependent
150  // implementation.
151  internal::KernelInterface *implementation() { return implementation_.get(); }
152
153  void set_metadata(const KernelMetadata &metadata) { metadata_ = metadata; }
154
155  const KernelMetadata &metadata() const { return metadata_; }
156
157  // Sets the preferred cache configuration for a kernel. This is just a
158  // suggestion to the runtime, and may not be honored during execution.
159  void SetPreferredCacheConfig(KernelCacheConfig config);
160
161  // Gets the preferred cache configuration for a kernel.
162  KernelCacheConfig GetPreferredCacheConfig() const;
163
164  void set_name(port::StringPiece name);
165  const string &name() const { return name_; }
166  const string &demangled_name() const { return demangled_name_; }
167
168 private:
169  // Implementation delegated to for platform-specific functionality.
170  std::unique_ptr<internal::KernelInterface> implementation_;
171
172  // The StreamExecutor that loads this kernel object.
173  StreamExecutor *parent_;
174
175  string name_;
176  string demangled_name_;
177
178  KernelMetadata metadata_;
179
180  SE_DISALLOW_COPY_AND_ASSIGN(KernelBase);
181};
182
183// Whether T is a DeviceMemory-family pointer.
184template <typename T>
185struct IsDeviceMemoryPointer {
186  static constexpr bool value = false;
187};
188
189template <typename U>
190struct IsDeviceMemoryPointer<DeviceMemory<U> *> {
191  static constexpr bool value = true;
192};
193
194template <>
195struct IsDeviceMemoryPointer<DeviceMemoryBase *> {
196  static constexpr bool value = true;
197};
198
199// Whether T is a DeviceMemory-family value-like thing (which includes a
200// reference). This trait is useful because we pack values in the same manner as
201// references.
202template <typename T>
203struct IsDeviceMemoryValueLike {
204  static constexpr bool value = false;
205};
206
207template <typename U>
208struct IsDeviceMemoryValueLike<DeviceMemory<U> &> {
209  static constexpr bool value = true;
210};
211
212// We need to treat SharedDeviceMemory types differently than other DeviceMemory
213// types (since they maintain no allocations), hence these specializations.
214template <typename U>
215struct IsDeviceMemoryValueLike<SharedDeviceMemory<U> &> {
216  static constexpr bool value = false;
217};
218
219template <>
220struct IsDeviceMemoryValueLike<DeviceMemoryBase &> {
221  static constexpr bool value = true;
222};
223
224template <typename U>
225struct IsDeviceMemoryValueLike<DeviceMemory<U>> {
226  static constexpr bool value = true;
227};
228
229template <typename U>
230struct IsDeviceMemoryValueLike<SharedDeviceMemory<U>> {
231  static constexpr bool value = false;
232};
233
234template <>
235struct IsDeviceMemoryValueLike<DeviceMemoryBase> {
236  static constexpr bool value = true;
237};
238
239template <typename U>
240struct IsSharedDeviceMemory {
241  static constexpr bool value = false;
242};
243
244template <typename U>
245struct IsSharedDeviceMemory<SharedDeviceMemory<U> &> {
246  static constexpr bool value = true;
247};
248
249template <typename U>
250struct IsSharedDeviceMemory<SharedDeviceMemory<U>> {
251  static constexpr bool value = true;
252};
253
254// KernelArg encapsulates the information necessary for a back-end executor to
255// configure a kernel to launch using the given argument.
256struct KernelArg {
257  // Indicates the type of an argument: normal, to be passed to the kernel
258  // in the standard manner, or shared memory, which has distinct
259  // rules for specification per backend.
260  enum Type {
261    kNormal,
262    kSharedMemory,
263  } type;
264
265  // The data to pass to the kernel - either a pointer to device memory, or the
266  // argument value. compact_array is used to prevent smaller args (ex. u8, u64)
267  // from requiring heap allocation.
268  port::InlinedVector<uint8, 4> data;
269
270  // The size of this argument in bytes.
271  uint64 bytes;
272};
273
274// Typed variant of KernelBase, like a typed device function pointer. See the
275// file comment for details and example usage.
276//
277// This class contains template metaprogramming magic to type check the
278// parameters passed to a kernel launch are acceptable, and subsequently pack
279// them into a form which can be used by the StreamExecutorInterface
280// implementation. (i.e.  CUDA and OpenCL both bind void*s with associated
281// sizes as kernel arguments.)
282//
283// Thread-compatible.
284template <typename... Params>
285class TypedKernel : public KernelBase {
286 public:
287  // Delegates to KernelBase::KernelBase(), see that constructor.
288  explicit TypedKernel(StreamExecutor *parent) : KernelBase(parent) {}
289
290  // Test-only constructor that can take a mock KernelInterface implementation.
291  // Takes ownership of implementation, it should not be null.
292  TypedKernel(StreamExecutor *parent, internal::KernelInterface *implementation)
293      : KernelBase(parent, implementation) {}
294
295 private:
296  // Stream needs access to the specific parameter-packing functionality that
297  // the TypedKernel provides for its corresponding type signature (and no other
298  // type signatures).
299  friend class Stream;
300
301  // This is the main entry point into the magic. Packs the parameters (which
302  // must type check against the class template) into the args and sizes
303  // arrays.
304  //
305  // Const refs are taken as parameters on all of the handlers to avoid
306  // implicit type promotion of integers.
307  void PackParams(std::vector<KernelArg> *args, Params... params) const {
308    PackOneParam(args, params...);
309  }
310
311  template <typename T, typename... RestOfParams>
312  void PackOneParam(std::vector<KernelArg> *args, const T &arg,
313                    const RestOfParams... rest) const {
314    PackOneParam(args, arg);
315    PackOneParam(args, rest...);
316  }
317
318  // Packs one (non-DeviceMemoryBase) parameter into the arg and sizes array.
319  // The enable_if<> is for excluding DeviceMemoryBase args, which have a
320  // separate implementation below.
321  template <typename T>
322  void PackOneParam(
323      std::vector<KernelArg> *args, const T &arg,
324      typename std::enable_if<!IsDeviceMemoryValueLike<T>::value &&
325                              !IsDeviceMemoryPointer<T>::value &&
326                              !IsSharedDeviceMemory<T>::value>::type * =
327          nullptr) const {
328    static_assert(!std::is_pointer<T>::value,
329                  "cannot pass raw pointer to the device");
330    static_assert(!std::is_convertible<T, DeviceMemoryBase>::value,
331                  "cannot pass device memory as a normal value");
332    const uint8 *arg_ptr = reinterpret_cast<const uint8 *>(&arg);
333    args->emplace_back(KernelArg{
334        KernelArg::kNormal,
335        port::InlinedVector<uint8, 4>{arg_ptr, arg_ptr + sizeof(arg)}, sizeof(arg)});
336  }
337
338  // DeviceMemoryBase family reference override.
339  template <typename T>
340  void PackOneParam(
341      std::vector<KernelArg> *args, const T &arg,
342      typename std::enable_if<IsDeviceMemoryValueLike<T>::value>::type * =
343          nullptr) const {
344    args->emplace_back(parent()->DeviceMemoryToKernelArg(arg));
345  }
346
347  // DeviceMemoryBase family pointer override.
348  template <typename T>
349  void PackOneParam(
350      std::vector<KernelArg> *args, T arg,
351      typename std::enable_if<IsDeviceMemoryPointer<T>::value>::type * =
352          nullptr) const {
353    DeviceMemoryBase *ptr = static_cast<DeviceMemoryBase *>(arg);
354    args->emplace_back(parent()->DeviceMemoryToKernelArg(*ptr));
355  }
356
357  // Dynamic shared device memory has a size, but no associated allocation on
358  // the host; internally, the device will allocate storage.
359  template <typename T>
360  void PackOneParam(
361      std::vector<KernelArg> *args, T arg,
362      typename std::enable_if<IsSharedDeviceMemory<T>::value>::type * =
363          nullptr) const {
364    args->emplace_back(KernelArg{KernelArg::kSharedMemory,
365                                 port::InlinedVector<uint8, 4>(), arg.size()});
366  }
367
368  // Base case for variadic template expansion - nothing to do!
369  void PackOneParam(std::vector<KernelArg> *args) const {}
370
371  SE_DISALLOW_COPY_AND_ASSIGN(TypedKernel);
372};
373
374// Template metaprogramming helper type that helps us produce better error
375// messages at compile time when the are mismatches between the parameter
376// type list and the argument type list.
377template <typename ParamTuple, typename ArgTuple>
378struct KernelInvocationChecker {
379  // Whether the parameter tuple and argument tuple match in length.
380  static constexpr bool kLengthMatches =
381      std::tuple_size<ParamTuple>::value == std::tuple_size<ArgTuple>::value;
382
383  // The (matching) length of the parameters and arguments type lists.
384  static constexpr int kTupleLength =
385      static_cast<int>(std::tuple_size<ArgTuple>::value);
386
387  // Helper trait to say whether the parameter wants a DeviceMemory-reference
388  // compatible type. This is for inexact type matches, so that it doesn't have
389  // to be precisely a const DeviceMemory<T>&, but can also be a value that
390  // represents the same.
391  template <typename ParamType, typename ArgType>
392  struct IsCompatibleDeviceMemoryRef {
393    static constexpr bool value = false;
394  };
395
396  // See type trait definition above.
397  template <typename U>
398  struct IsCompatibleDeviceMemoryRef<const DeviceMemory<U> &, DeviceMemory<U>> {
399    static constexpr bool value = true;
400  };
401
402  // See type trait definition above.
403  template <typename U>
404  struct IsCompatibleDeviceMemoryRef<const SharedDeviceMemory<U> &,
405                                     SharedDeviceMemory<U>> {
406    static constexpr bool value = true;
407  };
408
409  // Returns whether ParamT and ArgT are compatible for data parallel kernel
410  // parameter packing without any assert functionality.
411  template <typename ParamT, typename ArgT>
412  static constexpr bool CompatibleNoAssert() {
413    return std::is_same<typename std::remove_const<ParamT>::type,
414                        ArgT>::value ||
415           IsCompatibleDeviceMemoryRef<ParamT, ArgT>::value;
416  }
417
418  // Checks whether ParamT and ArgT are compatible for data parallel kernel
419  // parameter packing. kArgumentNumber is unused, it just for error display.
420  //
421  // NOTE: if you encounter an error here, you can see the mismatch by looking
422  // at the end of the last error message, which will be of the form:
423  //
424  //    ...::Compatible<const perftools::gputools::DeviceMemory<OneThing> &,
425  //                    perftools::gputools::DeviceMemory<AnotherThing>, true,
426  //                    0>'
427  //    requested here
428  //
429  // This means that the 0th argument you passed to the kernel invocation should
430  // have been DeviceMemory<OneThing> but was observed to be
431  // DeviceMemory<AnotherThing>.
432  template <typename ParamT, typename ArgT, bool kShouldStaticAssert,
433            int kArgumentNumber>
434  static constexpr bool Compatible() {
435    static_assert(
436        kShouldStaticAssert ? CompatibleNoAssert<ParamT, ArgT>() : true,
437        "parameter type (LHS) is not compatible with argument type (RHS)");
438    return CompatibleNoAssert<ParamT, ArgT>();
439  }
440
441  // Checks the parameter/argument match at kArgumentNumber for an out of bounds
442  // argument number.
443  //
444  // This is the base case: we've run out of argument to check, so we're all
445  // good.
446  template <int kArgumentNumber, bool kShouldStaticAssert>
447  static constexpr bool CheckParam(
448      typename std::enable_if<(kArgumentNumber < 0)>::type *dummy = nullptr) {
449    return true;
450  }
451
452  // Checks the parameter/argument match at kArgumentNumber.
453  // kShouldStaticAssert determines whether to assert out on a mismatch, or just
454  // yield the constexpr boolean value.
455  template <int kArgumentNumber, bool kShouldStaticAssert>
456  static constexpr bool CheckParam(
457      typename std::enable_if<kArgumentNumber >= 0>::type *dummy = nullptr) {
458    typedef typename std::tuple_element<kArgumentNumber, ParamTuple>::type
459        ParamT;
460    typedef typename std::tuple_element<kArgumentNumber, ArgTuple>::type ArgT;
461    return Compatible<ParamT, ArgT, kShouldStaticAssert, kArgumentNumber>() &&
462           CheckParam<kArgumentNumber - 1, kShouldStaticAssert>();
463  }
464
465  // Checks the parameters/arguments for match, but doesn't static assert out.
466  // This is useful for testing/inspecting whether a set of parameters match in
467  // things like tests.
468  static constexpr bool CheckAllNoStaticAssert() {
469    return kLengthMatches && CheckParam<kTupleLength - 1, false>();
470  }
471
472  // Checks the parameters and static asserts out with a helpful error message
473  // (and useful template parameters in the instantiation stack) if there is an
474  // error.
475  static constexpr bool CheckAllStaticAssert() {
476    static_assert(kLengthMatches,
477                  "argument length mismatched against typed kernel parameters");
478    return kLengthMatches && CheckParam<kTupleLength - 1, true>();
479  }
480};
481
482// This is a convenience type for checking whether a typed kernel matches
483// against a type list.
484template <typename KernelT, typename... Params>
485struct KernelParamsOk {
486  static constexpr bool kResult = false;
487};
488
489// See above.
490template <typename... Params, typename... Args>
491struct KernelParamsOk<TypedKernel<Params...>, Args...> {
492  static constexpr bool kResult = KernelInvocationChecker<
493      std::tuple<Params...>, std::tuple<Args...>>::CheckAllNoStaticAssert();
494};
495
496}  // namespace gputools
497}  // namespace perftools
498
499#endif  // TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
500