cuda_fft.h revision 4c0adf2c26345dd63e2b883317f7efb464862532
1122cdce33e3e0a01a7f82645617317530aa571fbA. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License"); 49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License. 59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at 69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur http://www.apache.org/licenses/LICENSE-2.0 89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software 109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS, 119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and 139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License. 149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/ 159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// CUDA-specific support for FFT functionality -- this wraps the cuFFT library 17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// capabilities, and is only included into CUDA implementation code -- it will 18f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// not introduce cuda headers into other code. 19f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_FFT_H_ 21f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_FFT_H_ 22f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 23a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang#include "cuda/include/cufft.h" 24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/stream_executor/fft.h" 25f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/stream_executor/platform/port.h" 26f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/stream_executor/plugin_registry.h" 27a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang#include "tensorflow/stream_executor/scratch_allocator.h" 28f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 29f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace perftools { 30f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace gputools { 31f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 32f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass Stream; 33f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 34f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace cuda { 35f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass CUDAExecutor; 37f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 38f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Opaque and unique indentifier for the cuFFT plugin. 39f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurextern const PluginId kCuFftPlugin; 40f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 41a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang// CUDAFftPlan uses deferred initialization. Only a single call of 42a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang// Initialize() is allowed to properly create cufft plan and set member 43a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang// variable is_initialized_ to true. Newly added interface that uses member 44a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang// variables should first check is_initialized_ to make sure that the values of 45a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang// member variables are valid. 46f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass CUDAFftPlan : public fft::Plan { 47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public: 48a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang CUDAFftPlan() 49a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang : parent_(nullptr), 50a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang plan_(-1), 51a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang fft_type_(fft::Type::kInvalid), 52a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang scratch_(nullptr), 534c0adf2c26345dd63e2b883317f7efb464862532A. Unique TensorFlower scratch_size_bytes_(0), 54a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang is_initialized_(false) {} 55f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur ~CUDAFftPlan() override; 56f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 57f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Get FFT direction in cuFFT based on FFT type. 58f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur int GetFftDirection() const; 59a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang cufftHandle GetPlan() const { 60a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang if (IsInitialized()) { 61a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang return plan_; 62a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang } else { 63a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang LOG(FATAL) << "Try to get cufftHandle value before initialization."; 64a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang } 65a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang } 66a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang 67a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang // Initialize function for batched plan 68a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang port::Status Initialize(CUDAExecutor *parent, Stream *stream, int rank, 69a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang uint64 *elem_count, uint64 *input_embed, 70a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang uint64 input_stride, uint64 input_distance, 71a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang uint64 *output_embed, uint64 output_stride, 72a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang uint64 output_distance, fft::Type type, 73a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang int batch_count, ScratchAllocator *scratch_allocator); 74a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang 75a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang // Initialize function for 1d,2d, and 3d plan 76a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang port::Status Initialize(CUDAExecutor *parent, Stream *stream, int rank, 77a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang uint64 *elem_count, fft::Type type, 78a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang ScratchAllocator *scratch_allocator); 79a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang 804c0adf2c26345dd63e2b883317f7efb464862532A. Unique TensorFlower port::Status UpdateScratchAllocator(Stream *stream, 814c0adf2c26345dd63e2b883317f7efb464862532A. Unique TensorFlower ScratchAllocator *scratch_allocator); 824c0adf2c26345dd63e2b883317f7efb464862532A. Unique TensorFlower 83a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang protected: 84a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang bool IsInitialized() const { return is_initialized_; } 85f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 86f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private: 87f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur CUDAExecutor *parent_; 88f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur cufftHandle plan_; 89f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur fft::Type fft_type_; 90a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang DeviceMemory<uint8> scratch_; 914c0adf2c26345dd63e2b883317f7efb464862532A. Unique TensorFlower size_t scratch_size_bytes_; 92a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang bool is_initialized_; 93f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}; 94f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 95f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// FFT support for CUDA platform via cuFFT library. 96f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// 97f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// This satisfies the platform-agnostic FftSupport interface. 98f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// 99f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Note that the cuFFT handle that this encapsulates is implicitly tied to the 100f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// context (and, as a result, the device) that the parent CUDAExecutor is tied 101f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// to. This simply happens as an artifact of creating the cuFFT handle when a 102f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// CUDA context is active. 103f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// 104f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Thread-safe. The CUDA context associated with all operations is the CUDA 105f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// context of parent_, so all context is explicit. 106f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass CUDAFft : public fft::FftSupport { 107f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public: 108f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur explicit CUDAFft(CUDAExecutor *parent) : parent_(parent) {} 109f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur ~CUDAFft() override {} 110f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 111f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES 112f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 113f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private: 114f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur CUDAExecutor *parent_; 115f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 116f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Two helper functions that execute dynload::cufftExec?2?. 117f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 118f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // This is for complex to complex FFT, when the direction is required. 119f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur template <typename FuncT, typename InputT, typename OutputT> 120f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur bool DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan, 121f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur FuncT cufft_exec, 122f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const DeviceMemory<InputT> &input, 123f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DeviceMemory<OutputT> *output); 124f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 125f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // This is for complex to real or real to complex FFT, when the direction 126f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // is implied. 127f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur template <typename FuncT, typename InputT, typename OutputT> 128f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur bool DoFftInternal(Stream *stream, fft::Plan *plan, FuncT cufft_exec, 129f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const DeviceMemory<InputT> &input, 130f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DeviceMemory<OutputT> *output); 131f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 132f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur SE_DISALLOW_COPY_AND_ASSIGN(CUDAFft); 133f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}; 134f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 135f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // namespace cuda 136f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // namespace gputools 137f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // namespace perftools 138f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 139f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_FFT_H_ 140