1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/service/platform_util.h" 17 18#include <algorithm> 19#include <string> 20#include <utility> 21 22#include "tensorflow/compiler/xla/service/compiler.h" 23#include "tensorflow/compiler/xla/status_macros.h" 24#include "tensorflow/compiler/xla/statusor.h" 25#include "tensorflow/compiler/xla/types.h" 26#include "tensorflow/compiler/xla/util.h" 27#include "tensorflow/core/lib/core/threadpool.h" 28#include "tensorflow/core/lib/strings/str_util.h" 29#include "tensorflow/core/platform/logging.h" 30#include "tensorflow/core/platform/stream_executor_no_cuda.h" 31 32namespace se = ::perftools::gputools; 33 34namespace xla { 35 36using tensorflow::str_util::Lowercase; 37 38// Minimum supported CUDA compute capability is 3.5. 39constexpr int kMinCudaComputeCapabilityMajor = 3; 40constexpr int kMinCudaComputeCapabilityMinor = 5; 41 42// The name of the interpreter platform. 43constexpr char kInterpreter[] = "interpreter"; 44 45namespace { 46 47string CanonicalPlatformName(const string& name) { 48 string platform_str = Lowercase(name); 49 // "cpu" and "host" mean the same thing. 50 if (platform_str == "cpu") { 51 platform_str = "host"; 52 } 53 // "gpu" and "cuda" mean the same thing. 54 if (platform_str == "gpu") { 55 platform_str = "cuda"; 56 } 57 return platform_str; 58} 59 60} // namespace 61 62/* static */ StatusOr<std::vector<se::Platform*>> 63PlatformUtil::GetSupportedPlatforms() { 64 se::MultiPlatformManager::PlatformMap platform_map; 65 se::port::Status platforms_status = se::MultiPlatformManager::WithPlatforms( 66 [&platform_map](se::MultiPlatformManager::PlatformMap* map) { 67 platform_map = *map; 68 return se::port::Status::OK(); 69 }); 70 if (platform_map.empty()) { 71 LOG(WARNING) << "no executor platforms available: platform map is empty"; 72 } 73 74 // Gather all platforms which have an XLA compiler. 75 std::vector<se::Platform*> platforms; 76 for (auto& platform_pair : platform_map) { 77 auto* platform = platform_pair.second; 78 auto compiler_status = Compiler::GetForPlatform(platform); 79 if (compiler_status.ok()) { 80 if (platform->VisibleDeviceCount() > 0) { 81 LOG(INFO) << "platform " << platform->Name() << " present with " 82 << platform->VisibleDeviceCount() << " visible devices"; 83 } else { 84 LOG(WARNING) << "platform " << platform->Name() << " present but no " 85 << "visible devices found"; 86 } 87 // Note: currently we call zero device platforms "supported" on the basis 88 // that, if the platform support was linked in, it was probably intended 89 // to be used for execution, and this way we can flag an error. 90 // 91 // TODO(b/33730287) If we want an alternative version of this behavior we 92 // could add an --xla_fallback_to_host flag. 93 platforms.push_back(platform); 94 } else { 95 LOG(INFO) << "platform " << platform->Name() << " present but no " 96 << "XLA compiler available: " 97 << compiler_status.status().error_message(); 98 } 99 } 100 return platforms; 101} 102 103/* static */ StatusOr<se::Platform*> PlatformUtil::GetSolePlatform() { 104 TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms()); 105 if (platforms.empty()) { 106 return NotFound("no platforms found"); 107 } else if (platforms.size() == 1) { 108 return platforms[0]; 109 } 110 111 // Multiple platforms present and we can't pick a reasonable default. 112 string platforms_string = tensorflow::str_util::Join( 113 platforms, ", ", 114 [](string* out, const se::Platform* p) { out->append(p->Name()); }); 115 return InvalidArgument( 116 "must specify platform because more than one platform found: %s", 117 platforms_string.c_str()); 118} 119 120/* static */ StatusOr<se::Platform*> PlatformUtil::GetDefaultPlatform() { 121 TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms()); 122 if (platforms.empty()) { 123 return NotFound("no platforms found"); 124 } else if (platforms.size() == 1) { 125 return platforms[0]; 126 } else if (platforms.size() == 2) { 127 for (int i = 0; i < 2; i++) { 128 if (Lowercase(platforms[i]->Name()) == kInterpreter && 129 Lowercase(platforms[1 - i]->Name()) != kInterpreter) { 130 return platforms[1 - i]; 131 } 132 } 133 } 134 135 // Multiple platforms present and we can't pick a reasonable default. 136 string platforms_string = tensorflow::str_util::Join( 137 platforms, ", ", 138 [](string* out, const se::Platform* p) { out->append(p->Name()); }); 139 return InvalidArgument( 140 "must specify platform because more than one platform (except for the " 141 "interpreter platform) found: %s", 142 platforms_string.c_str()); 143} 144 145/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatform( 146 const string& platform_name) { 147 string platform_str = CanonicalPlatformName(platform_name); 148 TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); 149 for (se::Platform* platform : platforms) { 150 if (Lowercase(platform->Name()) == platform_str) { 151 return platform; 152 } 153 } 154 return InvalidArgument("platform %s not found", platform_name.c_str()); 155} 156 157/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatformExceptFor( 158 const string& platform_name) { 159 string platform_str = CanonicalPlatformName(platform_name); 160 161 TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); 162 std::vector<se::Platform*> matched; 163 for (se::Platform* platform : platforms) { 164 if (Lowercase(platform->Name()) != platform_name) { 165 matched.push_back(platform); 166 } 167 } 168 if (matched.empty()) { 169 return InvalidArgument("unable to find platform that is not %s", 170 platform_name.c_str()); 171 } 172 if (matched.size() == 1) { 173 return matched[0]; 174 } 175 string matched_string = tensorflow::str_util::Join( 176 matched, ", ", 177 [](string* out, const se::Platform* p) { out->append(p->Name()); }); 178 return InvalidArgument( 179 "found multiple platforms %s, but expected one platform except for %s", 180 matched_string.c_str(), platform_name.c_str()); 181} 182 183// Returns whether the device underlying the given StreamExecutor is supported 184// by XLA. 185static bool IsDeviceSupported(se::StreamExecutor* executor) { 186 const auto& description = executor->GetDeviceDescription(); 187 if (executor->platform()->id() == se::cuda::kCudaPlatformId) { 188 // CUDA devices must have a minimum compute capability. 189 int major_version, minor_version; 190 if (description.cuda_compute_capability(&major_version, &minor_version)) { 191 if (major_version < kMinCudaComputeCapabilityMajor || 192 (major_version == kMinCudaComputeCapabilityMajor && 193 minor_version < kMinCudaComputeCapabilityMinor)) { 194 LOG(INFO) << "StreamExecutor cuda device (" 195 << executor->device_ordinal() << ") is of " 196 << "insufficient compute capability: " 197 << kMinCudaComputeCapabilityMajor << "." 198 << kMinCudaComputeCapabilityMinor << " required, " 199 << "device is " << major_version << "." << minor_version; 200 return false; 201 } 202 } 203 } 204 return true; 205} 206 207/* static */ StatusOr<std::vector<se::StreamExecutor*>> 208PlatformUtil::GetStreamExecutors(se::Platform* platform) { 209 int device_count = platform->VisibleDeviceCount(); 210 if (device_count <= 0) { 211 return NotFound("no %s devices found", platform->Name().c_str()); 212 } 213 if (platform->id() == se::host::kHostPlatformId) { 214 // On host "devices", StreamExecutor exports a device for each hardware 215 // thread. Because we parallelize a single computation across threads, it 216 // doesn't make sense to expose these as separate devices, so fix the number 217 // of devices to one. 218 device_count = 1; 219 } 220 std::vector<se::StreamExecutor*> stream_executors(device_count, nullptr); 221 VLOG(1) << "Initializing devices"; 222 { 223 tensorflow::thread::ThreadPool thread_pool( 224 tensorflow::Env::Default(), "device_initialization", device_count); 225 for (int i = 0; i < device_count; ++i) { 226 thread_pool.Schedule([platform, i, &stream_executors]() { 227 VLOG(1) << "Started device init " << i; 228 se::StreamExecutorConfig config; 229 config.ordinal = i; 230 auto executor_status = platform->GetExecutor(config); 231 if (executor_status.ok()) { 232 se::StreamExecutor* executor = executor_status.ValueOrDie(); 233 if (IsDeviceSupported(executor)) { 234 stream_executors[i] = executor; 235 } 236 } else { 237 LOG(WARNING) << "unable to create StreamExecutor for " 238 << platform->Name() << ":" << i << ": " 239 << executor_status.status().error_message(); 240 } 241 VLOG(1) << "Finished device init " << i; 242 }); 243 } 244 // Block here in thread_pool destructor until all devices are initialized. 245 } 246 VLOG(1) << "Device initialization complete"; 247 if (std::all_of(stream_executors.begin(), stream_executors.end(), 248 [](se::StreamExecutor* s) { return s == nullptr; })) { 249 return InternalError("no supported devices found for platform %s", 250 platform->Name().c_str()); 251 } 252 return stream_executors; 253} 254 255} // namespace xla 256