1122cdce33e3e0a01a7f82645617317530aa571fbA. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License");
49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License.
59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at
69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur    http://www.apache.org/licenses/LICENSE-2.0
89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software
109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS,
119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and
139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License.
149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/
159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// CUDA-specific support for FFT functionality -- this wraps the cuFFT library
17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// capabilities, and is only included into CUDA implementation code -- it will
18f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// not introduce cuda headers into other code.
19f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_FFT_H_
21f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_FFT_H_
22f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
23a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang#include "cuda/include/cufft.h"
24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/stream_executor/fft.h"
25f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/stream_executor/platform/port.h"
26f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/stream_executor/plugin_registry.h"
27a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang#include "tensorflow/stream_executor/scratch_allocator.h"
28f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
29f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace perftools {
30f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace gputools {
31f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
32f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass Stream;
33f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
34f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace cuda {
35f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass CUDAExecutor;
37f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
38f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Opaque and unique indentifier for the cuFFT plugin.
39f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurextern const PluginId kCuFftPlugin;
40f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
41a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang// CUDAFftPlan uses deferred initialization. Only a single call of
42a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang// Initialize() is allowed to properly create cufft plan and set member
43a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang// variable is_initialized_ to true. Newly added interface that uses member
44a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang// variables should first check is_initialized_ to make sure that the values of
45a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang// member variables are valid.
46f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass CUDAFftPlan : public fft::Plan {
47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public:
48a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang  CUDAFftPlan()
49a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang      : parent_(nullptr),
50a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang        plan_(-1),
51a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang        fft_type_(fft::Type::kInvalid),
52a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang        scratch_(nullptr),
534c0adf2c26345dd63e2b883317f7efb464862532A. Unique TensorFlower        scratch_size_bytes_(0),
54a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang        is_initialized_(false) {}
55f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  ~CUDAFftPlan() override;
56f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
57f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Get FFT direction in cuFFT based on FFT type.
58f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int GetFftDirection() const;
59a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang  cufftHandle GetPlan() const {
60a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang    if (IsInitialized()) {
61a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang      return plan_;
62a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang    } else {
63a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang      LOG(FATAL) << "Try to get cufftHandle value before initialization.";
64a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang    }
65a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang  }
66a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang
67a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang  // Initialize function for batched plan
68a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang  port::Status Initialize(CUDAExecutor *parent, Stream *stream, int rank,
69a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang                          uint64 *elem_count, uint64 *input_embed,
70a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang                          uint64 input_stride, uint64 input_distance,
71a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang                          uint64 *output_embed, uint64 output_stride,
72a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang                          uint64 output_distance, fft::Type type,
73a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang                          int batch_count, ScratchAllocator *scratch_allocator);
74a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang
75a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang  // Initialize function for 1d,2d, and 3d plan
76a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang  port::Status Initialize(CUDAExecutor *parent, Stream *stream, int rank,
77a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang                          uint64 *elem_count, fft::Type type,
78a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang                          ScratchAllocator *scratch_allocator);
79a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang
804c0adf2c26345dd63e2b883317f7efb464862532A. Unique TensorFlower  port::Status UpdateScratchAllocator(Stream *stream,
814c0adf2c26345dd63e2b883317f7efb464862532A. Unique TensorFlower                                      ScratchAllocator *scratch_allocator);
824c0adf2c26345dd63e2b883317f7efb464862532A. Unique TensorFlower
83a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang protected:
84a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang  bool IsInitialized() const { return is_initialized_; }
85f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
86f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private:
87f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  CUDAExecutor *parent_;
88f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  cufftHandle plan_;
89f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  fft::Type fft_type_;
90a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang  DeviceMemory<uint8> scratch_;
914c0adf2c26345dd63e2b883317f7efb464862532A. Unique TensorFlower  size_t scratch_size_bytes_;
92a05c3ce52ffd4909cb9bf7c155805406b5b1fc06Yangzihao Wang  bool is_initialized_;
93f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
94f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
95f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// FFT support for CUDA platform via cuFFT library.
96f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur//
97f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// This satisfies the platform-agnostic FftSupport interface.
98f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur//
99f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Note that the cuFFT handle that this encapsulates is implicitly tied to the
100f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// context (and, as a result, the device) that the parent CUDAExecutor is tied
101f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// to. This simply happens as an artifact of creating the cuFFT handle when a
102f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// CUDA context is active.
103f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur//
104f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Thread-safe. The CUDA context associated with all operations is the CUDA
105f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// context of parent_, so all context is explicit.
106f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass CUDAFft : public fft::FftSupport {
107f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public:
108f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  explicit CUDAFft(CUDAExecutor *parent) : parent_(parent) {}
109f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  ~CUDAFft() override {}
110f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
111f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES
112f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
113f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private:
114f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  CUDAExecutor *parent_;
115f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
116f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Two helper functions that execute dynload::cufftExec?2?.
117f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
118f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // This is for complex to complex FFT, when the direction is required.
119f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  template <typename FuncT, typename InputT, typename OutputT>
120f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  bool DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
121f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                  FuncT cufft_exec,
122f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                  const DeviceMemory<InputT> &input,
123f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                  DeviceMemory<OutputT> *output);
124f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
125f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // This is for complex to real or real to complex FFT, when the direction
126f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // is implied.
127f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  template <typename FuncT, typename InputT, typename OutputT>
128f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  bool DoFftInternal(Stream *stream, fft::Plan *plan, FuncT cufft_exec,
129f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                     const DeviceMemory<InputT> &input,
130f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                     DeviceMemory<OutputT> *output);
131f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
132f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  SE_DISALLOW_COPY_AND_ASSIGN(CUDAFft);
133f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
134f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
135f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // namespace cuda
136f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // namespace gputools
137f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // namespace perftools
138f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
139f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#endif  // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_FFT_H_
140