1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
17
18#include <string>
19
20#include "tensorflow/compiler/xla/types.h"
21#include "tensorflow/compiler/xla/util.h"
22#include "tensorflow/core/lib/strings/strcat.h"
23#include "tensorflow/core/lib/strings/stringprintf.h"
24#include "tensorflow/core/platform/logging.h"
25#include "tensorflow/core/platform/stream_executor_no_cuda.h"
26
27namespace se = ::perftools::gputools;
28
29namespace xla {
30namespace gpu {
31
32FftScratchAllocator::FftScratchAllocator(
33    int device_ordinal, DeviceMemoryAllocator* memory_allocator)
34    : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
35
36FftScratchAllocator::~FftScratchAllocator() {
37  for (auto& allocated_buffer : allocated_buffers_) {
38    if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer)
39             .ok()) {
40      // The program can still continue with failed deallocation.
41      LOG(ERROR) << "Failed to deallocate the allocated buffer: "
42                 << allocated_buffer.opaque();
43    }
44  }
45}
46
47int64 FftScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) {
48  constexpr int64 kFftScratchSize = 1LL << 32;  // 4GB by default.
49  return kFftScratchSize;
50}
51
52se::port::StatusOr<se::DeviceMemory<uint8>> FftScratchAllocator::AllocateBytes(
53    se::Stream* stream, int64 byte_size) {
54  CHECK_GE(byte_size, 0) << "byte_size must be positive.";
55  if (byte_size > GetMemoryLimitInBytes(stream)) {
56    return se::port::Status(
57        se::port::error::RESOURCE_EXHAUSTED,
58        tensorflow::strings::Printf(
59            "Allocating %lld bytes exceeds the memory limit of %lld bytes.",
60            byte_size, GetMemoryLimitInBytes(stream)));
61  }
62
63  auto status_or_memory =
64      memory_allocator_->Allocate(device_ordinal_, byte_size,
65                                  /*retry_on_failure=*/false);
66  if (!status_or_memory.ok()) {
67    return tensorflow::errors::ResourceExhausted(
68        "Failed to allocate %lld bytes on device %d.", byte_size,
69        device_ordinal_);
70  }
71  se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie();
72  allocated_buffers_.push_back(allocated_buffer);
73  total_allocated_bytes_ += byte_size;
74  return se::DeviceMemory<uint8>(allocated_buffer);
75}
76
77namespace {
78
79se::fft::Type FftTypeToSeType(FftType type) {
80  switch (type) {
81    case FftType::FFT:
82      return se::fft::Type::kC2CForward;
83    case FftType::IFFT:
84      return se::fft::Type::kC2CInverse;
85    case FftType::IRFFT:
86      return se::fft::Type::kC2R;
87    case FftType::RFFT:
88      return se::fft::Type::kR2C;
89    default:
90      LOG(FATAL) << "unsupported fft type";
91  }
92}
93
94string FftTypeToString(se::fft::Type type) {
95  switch (type) {
96    case se::fft::Type::kC2CForward:
97      return "FFT";
98    case se::fft::Type::kC2CInverse:
99      return "IFFT";
100    case se::fft::Type::kC2R:
101      return "IRFFT";
102    case se::fft::Type::kR2C:
103      return "RFFT";
104    default:
105      LOG(FATAL) << "unknown fft type";
106  }
107}
108
109}  // namespace
110
111FftThunk::FftThunk(FftType fft_type,
112                   tensorflow::gtl::ArraySlice<int64> fft_length,
113                   const BufferAllocation::Slice& input_buffer,
114                   const BufferAllocation::Slice& output_buffer,
115                   const Shape& input_shape, const Shape& output_shape,
116                   const HloInstruction* hlo)
117    : Thunk(Kind::kFft, hlo),
118      fft_type_(FftTypeToSeType(fft_type)),
119      fft_length_(fft_length.begin(), fft_length.end()),
120      scale_factor_(1.0f),
121      input_buffer_(input_buffer),
122      output_buffer_(output_buffer),
123      input_shape_(input_shape),
124      output_shape_(output_shape) {}
125
126tensorflow::Status FftThunk::ExecuteOnStream(
127    const BufferAllocations& buffer_allocations, se::Stream* stream) {
128  VLOG(3) << "FFT type: " << FftTypeToString(fft_type_);
129  VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_);
130  VLOG(3) << "Output shape: "
131          << ShapeUtil::HumanStringWithLayout(output_shape_);
132
133  FftScratchAllocator scratch_allocator(buffer_allocations.device_ordinal(),
134                                        buffer_allocations.memory_allocator());
135
136  if (fft_plan_ == nullptr) {
137    const int64 fft_rank = fft_length_.size();
138    CHECK_LE(fft_rank, 3);
139    int batch_size = 1;
140    for (int i = 0; i < input_shape_.dimensions_size() - fft_rank; ++i) {
141      batch_size *= input_shape_.dimensions(i);
142    }
143    uint64 fft_length[3];
144    uint64 input_embed[3];
145    const uint64 input_stride = 1;
146    uint64 input_distance = 1;
147    uint64 output_embed[3];
148    const uint64 output_stride = 1;
149    uint64 output_distance = 1;
150
151    for (int i = 0; i < fft_rank; ++i) {
152      auto dim_offset = input_shape_.dimensions_size() - fft_rank + i;
153      fft_length[i] = static_cast<uint64>(fft_length_[i]);
154      input_embed[i] = input_shape_.dimensions(dim_offset);
155      input_distance *= input_shape_.dimensions(dim_offset);
156      output_embed[i] = output_shape_.dimensions(dim_offset);
157      output_distance *= output_shape_.dimensions(dim_offset);
158    }
159
160    constexpr bool kInPlaceFft = false;
161    fft_plan_ =
162        stream->parent()->AsFft()->CreateBatchedPlanWithScratchAllocator(
163            stream, fft_rank, fft_length, input_embed, input_stride,
164            input_distance, output_embed, output_stride, output_distance,
165            fft_type_, kInPlaceFft, batch_size, &scratch_allocator);
166    scale_factor_ = 1.0f / output_distance;
167  } else {
168    stream->parent()->AsFft()->UpdatePlanWithScratchAllocator(
169        stream, fft_plan_.get(), &scratch_allocator);
170  }
171
172  bool launch_ok;
173  switch (fft_type_) {
174    case se::fft::Type::kC2CForward: {
175      se::DeviceMemory<complex64> input_data(
176          buffer_allocations.GetDeviceAddress(input_buffer_));
177      se::DeviceMemory<complex64> output_data(
178          buffer_allocations.GetDeviceAddress(output_buffer_));
179      launch_ok =
180          stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok();
181      break;
182    }
183    case se::fft::Type::kC2CInverse: {
184      se::DeviceMemory<complex64> input_data(
185          buffer_allocations.GetDeviceAddress(input_buffer_));
186      se::DeviceMemory<complex64> output_data(
187          buffer_allocations.GetDeviceAddress(output_buffer_));
188      launch_ok =
189          stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok();
190      if (launch_ok) {
191        launch_ok =
192            stream
193                ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
194                               complex64(scale_factor_), &output_data, 1)
195                .ok();
196      }
197      break;
198    }
199    case se::fft::Type::kR2C: {
200      se::DeviceMemory<float> input_data(
201          buffer_allocations.GetDeviceAddress(input_buffer_));
202      se::DeviceMemory<complex64> output_data(
203          buffer_allocations.GetDeviceAddress(output_buffer_));
204      launch_ok =
205          stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok();
206      break;
207    }
208    case se::fft::Type::kC2R: {
209      se::DeviceMemory<complex64> input_data(
210          buffer_allocations.GetDeviceAddress(input_buffer_));
211      se::DeviceMemory<float> output_data(
212          buffer_allocations.GetDeviceAddress(output_buffer_));
213      launch_ok =
214          stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok();
215      if (launch_ok) {
216        launch_ok = stream
217                        ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
218                                       scale_factor_, &output_data, 1)
219                        .ok();
220      }
221      break;
222    }
223    default:
224      LOG(FATAL) << "unsupported fft type";
225  }
226  if (launch_ok) {
227    return tensorflow::Status::OK();
228  }
229  return InternalError("Unable to launch fft for thunk %p with type %s", this,
230                       FftTypeToString(fft_type_).c_str());
231}
232
233}  // namespace gpu
234}  // namespace xla
235