1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/platform_util.h"
17
18#include <algorithm>
19#include <string>
20#include <utility>
21
22#include "tensorflow/compiler/xla/service/compiler.h"
23#include "tensorflow/compiler/xla/status_macros.h"
24#include "tensorflow/compiler/xla/statusor.h"
25#include "tensorflow/compiler/xla/types.h"
26#include "tensorflow/compiler/xla/util.h"
27#include "tensorflow/core/lib/core/threadpool.h"
28#include "tensorflow/core/lib/strings/str_util.h"
29#include "tensorflow/core/platform/logging.h"
30#include "tensorflow/core/platform/stream_executor_no_cuda.h"
31
32namespace se = ::perftools::gputools;
33
34namespace xla {
35
36using tensorflow::str_util::Lowercase;
37
38// Minimum supported CUDA compute capability is 3.5.
39constexpr int kMinCudaComputeCapabilityMajor = 3;
40constexpr int kMinCudaComputeCapabilityMinor = 5;
41
42// The name of the interpreter platform.
43constexpr char kInterpreter[] = "interpreter";
44
45namespace {
46
47string CanonicalPlatformName(const string& name) {
48  string platform_str = Lowercase(name);
49  // "cpu" and "host" mean the same thing.
50  if (platform_str == "cpu") {
51    platform_str = "host";
52  }
53  // "gpu" and "cuda" mean the same thing.
54  if (platform_str == "gpu") {
55    platform_str = "cuda";
56  }
57  return platform_str;
58}
59
60}  // namespace
61
62/* static */ StatusOr<std::vector<se::Platform*>>
63PlatformUtil::GetSupportedPlatforms() {
64  se::MultiPlatformManager::PlatformMap platform_map;
65  se::port::Status platforms_status = se::MultiPlatformManager::WithPlatforms(
66      [&platform_map](se::MultiPlatformManager::PlatformMap* map) {
67        platform_map = *map;
68        return se::port::Status::OK();
69      });
70  if (platform_map.empty()) {
71    LOG(WARNING) << "no executor platforms available: platform map is empty";
72  }
73
74  // Gather all platforms which have an XLA compiler.
75  std::vector<se::Platform*> platforms;
76  for (auto& platform_pair : platform_map) {
77    auto* platform = platform_pair.second;
78    auto compiler_status = Compiler::GetForPlatform(platform);
79    if (compiler_status.ok()) {
80      if (platform->VisibleDeviceCount() > 0) {
81        LOG(INFO) << "platform " << platform->Name() << " present with "
82                  << platform->VisibleDeviceCount() << " visible devices";
83      } else {
84        LOG(WARNING) << "platform " << platform->Name() << " present but no "
85                     << "visible devices found";
86      }
87      // Note: currently we call zero device platforms "supported" on the basis
88      // that, if the platform support was linked in, it was probably intended
89      // to be used for execution, and this way we can flag an error.
90      //
91      // TODO(b/33730287) If we want an alternative version of this behavior we
92      // could add an --xla_fallback_to_host flag.
93      platforms.push_back(platform);
94    } else {
95      LOG(INFO) << "platform " << platform->Name() << " present but no "
96                << "XLA compiler available: "
97                << compiler_status.status().error_message();
98    }
99  }
100  return platforms;
101}
102
103/* static */ StatusOr<se::Platform*> PlatformUtil::GetSolePlatform() {
104  TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms());
105  if (platforms.empty()) {
106    return NotFound("no platforms found");
107  } else if (platforms.size() == 1) {
108    return platforms[0];
109  }
110
111  // Multiple platforms present and we can't pick a reasonable default.
112  string platforms_string = tensorflow::str_util::Join(
113      platforms, ", ",
114      [](string* out, const se::Platform* p) { out->append(p->Name()); });
115  return InvalidArgument(
116      "must specify platform because more than one platform found: %s",
117      platforms_string.c_str());
118}
119
120/* static */ StatusOr<se::Platform*> PlatformUtil::GetDefaultPlatform() {
121  TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms());
122  if (platforms.empty()) {
123    return NotFound("no platforms found");
124  } else if (platforms.size() == 1) {
125    return platforms[0];
126  } else if (platforms.size() == 2) {
127    for (int i = 0; i < 2; i++) {
128      if (Lowercase(platforms[i]->Name()) == kInterpreter &&
129          Lowercase(platforms[1 - i]->Name()) != kInterpreter) {
130        return platforms[1 - i];
131      }
132    }
133  }
134
135  // Multiple platforms present and we can't pick a reasonable default.
136  string platforms_string = tensorflow::str_util::Join(
137      platforms, ", ",
138      [](string* out, const se::Platform* p) { out->append(p->Name()); });
139  return InvalidArgument(
140      "must specify platform because more than one platform (except for the "
141      "interpreter platform) found: %s",
142      platforms_string.c_str());
143}
144
145/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatform(
146    const string& platform_name) {
147  string platform_str = CanonicalPlatformName(platform_name);
148  TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
149  for (se::Platform* platform : platforms) {
150    if (Lowercase(platform->Name()) == platform_str) {
151      return platform;
152    }
153  }
154  return InvalidArgument("platform %s not found", platform_name.c_str());
155}
156
157/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatformExceptFor(
158    const string& platform_name) {
159  string platform_str = CanonicalPlatformName(platform_name);
160
161  TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
162  std::vector<se::Platform*> matched;
163  for (se::Platform* platform : platforms) {
164    if (Lowercase(platform->Name()) != platform_name) {
165      matched.push_back(platform);
166    }
167  }
168  if (matched.empty()) {
169    return InvalidArgument("unable to find platform that is not %s",
170                           platform_name.c_str());
171  }
172  if (matched.size() == 1) {
173    return matched[0];
174  }
175  string matched_string = tensorflow::str_util::Join(
176      matched, ", ",
177      [](string* out, const se::Platform* p) { out->append(p->Name()); });
178  return InvalidArgument(
179      "found multiple platforms %s, but expected one platform except for %s",
180      matched_string.c_str(), platform_name.c_str());
181}
182
183// Returns whether the device underlying the given StreamExecutor is supported
184// by XLA.
185static bool IsDeviceSupported(se::StreamExecutor* executor) {
186  const auto& description = executor->GetDeviceDescription();
187  if (executor->platform()->id() == se::cuda::kCudaPlatformId) {
188    // CUDA devices must have a minimum compute capability.
189    int major_version, minor_version;
190    if (description.cuda_compute_capability(&major_version, &minor_version)) {
191      if (major_version < kMinCudaComputeCapabilityMajor ||
192          (major_version == kMinCudaComputeCapabilityMajor &&
193           minor_version < kMinCudaComputeCapabilityMinor)) {
194        LOG(INFO) << "StreamExecutor cuda device ("
195                  << executor->device_ordinal() << ") is of "
196                  << "insufficient compute capability: "
197                  << kMinCudaComputeCapabilityMajor << "."
198                  << kMinCudaComputeCapabilityMinor << " required, "
199                  << "device is " << major_version << "." << minor_version;
200        return false;
201      }
202    }
203  }
204  return true;
205}
206
207/* static */ StatusOr<std::vector<se::StreamExecutor*>>
208PlatformUtil::GetStreamExecutors(se::Platform* platform) {
209  int device_count = platform->VisibleDeviceCount();
210  if (device_count <= 0) {
211    return NotFound("no %s devices found", platform->Name().c_str());
212  }
213  if (platform->id() == se::host::kHostPlatformId) {
214    // On host "devices", StreamExecutor exports a device for each hardware
215    // thread. Because we parallelize a single computation across threads, it
216    // doesn't make sense to expose these as separate devices, so fix the number
217    // of devices to one.
218    device_count = 1;
219  }
220  std::vector<se::StreamExecutor*> stream_executors(device_count, nullptr);
221  VLOG(1) << "Initializing devices";
222  {
223    tensorflow::thread::ThreadPool thread_pool(
224        tensorflow::Env::Default(), "device_initialization", device_count);
225    for (int i = 0; i < device_count; ++i) {
226      thread_pool.Schedule([platform, i, &stream_executors]() {
227        VLOG(1) << "Started device init " << i;
228        se::StreamExecutorConfig config;
229        config.ordinal = i;
230        auto executor_status = platform->GetExecutor(config);
231        if (executor_status.ok()) {
232          se::StreamExecutor* executor = executor_status.ValueOrDie();
233          if (IsDeviceSupported(executor)) {
234            stream_executors[i] = executor;
235          }
236        } else {
237          LOG(WARNING) << "unable to create StreamExecutor for "
238                       << platform->Name() << ":" << i << ": "
239                       << executor_status.status().error_message();
240        }
241        VLOG(1) << "Finished device init " << i;
242      });
243    }
244    // Block here in thread_pool destructor until all devices are initialized.
245  }
246  VLOG(1) << "Device initialization complete";
247  if (std::all_of(stream_executors.begin(), stream_executors.end(),
248                  [](se::StreamExecutor* s) { return s == nullptr; })) {
249    return InternalError("no supported devices found for platform %s",
250                         platform->Name().c_str());
251  }
252  return stream_executors;
253}
254
255}  // namespace xla
256