1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/service/gpu/fft_thunk.h" 17 18#include <string> 19 20#include "tensorflow/compiler/xla/types.h" 21#include "tensorflow/compiler/xla/util.h" 22#include "tensorflow/core/lib/strings/strcat.h" 23#include "tensorflow/core/lib/strings/stringprintf.h" 24#include "tensorflow/core/platform/logging.h" 25#include "tensorflow/core/platform/stream_executor_no_cuda.h" 26 27namespace se = ::perftools::gputools; 28 29namespace xla { 30namespace gpu { 31 32FftScratchAllocator::FftScratchAllocator( 33 int device_ordinal, DeviceMemoryAllocator* memory_allocator) 34 : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} 35 36FftScratchAllocator::~FftScratchAllocator() { 37 for (auto& allocated_buffer : allocated_buffers_) { 38 if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer) 39 .ok()) { 40 // The program can still continue with failed deallocation. 41 LOG(ERROR) << "Failed to deallocate the allocated buffer: " 42 << allocated_buffer.opaque(); 43 } 44 } 45} 46 47int64 FftScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) { 48 constexpr int64 kFftScratchSize = 1LL << 32; // 4GB by default. 49 return kFftScratchSize; 50} 51 52se::port::StatusOr<se::DeviceMemory<uint8>> FftScratchAllocator::AllocateBytes( 53 se::Stream* stream, int64 byte_size) { 54 CHECK_GE(byte_size, 0) << "byte_size must be positive."; 55 if (byte_size > GetMemoryLimitInBytes(stream)) { 56 return se::port::Status( 57 se::port::error::RESOURCE_EXHAUSTED, 58 tensorflow::strings::Printf( 59 "Allocating %lld bytes exceeds the memory limit of %lld bytes.", 60 byte_size, GetMemoryLimitInBytes(stream))); 61 } 62 63 auto status_or_memory = 64 memory_allocator_->Allocate(device_ordinal_, byte_size, 65 /*retry_on_failure=*/false); 66 if (!status_or_memory.ok()) { 67 return tensorflow::errors::ResourceExhausted( 68 "Failed to allocate %lld bytes on device %d.", byte_size, 69 device_ordinal_); 70 } 71 se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie(); 72 allocated_buffers_.push_back(allocated_buffer); 73 total_allocated_bytes_ += byte_size; 74 return se::DeviceMemory<uint8>(allocated_buffer); 75} 76 77namespace { 78 79se::fft::Type FftTypeToSeType(FftType type) { 80 switch (type) { 81 case FftType::FFT: 82 return se::fft::Type::kC2CForward; 83 case FftType::IFFT: 84 return se::fft::Type::kC2CInverse; 85 case FftType::IRFFT: 86 return se::fft::Type::kC2R; 87 case FftType::RFFT: 88 return se::fft::Type::kR2C; 89 default: 90 LOG(FATAL) << "unsupported fft type"; 91 } 92} 93 94string FftTypeToString(se::fft::Type type) { 95 switch (type) { 96 case se::fft::Type::kC2CForward: 97 return "FFT"; 98 case se::fft::Type::kC2CInverse: 99 return "IFFT"; 100 case se::fft::Type::kC2R: 101 return "IRFFT"; 102 case se::fft::Type::kR2C: 103 return "RFFT"; 104 default: 105 LOG(FATAL) << "unknown fft type"; 106 } 107} 108 109} // namespace 110 111FftThunk::FftThunk(FftType fft_type, 112 tensorflow::gtl::ArraySlice<int64> fft_length, 113 const BufferAllocation::Slice& input_buffer, 114 const BufferAllocation::Slice& output_buffer, 115 const Shape& input_shape, const Shape& output_shape, 116 const HloInstruction* hlo) 117 : Thunk(Kind::kFft, hlo), 118 fft_type_(FftTypeToSeType(fft_type)), 119 fft_length_(fft_length.begin(), fft_length.end()), 120 scale_factor_(1.0f), 121 input_buffer_(input_buffer), 122 output_buffer_(output_buffer), 123 input_shape_(input_shape), 124 output_shape_(output_shape) {} 125 126tensorflow::Status FftThunk::ExecuteOnStream( 127 const BufferAllocations& buffer_allocations, se::Stream* stream) { 128 VLOG(3) << "FFT type: " << FftTypeToString(fft_type_); 129 VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_); 130 VLOG(3) << "Output shape: " 131 << ShapeUtil::HumanStringWithLayout(output_shape_); 132 133 FftScratchAllocator scratch_allocator(buffer_allocations.device_ordinal(), 134 buffer_allocations.memory_allocator()); 135 136 if (fft_plan_ == nullptr) { 137 const int64 fft_rank = fft_length_.size(); 138 CHECK_LE(fft_rank, 3); 139 int batch_size = 1; 140 for (int i = 0; i < input_shape_.dimensions_size() - fft_rank; ++i) { 141 batch_size *= input_shape_.dimensions(i); 142 } 143 uint64 fft_length[3]; 144 uint64 input_embed[3]; 145 const uint64 input_stride = 1; 146 uint64 input_distance = 1; 147 uint64 output_embed[3]; 148 const uint64 output_stride = 1; 149 uint64 output_distance = 1; 150 151 for (int i = 0; i < fft_rank; ++i) { 152 auto dim_offset = input_shape_.dimensions_size() - fft_rank + i; 153 fft_length[i] = static_cast<uint64>(fft_length_[i]); 154 input_embed[i] = input_shape_.dimensions(dim_offset); 155 input_distance *= input_shape_.dimensions(dim_offset); 156 output_embed[i] = output_shape_.dimensions(dim_offset); 157 output_distance *= output_shape_.dimensions(dim_offset); 158 } 159 160 constexpr bool kInPlaceFft = false; 161 fft_plan_ = 162 stream->parent()->AsFft()->CreateBatchedPlanWithScratchAllocator( 163 stream, fft_rank, fft_length, input_embed, input_stride, 164 input_distance, output_embed, output_stride, output_distance, 165 fft_type_, kInPlaceFft, batch_size, &scratch_allocator); 166 scale_factor_ = 1.0f / output_distance; 167 } else { 168 stream->parent()->AsFft()->UpdatePlanWithScratchAllocator( 169 stream, fft_plan_.get(), &scratch_allocator); 170 } 171 172 bool launch_ok; 173 switch (fft_type_) { 174 case se::fft::Type::kC2CForward: { 175 se::DeviceMemory<complex64> input_data( 176 buffer_allocations.GetDeviceAddress(input_buffer_)); 177 se::DeviceMemory<complex64> output_data( 178 buffer_allocations.GetDeviceAddress(output_buffer_)); 179 launch_ok = 180 stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok(); 181 break; 182 } 183 case se::fft::Type::kC2CInverse: { 184 se::DeviceMemory<complex64> input_data( 185 buffer_allocations.GetDeviceAddress(input_buffer_)); 186 se::DeviceMemory<complex64> output_data( 187 buffer_allocations.GetDeviceAddress(output_buffer_)); 188 launch_ok = 189 stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok(); 190 if (launch_ok) { 191 launch_ok = 192 stream 193 ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape_), 194 complex64(scale_factor_), &output_data, 1) 195 .ok(); 196 } 197 break; 198 } 199 case se::fft::Type::kR2C: { 200 se::DeviceMemory<float> input_data( 201 buffer_allocations.GetDeviceAddress(input_buffer_)); 202 se::DeviceMemory<complex64> output_data( 203 buffer_allocations.GetDeviceAddress(output_buffer_)); 204 launch_ok = 205 stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok(); 206 break; 207 } 208 case se::fft::Type::kC2R: { 209 se::DeviceMemory<complex64> input_data( 210 buffer_allocations.GetDeviceAddress(input_buffer_)); 211 se::DeviceMemory<float> output_data( 212 buffer_allocations.GetDeviceAddress(output_buffer_)); 213 launch_ok = 214 stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok(); 215 if (launch_ok) { 216 launch_ok = stream 217 ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape_), 218 scale_factor_, &output_data, 1) 219 .ok(); 220 } 221 break; 222 } 223 default: 224 LOG(FATAL) << "unsupported fft type"; 225 } 226 if (launch_ok) { 227 return tensorflow::Status::OK(); 228 } 229 return InternalError("Unable to launch fft for thunk %p with type %s", this, 230 FftTypeToString(fft_type_).c_str()); 231} 232 233} // namespace gpu 234} // namespace xla 235