1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/stream_executor/cuda/cuda_platform.h"
17
18#include "tensorflow/stream_executor/cuda/cuda_driver.h"
19#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
20#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
21#include "tensorflow/stream_executor/lib/error.h"
22#include "tensorflow/stream_executor/lib/initialize.h"
23#include "tensorflow/stream_executor/lib/ptr_util.h"
24#include "tensorflow/stream_executor/lib/status.h"
25#include "tensorflow/stream_executor/lib/stringprintf.h"
26
27namespace perftools {
28namespace gputools {
29namespace cuda {
30namespace {
31
32// Synchronize with spinlocks.
33const char kScheduleSpinString[] = "spin";
34// Synchronize with spinlocks that also call CPU yield instructions.
35const char kScheduleYieldString[] = "yield";
36// Synchronize with a "synchronization primitive" (e.g. mutex).
37const char kScheduleBlockingSyncString[] = "blocking_sync";
38
39const DeviceOptions GetDeviceOptionsFromEnv() {
40  const char* gpu_schedule_string =
41      std::getenv("TF_CUDA_PLATFORM_GPU_DEVICE_SCHEDULE");
42
43  if (gpu_schedule_string == nullptr) {
44    return perftools::gputools::DeviceOptions::Default();
45  }
46
47  unsigned device_flags = 0;
48  if (strcmp(kScheduleSpinString, gpu_schedule_string) == 0) {
49    device_flags = perftools::gputools::DeviceOptions::kScheduleSpin;
50  } else if (strcmp(kScheduleYieldString, gpu_schedule_string) == 0) {
51    device_flags = perftools::gputools::DeviceOptions::kScheduleYield;
52  } else if (strcmp(kScheduleBlockingSyncString, gpu_schedule_string) == 0) {
53    device_flags = perftools::gputools::DeviceOptions::kScheduleBlockingSync;
54  } else {
55    LOG(QFATAL) << "Unknown option for environment variable "
56                   "TF_CUDA_PLATFORM_GPU_DEVICE_SCHEDULE "
57                << gpu_schedule_string << " should be one of {"
58                << kScheduleBlockingSyncString << ", " << kScheduleSpinString
59                << ", " << kScheduleYieldString << "}";
60  }
61
62  return perftools::gputools::DeviceOptions(device_flags);
63}
64
65}  // namespace
66
67CudaPlatform::CudaPlatform()
68    : name_("CUDA"), min_numa_node_(0), limit_numa_node_(0) {}
69
70CudaPlatform::~CudaPlatform() {}
71
72// Due to legacy issues in user code, we can't currently call InpectNumaNodes
73// at module initialization time, because non-GPU programs still include this
74// plugin via various methods, so instead, it has to be init-on-reference.
75void CudaPlatform::InspectNumaNodes() {
76  // To get NUMA node information, we need to create all executors, so we can
77  // examine their device descriptions to see their bus assignments.
78  static bool initialized = false;
79  static mutex numa_mutex(LINKER_INITIALIZED);
80  mutex_lock lock(numa_mutex);
81  if (initialized) {
82    return;
83  }
84
85  StreamExecutorConfig config;
86  for (int i = 0; i < VisibleDeviceCount(); i++) {
87    config.ordinal = i;
88    StreamExecutor* exec = GetExecutor(config).ValueOrDie();
89    if (i == 0) {
90      // NUMA nodes may not start at 0, so set the minimum node  based on the
91      // first executor we see.
92      min_numa_node_ = exec->GetDeviceDescription().numa_node();
93      limit_numa_node_ = min_numa_node_ + 1;
94    } else {
95      min_numa_node_ =
96          std::min(min_numa_node_, exec->GetDeviceDescription().numa_node());
97      limit_numa_node_ = std::max(limit_numa_node_,
98                                  exec->GetDeviceDescription().numa_node() + 1);
99    }
100  }
101  initialized = true;
102}
103
104int CudaPlatform::BusCount() {
105  InspectNumaNodes();
106  return limit_numa_node_ - min_numa_node_;
107}
108
109int CudaPlatform::DeviceToBus(int device_ordinal) {
110  StreamExecutorConfig config;
111  config.ordinal = device_ordinal;
112  StreamExecutor* exec = GetExecutor(config).ValueOrDie();
113  return exec->GetDeviceDescription().numa_node() - min_numa_node_;
114}
115
116port::StatusOr<StreamExecutor*> CudaPlatform::FirstExecutorForBus(
117    int bus_ordinal) {
118  InspectNumaNodes();
119  CHECK_LT(bus_ordinal, BusCount()) << "bus ordinal out of available range";
120  for (int i = 0; i < VisibleDeviceCount(); i++) {
121    if (DeviceToBus(i) == bus_ordinal) {
122      StreamExecutorConfig config;
123      config.ordinal = i;
124      return GetExecutor(config).ValueOrDie();
125    }
126  }
127
128  return port::Status{
129      port::error::NOT_FOUND,
130      port::Printf("Executor for bus %d not found.", bus_ordinal)};
131}
132
133Platform::Id CudaPlatform::id() const { return kCudaPlatformId; }
134
135int CudaPlatform::VisibleDeviceCount() const {
136  // Throw away the result - it logs internally, and this [containing] function
137  // isn't in the path of user control. It's safe to call this > 1x.
138  if (!cuda::CUDADriver::Init().ok()) {
139    return -1;
140  }
141
142  return CUDADriver::GetDeviceCount();
143}
144
145const string& CudaPlatform::Name() const { return name_; }
146
147port::StatusOr<StreamExecutor*> CudaPlatform::ExecutorForDevice(int ordinal) {
148  StreamExecutorConfig config;
149  config.ordinal = ordinal;
150  config.plugin_config = PluginConfig();
151  config.device_options = GetDeviceOptionsFromEnv();
152  return GetExecutor(config);
153}
154
155port::StatusOr<StreamExecutor*> CudaPlatform::ExecutorForDeviceWithPluginConfig(
156    int device_ordinal, const PluginConfig& plugin_config) {
157  StreamExecutorConfig config;
158  config.ordinal = device_ordinal;
159  config.plugin_config = plugin_config;
160  config.device_options = GetDeviceOptionsFromEnv();
161  return GetExecutor(config);
162}
163
164port::StatusOr<StreamExecutor*> CudaPlatform::GetExecutor(
165    const StreamExecutorConfig& config) {
166  return executor_cache_.GetOrCreate(
167      config, [&]() { return GetUncachedExecutor(config); });
168}
169
170port::StatusOr<std::unique_ptr<StreamExecutor>>
171CudaPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
172  auto executor = port::MakeUnique<StreamExecutor>(
173      this, port::MakeUnique<CUDAExecutor>(config.plugin_config));
174  auto init_status = executor->Init(config.ordinal, config.device_options);
175  if (!init_status.ok()) {
176    return port::Status{
177        port::error::INTERNAL,
178        port::Printf(
179            "failed initializing StreamExecutor for CUDA device ordinal %d: %s",
180            config.ordinal, init_status.ToString().c_str())};
181  }
182
183  return std::move(executor);
184}
185
186void CudaPlatform::RegisterTraceListener(
187    std::unique_ptr<TraceListener> listener) {
188  LOG(FATAL) << "not yet implemented: register CUDA trace listener";
189}
190
191void CudaPlatform::UnregisterTraceListener(TraceListener* listener) {
192  LOG(FATAL) << "not yet implemented: unregister CUDA trace listener";
193}
194
195}  // namespace cuda
196
197static void InitializeCudaPlatform() {
198  // Disabling leak checking, MultiPlatformManager does not destroy its
199  // registered platforms.
200
201  std::unique_ptr<cuda::CudaPlatform> platform(new cuda::CudaPlatform);
202  SE_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform)));
203}
204
205}  // namespace gputools
206}  // namespace perftools
207
208REGISTER_MODULE_INITIALIZER(cuda_platform,
209                            perftools::gputools::InitializeCudaPlatform());
210
211DECLARE_MODULE_INITIALIZER(multi_platform_manager);
212// Note that module initialization sequencing is not supported in the
213// open-source project, so this will be a no-op there.
214REGISTER_MODULE_INITIALIZER_SEQUENCE(cuda_platform, multi_platform_manager);
215