kernel.h revision f41959ccb2d9d4c722fe8fc3351401d53bcf4900
1// Suite of datatypes to represent data-parallel kernel objects (code entities). 2// Kernel is the untyped variant, whereas TypedKernel takes a type signature 3// to do some template-based helper generation and give compile-time type 4// checking for kernel launch parameters. 5// 6// Users typically don't see KernelBase, they see typed kernels, analogous to a 7// typed function pointer. TypedKernels express their argument types via 8// template parameters like so: 9// 10// TypedKernel<DeviceMemory<int>*, int> 11// 12// Which expresses a data parallel kernel signature for: 13// 14// void(int*, int); 15// 16// And for a const memory region: 17// 18// TypedKernel<const DeviceMemory<int>&, int> 19// 20// Corresponds to a data parallel kernel signature for: 21// 22// void(const int*, int) 23// 24// Note that kernels always have a void return type, so results typically must 25// be memcpy'ied from device memory to the host. 26// 27// Also note that a scalar integer residing in device memory and an array of 28// integers residing in device memory have the same signature: DeviceMemory<T>. 29// However, in the future, checks may be added for additional safety that arrays 30// of minimum sizes are passed when those minimum sizes are contractually 31// expected by the kernel. 32// 33// For user-defined types whose definitions are appropriately shared between the 34// host code doing the launching and the kernel code being launched, the user 35// defined types are similarly permitted to be expressed as residing in device 36// memory: 37// 38// TypedKernel<DeviceMemory<MyUserDefinedStructure>> 39// 40// And, when the alignment and padding are agreed upon, POD types will also be 41// able to be passed by value; for example, it is a common idiom to specify a 42// bunch of options simultaneously with a structure: 43// 44// TypedKernel<MyOptionsStructurePassedByValue, DeviceMemory<float>> 45// 46// Which corresponds to a data parallel kernel signature like: 47// 48// void(MyOptionsStructurePassedByValue value, float *result); 49// 50// Users typically won't need to type out the TypedKernel signature in full, it 51// will be typedef'd by automatically generated code; for example, see 52// perftools::gputools::executor_sample::VecReduceAddKernel. 53 54#ifndef TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_ 55#define TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_ 56 57#include <memory> 58#include <tuple> 59#include <type_traits> 60#include <vector> 61 62#include "tensorflow/stream_executor/device_memory.h" 63#include "tensorflow/stream_executor/kernel_cache_config.h" 64#include "tensorflow/stream_executor/lib/stringpiece.h" 65#include "tensorflow/stream_executor/platform/port.h" 66#include "tensorflow/stream_executor/lib/inlined_vector.h" 67 68namespace perftools { 69namespace gputools { 70 71class DeviceMemoryBase; 72template <typename ElemT> 73class DeviceMemory; 74class StreamExecutor; 75 76namespace internal { 77class KernelInterface; 78} // namespace internal 79 80// KernelMetadata holds runtime-queryable attributes of a loaded kernel, such as 81// registers allocated, shared memory used, etc. 82// Not all platforms support reporting of all information, so each accessor 83// returns false if the associated field is not populated in the underlying 84// platform. 85class KernelMetadata { 86 public: 87 KernelMetadata() 88 : has_registers_per_thread_(false), has_shared_memory_bytes_(false) {} 89 90 // Returns the number of registers used per thread executing this kernel. 91 bool registers_per_thread(int *registers_per_thread) const; 92 93 // Sets the number of registers used per thread executing this kernel. 94 void set_registers_per_thread(int registers_per_thread); 95 96 // Returns the amount of [static] shared memory used per block executing this 97 // kernel. Note that dynamic shared memory allocations are not (and can not) 98 // be reported here (since they're not specified until kernel launch time). 99 bool shared_memory_bytes(int *shared_memory_bytes) const; 100 101 // Sets the amount of [static] shared memory used per block executing this 102 // kernel. 103 void set_shared_memory_bytes(int shared_memory_bytes); 104 105 private: 106 // Holds the value returned by registers_per_thread above. 107 bool has_registers_per_thread_; 108 int registers_per_thread_; 109 110 // Holds the value returned by shared_memory_bytes above. 111 bool has_shared_memory_bytes_; 112 int64 shared_memory_bytes_; 113}; 114 115// A data-parallel kernel (code entity) for launching via the StreamExecutor, 116// analogous to a void* device function pointer. See TypedKernel for the typed 117// variant. 118// 119// Thread-compatible. 120class KernelBase { 121 public: 122 // Constructs an "empty" (not-yet-loaded) kernel instance. 123 // 124 // parent is the StreamExecutor that will be responsible for loading the 125 // implementation of this kernel. It must not be null. 126 explicit KernelBase(StreamExecutor *parent); 127 128 // Test-only constructor that can take a mock KernelInterface implementation. 129 // Takes ownership of implementation, it should not be null. 130 KernelBase(StreamExecutor *parent, internal::KernelInterface *implementation); 131 132 // Releases resources associated with the kernel instance (i.e. 133 // platform-specific implementation). 134 ~KernelBase(); 135 136 // Returns the number of parameters that this kernel accepts. (Arity refers to 137 // nullary, unary, ...). 138 unsigned Arity() const; 139 140 // Returns the StreamExecutor that represents the platform this kernel 141 // executes upon. 142 StreamExecutor *parent() const { return parent_; } 143 144 // Returns a const pointer to the (opaque) platform-dependent implementation. 145 const internal::KernelInterface *implementation() const { 146 return implementation_.get(); 147 } 148 149 // Returns a non-const pointer to the (opaque) platform-dependent 150 // implementation. 151 internal::KernelInterface *implementation() { return implementation_.get(); } 152 153 void set_metadata(const KernelMetadata &metadata) { metadata_ = metadata; } 154 155 const KernelMetadata &metadata() const { return metadata_; } 156 157 // Sets the preferred cache configuration for a kernel. This is just a 158 // suggestion to the runtime, and may not be honored during execution. 159 void SetPreferredCacheConfig(KernelCacheConfig config); 160 161 // Gets the preferred cache configuration for a kernel. 162 KernelCacheConfig GetPreferredCacheConfig() const; 163 164 void set_name(port::StringPiece name); 165 const string &name() const { return name_; } 166 const string &demangled_name() const { return demangled_name_; } 167 168 private: 169 // Implementation delegated to for platform-specific functionality. 170 std::unique_ptr<internal::KernelInterface> implementation_; 171 172 // The StreamExecutor that loads this kernel object. 173 StreamExecutor *parent_; 174 175 string name_; 176 string demangled_name_; 177 178 KernelMetadata metadata_; 179 180 SE_DISALLOW_COPY_AND_ASSIGN(KernelBase); 181}; 182 183// Whether T is a DeviceMemory-family pointer. 184template <typename T> 185struct IsDeviceMemoryPointer { 186 static constexpr bool value = false; 187}; 188 189template <typename U> 190struct IsDeviceMemoryPointer<DeviceMemory<U> *> { 191 static constexpr bool value = true; 192}; 193 194template <> 195struct IsDeviceMemoryPointer<DeviceMemoryBase *> { 196 static constexpr bool value = true; 197}; 198 199// Whether T is a DeviceMemory-family value-like thing (which includes a 200// reference). This trait is useful because we pack values in the same manner as 201// references. 202template <typename T> 203struct IsDeviceMemoryValueLike { 204 static constexpr bool value = false; 205}; 206 207template <typename U> 208struct IsDeviceMemoryValueLike<DeviceMemory<U> &> { 209 static constexpr bool value = true; 210}; 211 212// We need to treat SharedDeviceMemory types differently than other DeviceMemory 213// types (since they maintain no allocations), hence these specializations. 214template <typename U> 215struct IsDeviceMemoryValueLike<SharedDeviceMemory<U> &> { 216 static constexpr bool value = false; 217}; 218 219template <> 220struct IsDeviceMemoryValueLike<DeviceMemoryBase &> { 221 static constexpr bool value = true; 222}; 223 224template <typename U> 225struct IsDeviceMemoryValueLike<DeviceMemory<U>> { 226 static constexpr bool value = true; 227}; 228 229template <typename U> 230struct IsDeviceMemoryValueLike<SharedDeviceMemory<U>> { 231 static constexpr bool value = false; 232}; 233 234template <> 235struct IsDeviceMemoryValueLike<DeviceMemoryBase> { 236 static constexpr bool value = true; 237}; 238 239template <typename U> 240struct IsSharedDeviceMemory { 241 static constexpr bool value = false; 242}; 243 244template <typename U> 245struct IsSharedDeviceMemory<SharedDeviceMemory<U> &> { 246 static constexpr bool value = true; 247}; 248 249template <typename U> 250struct IsSharedDeviceMemory<SharedDeviceMemory<U>> { 251 static constexpr bool value = true; 252}; 253 254// KernelArg encapsulates the information necessary for a back-end executor to 255// configure a kernel to launch using the given argument. 256struct KernelArg { 257 // Indicates the type of an argument: normal, to be passed to the kernel 258 // in the standard manner, or shared memory, which has distinct 259 // rules for specification per backend. 260 enum Type { 261 kNormal, 262 kSharedMemory, 263 } type; 264 265 // The data to pass to the kernel - either a pointer to device memory, or the 266 // argument value. compact_array is used to prevent smaller args (ex. u8, u64) 267 // from requiring heap allocation. 268 port::InlinedVector<uint8, 4> data; 269 270 // The size of this argument in bytes. 271 uint64 bytes; 272}; 273 274// Typed variant of KernelBase, like a typed device function pointer. See the 275// file comment for details and example usage. 276// 277// This class contains template metaprogramming magic to type check the 278// parameters passed to a kernel launch are acceptable, and subsequently pack 279// them into a form which can be used by the StreamExecutorInterface 280// implementation. (i.e. CUDA and OpenCL both bind void*s with associated 281// sizes as kernel arguments.) 282// 283// Thread-compatible. 284template <typename... Params> 285class TypedKernel : public KernelBase { 286 public: 287 // Delegates to KernelBase::KernelBase(), see that constructor. 288 explicit TypedKernel(StreamExecutor *parent) : KernelBase(parent) {} 289 290 // Test-only constructor that can take a mock KernelInterface implementation. 291 // Takes ownership of implementation, it should not be null. 292 TypedKernel(StreamExecutor *parent, internal::KernelInterface *implementation) 293 : KernelBase(parent, implementation) {} 294 295 private: 296 // Stream needs access to the specific parameter-packing functionality that 297 // the TypedKernel provides for its corresponding type signature (and no other 298 // type signatures). 299 friend class Stream; 300 301 // This is the main entry point into the magic. Packs the parameters (which 302 // must type check against the class template) into the args and sizes 303 // arrays. 304 // 305 // Const refs are taken as parameters on all of the handlers to avoid 306 // implicit type promotion of integers. 307 void PackParams(std::vector<KernelArg> *args, Params... params) const { 308 PackOneParam(args, params...); 309 } 310 311 template <typename T, typename... RestOfParams> 312 void PackOneParam(std::vector<KernelArg> *args, const T &arg, 313 const RestOfParams... rest) const { 314 PackOneParam(args, arg); 315 PackOneParam(args, rest...); 316 } 317 318 // Packs one (non-DeviceMemoryBase) parameter into the arg and sizes array. 319 // The enable_if<> is for excluding DeviceMemoryBase args, which have a 320 // separate implementation below. 321 template <typename T> 322 void PackOneParam( 323 std::vector<KernelArg> *args, const T &arg, 324 typename std::enable_if<!IsDeviceMemoryValueLike<T>::value && 325 !IsDeviceMemoryPointer<T>::value && 326 !IsSharedDeviceMemory<T>::value>::type * = 327 nullptr) const { 328 static_assert(!std::is_pointer<T>::value, 329 "cannot pass raw pointer to the device"); 330 static_assert(!std::is_convertible<T, DeviceMemoryBase>::value, 331 "cannot pass device memory as a normal value"); 332 const uint8 *arg_ptr = reinterpret_cast<const uint8 *>(&arg); 333 args->emplace_back(KernelArg{ 334 KernelArg::kNormal, 335 port::InlinedVector<uint8, 4>{arg_ptr, arg_ptr + sizeof(arg)}, sizeof(arg)}); 336 } 337 338 // DeviceMemoryBase family reference override. 339 template <typename T> 340 void PackOneParam( 341 std::vector<KernelArg> *args, const T &arg, 342 typename std::enable_if<IsDeviceMemoryValueLike<T>::value>::type * = 343 nullptr) const { 344 args->emplace_back(parent()->DeviceMemoryToKernelArg(arg)); 345 } 346 347 // DeviceMemoryBase family pointer override. 348 template <typename T> 349 void PackOneParam( 350 std::vector<KernelArg> *args, T arg, 351 typename std::enable_if<IsDeviceMemoryPointer<T>::value>::type * = 352 nullptr) const { 353 DeviceMemoryBase *ptr = static_cast<DeviceMemoryBase *>(arg); 354 args->emplace_back(parent()->DeviceMemoryToKernelArg(*ptr)); 355 } 356 357 // Dynamic shared device memory has a size, but no associated allocation on 358 // the host; internally, the device will allocate storage. 359 template <typename T> 360 void PackOneParam( 361 std::vector<KernelArg> *args, T arg, 362 typename std::enable_if<IsSharedDeviceMemory<T>::value>::type * = 363 nullptr) const { 364 args->emplace_back(KernelArg{KernelArg::kSharedMemory, 365 port::InlinedVector<uint8, 4>(), arg.size()}); 366 } 367 368 // Base case for variadic template expansion - nothing to do! 369 void PackOneParam(std::vector<KernelArg> *args) const {} 370 371 SE_DISALLOW_COPY_AND_ASSIGN(TypedKernel); 372}; 373 374// Template metaprogramming helper type that helps us produce better error 375// messages at compile time when the are mismatches between the parameter 376// type list and the argument type list. 377template <typename ParamTuple, typename ArgTuple> 378struct KernelInvocationChecker { 379 // Whether the parameter tuple and argument tuple match in length. 380 static constexpr bool kLengthMatches = 381 std::tuple_size<ParamTuple>::value == std::tuple_size<ArgTuple>::value; 382 383 // The (matching) length of the parameters and arguments type lists. 384 static constexpr int kTupleLength = 385 static_cast<int>(std::tuple_size<ArgTuple>::value); 386 387 // Helper trait to say whether the parameter wants a DeviceMemory-reference 388 // compatible type. This is for inexact type matches, so that it doesn't have 389 // to be precisely a const DeviceMemory<T>&, but can also be a value that 390 // represents the same. 391 template <typename ParamType, typename ArgType> 392 struct IsCompatibleDeviceMemoryRef { 393 static constexpr bool value = false; 394 }; 395 396 // See type trait definition above. 397 template <typename U> 398 struct IsCompatibleDeviceMemoryRef<const DeviceMemory<U> &, DeviceMemory<U>> { 399 static constexpr bool value = true; 400 }; 401 402 // See type trait definition above. 403 template <typename U> 404 struct IsCompatibleDeviceMemoryRef<const SharedDeviceMemory<U> &, 405 SharedDeviceMemory<U>> { 406 static constexpr bool value = true; 407 }; 408 409 // Returns whether ParamT and ArgT are compatible for data parallel kernel 410 // parameter packing without any assert functionality. 411 template <typename ParamT, typename ArgT> 412 static constexpr bool CompatibleNoAssert() { 413 return std::is_same<typename std::remove_const<ParamT>::type, 414 ArgT>::value || 415 IsCompatibleDeviceMemoryRef<ParamT, ArgT>::value; 416 } 417 418 // Checks whether ParamT and ArgT are compatible for data parallel kernel 419 // parameter packing. kArgumentNumber is unused, it just for error display. 420 // 421 // NOTE: if you encounter an error here, you can see the mismatch by looking 422 // at the end of the last error message, which will be of the form: 423 // 424 // ...::Compatible<const perftools::gputools::DeviceMemory<OneThing> &, 425 // perftools::gputools::DeviceMemory<AnotherThing>, true, 426 // 0>' 427 // requested here 428 // 429 // This means that the 0th argument you passed to the kernel invocation should 430 // have been DeviceMemory<OneThing> but was observed to be 431 // DeviceMemory<AnotherThing>. 432 template <typename ParamT, typename ArgT, bool kShouldStaticAssert, 433 int kArgumentNumber> 434 static constexpr bool Compatible() { 435 static_assert( 436 kShouldStaticAssert ? CompatibleNoAssert<ParamT, ArgT>() : true, 437 "parameter type (LHS) is not compatible with argument type (RHS)"); 438 return CompatibleNoAssert<ParamT, ArgT>(); 439 } 440 441 // Checks the parameter/argument match at kArgumentNumber for an out of bounds 442 // argument number. 443 // 444 // This is the base case: we've run out of argument to check, so we're all 445 // good. 446 template <int kArgumentNumber, bool kShouldStaticAssert> 447 static constexpr bool CheckParam( 448 typename std::enable_if<(kArgumentNumber < 0)>::type *dummy = nullptr) { 449 return true; 450 } 451 452 // Checks the parameter/argument match at kArgumentNumber. 453 // kShouldStaticAssert determines whether to assert out on a mismatch, or just 454 // yield the constexpr boolean value. 455 template <int kArgumentNumber, bool kShouldStaticAssert> 456 static constexpr bool CheckParam( 457 typename std::enable_if<kArgumentNumber >= 0>::type *dummy = nullptr) { 458 typedef typename std::tuple_element<kArgumentNumber, ParamTuple>::type 459 ParamT; 460 typedef typename std::tuple_element<kArgumentNumber, ArgTuple>::type ArgT; 461 return Compatible<ParamT, ArgT, kShouldStaticAssert, kArgumentNumber>() && 462 CheckParam<kArgumentNumber - 1, kShouldStaticAssert>(); 463 } 464 465 // Checks the parameters/arguments for match, but doesn't static assert out. 466 // This is useful for testing/inspecting whether a set of parameters match in 467 // things like tests. 468 static constexpr bool CheckAllNoStaticAssert() { 469 return kLengthMatches && CheckParam<kTupleLength - 1, false>(); 470 } 471 472 // Checks the parameters and static asserts out with a helpful error message 473 // (and useful template parameters in the instantiation stack) if there is an 474 // error. 475 static constexpr bool CheckAllStaticAssert() { 476 static_assert(kLengthMatches, 477 "argument length mismatched against typed kernel parameters"); 478 return kLengthMatches && CheckParam<kTupleLength - 1, true>(); 479 } 480}; 481 482// This is a convenience type for checking whether a typed kernel matches 483// against a type list. 484template <typename KernelT, typename... Params> 485struct KernelParamsOk { 486 static constexpr bool kResult = false; 487}; 488 489// See above. 490template <typename... Params, typename... Args> 491struct KernelParamsOk<TypedKernel<Params...>, Args...> { 492 static constexpr bool kResult = KernelInvocationChecker< 493 std::tuple<Params...>, std::tuple<Args...>>::CheckAllNoStaticAssert(); 494}; 495 496} // namespace gputools 497} // namespace perftools 498 499#endif // TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_ 500