gpu_device.cc revision 94538a944ed71eb8c0c22213fef245e09165f935
1/* Copyright 2015 Google Inc. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// TODO(opensource): Use a more generic sounding preprocessor name than
17// GOOGLE_CUDA
18#if GOOGLE_CUDA
19
20#define EIGEN_USE_GPU
21
22#include "tensorflow/core/common_runtime/gpu/gpu_device.h"
23
24#include <stdlib.h>
25#include <string.h>
26#include <algorithm>
27
28//#include "base/commandlineflags.h"
29#include "tensorflow/stream_executor/cuda/cuda_activation.h"
30#include "tensorflow/stream_executor/multi_platform_manager.h"
31#include "tensorflow/stream_executor/stream.h"
32#include "tensorflow/stream_executor/stream_executor.h"
33#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
34#include "tensorflow/core/common_runtime/device_factory.h"
35#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
36#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
37#include "tensorflow/core/common_runtime/gpu/gpu_stream_util.h"
38#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
39#include "tensorflow/core/common_runtime/gpu/process_state.h"
40#include "tensorflow/core/common_runtime/gpu_device_context.h"
41#include "tensorflow/core/common_runtime/local_device.h"
42#include "tensorflow/core/framework/allocator.h"
43#include "tensorflow/core/framework/device_base.h"
44#include "tensorflow/core/framework/op_kernel.h"
45#include "tensorflow/core/framework/types.h"
46#include "tensorflow/core/graph/types.h"
47#include "tensorflow/core/lib/gtl/stl_util.h"
48#include "tensorflow/core/lib/strings/numbers.h"
49#include "tensorflow/core/lib/strings/strcat.h"
50#include "tensorflow/core/platform/logging.h"
51#include "tensorflow/core/platform/macros.h"
52#include "tensorflow/core/platform/port.h"
53#include "tensorflow/core/platform/tracing.h"
54#include "tensorflow/core/public/session_options.h"
55#include "tensorflow/core/public/status.h"
56#include "tensorflow/core/public/tensor.h"
57#include "tensorflow/core/util/device_name_utils.h"
58
59#if defined(PLATFORM_GOOGLE)
60DEFINE_bool(brain_gpu_sync_every_op, false,
61            "If true, call GPUUtil::Sync() between every dispatched opkernel.");
62
63DEFINE_int32(brain_gpu_max_streams, 1,
64             "Max number of GPU streams to use for computation.");
65#else
66// TODO(opensource): These should be made options in some options struct,
67// rather than flags.
68bool FLAGS_brain_gpu_sync_every_op = false;
69tensorflow::int32 FLAGS_brain_gpu_max_streams = 1;
70#endif
71
72namespace gpu = ::perftools::gputools;
73
74namespace tensorflow {
75
76// Eigen Ops directly allocate memory only for temporary buffers used
77// during OpKernel::Compute().  The recommended way of allocating such
78// memory is via OpKernelContext::allocate_temp().  However, Eigen Ops
79// don't have access to OpKernelContext, instead they get access to
80// memory directly through the device allocator.  As an Open Source
81// project, Eigen assumes allocator semantics similar to those of the
82// CUDA memory allocator, and may not work correctly due to race
83// conditions if used with some other allocator.  For safety, we need
84// to delay deallocation calls out of Eigen until all events on the
85// corresponding stream have completed.  The following two classes
86// serve this purpose in two different compilation environments.
87
88#if defined(__GCUDACC__) || defined(__GCUDACC_HOST__)
89class EigenAllocator : public ::Eigen::Allocator {
90 public:
91  explicit EigenAllocator(gpu::Stream* stream, ::tensorflow::Allocator* alloc,
92                          EventMgr* em)
93      : stream_(stream), allocator_(alloc), em_(em) {}
94
95  void* allocate(size_t num_bytes) const override {
96    void* ret = allocator_->AllocateRaw(32 /* alignment */, num_bytes);
97    // Eigen doesn't typically check the return pointer from allocate,
98    // so we do it here and die with a more helpful error message.
99    if (ret == nullptr) {
100      LOG(FATAL) << "EigenAllocator for GPU ran out of memory when allocating "
101                 << num_bytes << ". See error logs for more detailed info.";
102    }
103    return ret;
104  }
105
106  void deallocate(void* buffer) const override {
107    em_->ThenDeleteBuffer(stream_, {allocator_, buffer});
108  }
109
110 private:
111  gpu::Stream* stream_;                 // Not owned.
112  ::tensorflow::Allocator* allocator_;  // Not owned.
113  ::tensorflow::EventMgr* em_;          // Not owned.
114
115  TF_DISALLOW_COPY_AND_ASSIGN(EigenAllocator);
116};
117
118#else
119class EigenCudaStreamDevice : public ::Eigen::StreamInterface {
120 public:
121  EigenCudaStreamDevice(const cudaStream_t* cuda_stream, int gpu_id,
122                        ::tensorflow::Allocator* alloc)
123      : stream_(cuda_stream), allocator_(alloc) {
124    Eigen::initializeDeviceProp();
125    device_prop_ = &Eigen::m_deviceProperties[gpu_id];
126  }
127
128  const cudaStream_t& stream() const override { return *stream_; }
129  const cudaDeviceProp& deviceProperties() const override {
130    return *device_prop_;
131  }
132
133  void* allocate(size_t num_bytes) const override {
134    void* ret = allocator_->AllocateRaw(32 /* alignment */, num_bytes);
135    if (ret == nullptr) {
136      LOG(FATAL) << "EigenAllocator for GPU ran out of memory when allocating "
137                 << num_bytes << ". See error logs for more detailed info.";
138    }
139
140    return ret;
141  }
142  void deallocate(void* buffer) const override {
143    AsyncFreeData* afData = new AsyncFreeData(allocator_, buffer);
144    cudaError_t err = cudaStreamAddCallback(*stream_, asyncFree, afData, 0);
145    CHECK_EQ(err, cudaSuccess);
146  }
147
148 private:
149  struct AsyncFreeData {
150    AsyncFreeData(::tensorflow::Allocator* a, void* p)
151        : allocator_(a), address_(p) {}
152    ::tensorflow::Allocator* allocator_;
153    void* address_;
154  };
155
156  static void CUDART_CB asyncFree(cudaStream_t stream, cudaError_t status,
157                                  void* userData) {
158    AsyncFreeData* data = static_cast<AsyncFreeData*>(userData);
159    data->allocator_->DeallocateRaw(data->address_);
160    delete data;
161  }
162
163  const cudaStream_t* stream_;          // Not owned.
164  const cudaDeviceProp* device_prop_;   // Not owned.
165  ::tensorflow::Allocator* allocator_;  // Not owned.
166
167  TF_DISALLOW_COPY_AND_ASSIGN(EigenCudaStreamDevice);
168};
169
170#endif
171
172BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name,
173                             Bytes memory_limit, BusAdjacency bus_adjacency,
174                             int gpu_id, const string& physical_device_desc,
175                             Allocator* gpu_allocator, Allocator* cpu_allocator)
176    : LocalDevice(options, Device::BuildDeviceAttributes(
177                               name, DEVICE_GPU, memory_limit, bus_adjacency,
178                               physical_device_desc),
179                  gpu_allocator),
180      gpu_allocator_(gpu_allocator),
181      cpu_allocator_(cpu_allocator),
182      gpu_id_(gpu_id) {
183  gpu::StreamExecutor* executor =
184      GPUMachineManager()->ExecutorForDevice(gpu_id_).ValueOrDie();
185  if (!executor) {
186    LOG(ERROR) << "Failed to get StreamExecutor for device " << gpu_id_;
187    return;
188  }
189  em_.reset(new EventMgr(executor, options.config.gpu_options()));
190
191  if (FLAGS_brain_gpu_max_streams < 1) {
192    LOG(FATAL) << "Invalid value for brain_gpu_max_streams.";
193  }
194
195  // Create the specified number of GPU streams
196  for (int i = 0; i < FLAGS_brain_gpu_max_streams; i++) {
197    auto stream = new gpu::Stream(executor);
198    stream->Init();
199    VLOG(2) << "Created stream[" << i << "] = " << stream;
200    streams_.push_back(stream);
201    device_contexts_.push_back(new GPUDeviceContext(i, stream));
202  }
203  gpu_device_info_ = new GpuDeviceInfo;
204  gpu_device_info_->stream = streams_[0];
205  gpu_device_info_->default_context = device_contexts_[0];
206  gpu_device_info_->event_mgr = em_.get();
207  set_tensorflow_gpu_device_info(gpu_device_info_);
208}
209
210BaseGPUDevice::~BaseGPUDevice() {
211  delete gpu_device_info_;
212  for (auto ctx : device_contexts_) ctx->Unref();
213  gtl::STLDeleteElements(&streams_);
214}
215
216Status BaseGPUDevice::FillContextMap(const Graph* graph,
217                                     DeviceContextMap* device_context_map) {
218  VLOG(2) << "FillContextMap";
219
220  const auto num_streams = streams_.size();
221  // Special case for single stream.
222  if (num_streams == 1) {
223    return Status::OK();
224  }
225  const int64 before = Env::Default()->NowMicros();
226  gpu_stream_util::AssignStreamsOpts opts;
227  opts.max_streams = num_streams;
228  std::unordered_map<int, int> node_to_stream_id;
229  TF_RETURN_IF_ERROR(
230      gpu_stream_util::AssignStreams(graph, opts, &node_to_stream_id));
231  int64 elapsed = Env::Default()->NowMicros() - before;
232  VLOG(3) << "AssignStreams took " << elapsed << "us";
233
234  // Fill in the context map.  It is OK for this map to contain
235  // duplicate DeviceContexts so long as we increment the refcount.
236  for (Node* n : graph->nodes()) {
237    auto mapped_stream = node_to_stream_id[n->id()];
238    CHECK_LE(mapped_stream, num_streams);
239    auto ctx = device_contexts_[mapped_stream];
240    VLOG(3) << "Assigned stream " << node_to_stream_id[n->id()]
241            << " ==> stream[" << ctx->stream_id() << "] for node id " << n->id()
242            << " " << n->type_string() << " " << n->name();
243    ctx->Ref();
244    device_context_map->insert(std::make_pair(n->id(), ctx));
245  }
246
247  return Status::OK();
248}
249
250void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
251  // ScopedActivity is cheap when tracing is not active, but we
252  // can avoid computing the Hash64.
253  // TODO(pbar) This would no longer be needed if Ops have a unique id.
254  const uint64 id = port::Tracing::IsActive() ? Hash64(op_kernel->name()) : 0;
255  port::Tracing::ScopedActivity region(port::Tracing::EventCategory::kCompute,
256                                       id);
257
258  GPUDeviceContext* gpu_device_context = device_contexts_[0];
259  if (context->op_device_context() != nullptr) {
260    gpu_device_context =
261        static_cast<GPUDeviceContext*>(context->op_device_context());
262  }
263  gpu::Stream* stream = gpu_device_context->stream();
264  const auto stream_id = gpu_device_context->stream_id();
265
266  const bool vlog_1 = VLOG_IS_ON(1);
267  const bool vlog_2 = vlog_1 && VLOG_IS_ON(2);
268
269  if (vlog_1) {
270    VLOG(1) << "GpuDevice::Compute " << op_kernel->name() << " op "
271            << op_kernel->def().op() << " on GPU" << gpu_id_ << " stream["
272            << stream_id << "]";
273  }
274
275  // NOTE(tucker): We need to discriminate between Eigen GPU
276  // operations and all others.  If an operation is Eigen
277  // implemented (or otherwise tries to launch a cuda kernel
278  // directly), we need to establish a stacked-scoped environment
279  // that directs it to execute on the proper device.  Otherwise we
280  // expect the Op to use StreamExecutor directly and correctly.  The
281  // way we make this discrimination is quite hacky: At the moment
282  // the only non-Eigen GPU Op is the recv-op, which is known to be
283  // asynchronous.
284  if (op_kernel->is_internal() && op_kernel->type_string() == "_Recv") {
285    context->SetStatus(errors::Internal(
286        "Invalid synchronous 'Compute' on GPU for '_Recv' op"));
287  } else {
288    port::Tracing::ScopedAnnotation annotation(op_kernel->name(),
289                                               op_kernel->type_string());
290
291    const auto num_streams = streams_.size();
292    if (num_streams > 1) {
293      // If this op's device context is different from the other contexts,
294      // we must wait on the stream.
295      for (int i = 0; i < context->num_inputs(); ++i) {
296        const GPUDeviceContext* idc =
297            static_cast<GPUDeviceContext*>(context->input_device_context(i));
298        OP_REQUIRES(context, idc != nullptr,
299                    errors::Internal("Input device context ", i,
300                                     " was not set properly."));
301        if (vlog_2) {
302          const void* base;
303          size_t len;
304          if (context->has_input(i)) {
305            if (IsRefType(context->input_dtype(i))) {
306              Tensor tensor = context->mutable_input(i, false);
307              base = DMAHelper::base(&tensor);
308              len = tensor.TotalBytes();
309            } else {
310              const Tensor& tensor = context->input(i);
311              base = DMAHelper::base(&tensor);
312              len = tensor.TotalBytes();
313            }
314            VLOG(2) << "Input " << i << " " << base << "  " << len;
315            VLOG(2) << "  stream[" << stream_id << "].ThenWaitFor(stream["
316                    << idc->stream_id() << "])"
317                    << ((idc->stream() == stream) ? " not needed" : "");
318          }
319        }
320        if (idc->stream() != stream) stream->ThenWaitFor(idc->stream());
321      }
322    }
323    gpu::cuda::ScopedActivateExecutorContext scoped_activation{
324        stream->parent(), gpu::cuda::MultiOpActivation::kYes};
325
326    if (FLAGS_brain_gpu_sync_every_op) {
327      op_kernel->Compute(context);
328      if (context->status().ok()) {
329        // Note: GPUUtil::Sync() only syncs the default stream.
330        // We need to either sync the stream used by this op, or
331        // all streams.  Given that this flag is typically used for
332        // debugging it makes more sense to sync all GPU activity.
333        context->SetStatus(GPUUtil::SyncAll(this));
334      }
335    } else {
336      // Keep a copy of the inputs before Compute runs, in case they get
337      // deleted. TODO(misard) this will be fixed when the tracking is
338      // done right.
339      EventMgr::TensorReferenceVector tensor_refs;
340      const int N_inputs = context->num_inputs();
341      tensor_refs.reserve(N_inputs + context->num_outputs());
342      for (int ii = 0; ii < N_inputs; ++ii) {
343        if (context->has_input(ii)) {
344          if (IsRefType(context->input_dtype(ii))) {
345            Tensor in = context->mutable_input(ii, false);
346            tensor_refs.push_back(TensorReference(in));
347          } else {
348            const Tensor& in = context->input(ii);
349            tensor_refs.push_back(TensorReference(in));
350          }
351        }
352      }
353      op_kernel->Compute(context);
354      if (context->status().ok()) {
355        // The GPU kernel has been queued, but may not complete for some
356        // time.  As soon as this function completes, the caller will
357        // discard its refs on the inputs, outputs and any scratch
358        // tensors it created. Create additional refs here that will be
359        // held until the kernel completes.
360        for (int ii = 0; ii < context->num_temps(); ++ii) {
361          Tensor* temp = context->temp(ii);
362          if (vlog_2) {
363            VLOG(2) << "Saving ref to temp Tensor @ " << DMAHelper::base(temp);
364          }
365          tensor_refs.push_back(TensorReference(*temp));
366        }
367        for (int ii = 0; ii < context->num_outputs(); ++ii) {
368          Tensor* temp = context->mutable_output(ii);
369          if (nullptr != temp) {
370            tensor_refs.push_back(TensorReference(*temp));
371          }
372        }
373        em_->ThenDeleteTensors(stream, tensor_refs);
374      }
375    }
376  }
377}
378
379Status BaseGPUDevice::Sync() { return GPUUtil::Sync(this); }
380
381void BaseGPUDevice::ComputeAsync(AsyncOpKernel* op_kernel,
382                                 OpKernelContext* context,
383                                 AsyncOpKernel::DoneCallback done) {
384  GPUDeviceContext* gpu_device_context = device_contexts_[0];
385  if (context->op_device_context() != nullptr) {
386    gpu_device_context =
387        static_cast<GPUDeviceContext*>(context->op_device_context());
388  }
389  const auto stream_id = gpu_device_context->stream_id();
390
391  VLOG(1) << "GpuDevice::ComputeAsync " << op_kernel->name() << " op "
392          << op_kernel->def().op() << " on GPU" << gpu_id_ << " stream["
393          << stream_id << "]";
394
395  port::Tracing::TraceMe activity(
396      strings::StrCat(op_kernel->name(), ":", op_kernel->type_string()));
397  op_kernel->ComputeAsync(context, done);
398}
399
400Status BaseGPUDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
401                                          const AllocatorAttributes alloc_attrs,
402                                          Tensor* tensor) {
403  AllocatorAttributes attr;
404  attr.set_on_host(true);
405  attr.set_gpu_compatible(true);
406  Allocator* host_alloc = GetAllocator(attr);
407  Tensor parsed(tensor_proto.dtype());
408  if (!parsed.FromProto(host_alloc, tensor_proto)) {
409    return errors::InvalidArgument("Cannot parse tensor from proto: ",
410                                   tensor_proto.DebugString());
411  }
412  Status status;
413  if (alloc_attrs.on_host()) {
414    *tensor = parsed;
415  } else {
416    if (!DMAHelper::CanUseDMA(&parsed)) {
417      return errors::Internal("GPU copy from non-DMA ",
418                              DataTypeString(parsed.dtype()), " tensor");
419    }
420    Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
421    port::Tracing::ScopedAnnotation annotation("MakeTensorFromProto");
422    Notification n;
423    device_contexts_[0]->CopyCPUTensorToDevice(&parsed, this, &copy,
424                                               [&n, &status](const Status& s) {
425                                                 status = s;
426                                                 n.Notify();
427                                               });
428    n.WaitForNotification();
429    *tensor = copy;
430  }
431  return status;
432}
433
434namespace {
435#if defined(__GCUDACC__) || defined(__GCUDACC_HOST__)
436class ConcretePerOpGpuDevice : public PerOpGpuDevice {
437 public:
438  explicit ConcretePerOpGpuDevice(gpu::Stream* stream,
439                                  Allocator* base_allocator,
440                                  ::tensorflow::EventMgr* em)
441      : allocator_(stream, base_allocator, em), device_(stream, &allocator_) {}
442
443  const Eigen::GpuDevice& device() const override { return device_; }
444
445 private:
446  EigenAllocator allocator_;
447  Eigen::GpuDevice device_;
448};
449#else
450class ConcretePerOpGpuDevice : public PerOpGpuDevice {
451 public:
452  explicit ConcretePerOpGpuDevice(const cudaStream_t* cuda_stream, int gpu_id,
453                                  Allocator* base_allocator)
454      : stream_device_(cuda_stream, gpu_id, base_allocator),
455        device_(&stream_device_) {}
456
457  const Eigen::GpuDevice& device() const override { return device_; }
458
459 private:
460  EigenCudaStreamDevice stream_device_;
461  Eigen::GpuDevice device_;
462};
463#endif
464}  // namespace
465
466const PerOpGpuDevice* BaseGPUDevice::NewDevice(int stream_id,
467                                               Allocator* allocator) {
468#if defined(__GCUDACC__) || defined(__GCUDACC_HOST__)
469  return new ConcretePerOpGpuDevice(streams_[stream_id], allocator, em_.get());
470#else
471  const cudaStream_t* cuda_stream = reinterpret_cast<const cudaStream_t*>(
472      streams_[stream_id]->implementation()->CudaStreamMemberHack());
473  return new ConcretePerOpGpuDevice(cuda_stream, gpu_id_, allocator);
474#endif
475}
476
477const PerOpGpuDevice* BaseGPUDevice::MakeGpuDevice(DeviceContext* dc,
478                                                   Allocator* allocator) {
479  if (dc) {
480    const GPUDeviceContext* gpu_dc = static_cast<GPUDeviceContext*>(dc);
481    const int stream_id = gpu_dc->stream_id();
482    VLOG(1) << "  eigen_gpu_device(" << dc << ") => stream[" << stream_id
483            << "]";
484    CHECK_LT(stream_id, streams_.size());
485    return NewDevice(stream_id, allocator);
486  } else {
487    return NewDevice(0, allocator);
488  }
489}
490
491void BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
492                                         const string& name_prefix,
493                                         std::vector<Device*>* devices) {
494  int n = INT_MAX;
495  auto iter = options.config.device_count().find("GPU");
496  if (iter != options.config.device_count().end()) {
497    n = iter->second;
498  }
499  std::vector<int> valid_gpu_ids;
500  GetValidDeviceIds(&valid_gpu_ids);
501  if (static_cast<size_t>(n) > valid_gpu_ids.size()) {
502    n = valid_gpu_ids.size();
503  }
504  for (int i = 0; i < n; i++) {
505    devices->push_back(CreateGPUDevice(
506        options, strings::StrCat(name_prefix, "/gpu:", i), valid_gpu_ids[i]));
507  }
508}
509
510namespace {
511int64 MinSystemMemory(int64 available_memory) {
512  // We use the following heuristic for now:
513  //
514  // If the available_memory is < 2GiB, we allocate 200MiB to system memory.
515  // Otherwise, allocate 300MiB to system memory.
516  //
517  // In the future we could be more sophisticated by using a table of
518  // devices.
519  if (available_memory < (1LL << 31)) {
520    // 200MiB
521    return 209715200LL;
522  } else {
523    // max(300 MiB, 0.95 * available_memory)
524    return std::max(314572800LL, static_cast<int64>(available_memory * 0.05));
525  }
526}
527}  // namespace
528
529static string GetShortDeviceDescription(int device_id,
530                                        const gpu::DeviceDescription& desc) {
531  return strings::StrCat("device: ", device_id, ", name: ", desc.name(),
532                         ", pci bus id: ", desc.pci_bus_id());
533}
534
535LocalDevice* BaseGPUDeviceFactory::CreateGPUDevice(
536    const SessionOptions& options, const string& name, int gpu_id) {
537  CHECK_GE(gpu_id, 0);
538
539  // Look up the device, to see its attributes.
540  gpu::Platform* gpu_platform = GPUMachineManager();
541  CHECK_LT(gpu_id, gpu_platform->VisibleDeviceCount());
542  gpu::StreamExecutor* se =
543      gpu_platform->ExecutorForDevice(gpu_id).ValueOrDie();
544  const gpu::DeviceDescription& desc = se->GetDeviceDescription();
545
546  int64 total_memory, available_memory;
547  CHECK(se->DeviceMemoryUsage(&available_memory, &total_memory));
548
549  int64 allocated_memory = available_memory;
550  double config_memory_fraction =
551      options.config.gpu_options().per_process_gpu_memory_fraction();
552  if (config_memory_fraction == 0) {
553    const int64 min_system_memory = MinSystemMemory(available_memory);
554    if (min_system_memory < allocated_memory) {
555      allocated_memory -= min_system_memory;
556    }
557  } else {
558    allocated_memory *= config_memory_fraction;
559  }
560
561  Bytes allocated_bytes = static_cast<Bytes>(allocated_memory);
562
563  // Get GPU BusAdjacency from its reported NUMA affinity.
564  // Because GPUs are virtualized in some environments, we can't just
565  // use the GPU id.
566  BusAdjacency bus_adjacency = BUS_ANY;
567  switch (desc.numa_node()) {
568    case 0:
569      bus_adjacency = BUS_0;
570      break;
571    case 1:
572      bus_adjacency = BUS_1;
573      break;
574    default:
575      bus_adjacency = BUS_ANY;
576  }
577  VLOG(1) << "GPUDevice id " << gpu_id << " on bus " << bus_adjacency
578          << " numa: " << desc.numa_node() << " pci: " << desc.pci_bus_id();
579
580  ProcessState* process_state = ProcessState::singleton();
581  return CreateGPUDevice(options, name, allocated_bytes, bus_adjacency, gpu_id,
582                         GetShortDeviceDescription(gpu_id, desc),
583                         process_state->GetGPUAllocator(
584                             gpu_id, allocated_memory,
585                             options.config.gpu_options().allocator_type()),
586                         process_state->GetCPUAllocator(desc.numa_node()));
587}
588
589static int GetMinGPUMultiprocessorCount() {
590  static const int kDefaultMinGPUMultiprocessorCount = 8;
591
592  const char* tf_min_gpu_core_count = getenv("TF_MIN_GPU_MULTIPROCESSOR_COUNT");
593
594  if (tf_min_gpu_core_count == nullptr ||
595      strcmp(tf_min_gpu_core_count, "") == 0) {
596    return kDefaultMinGPUMultiprocessorCount;
597  }
598
599  int min_gpu_core_count = -1;
600  if (strings::safe_strto32(tf_min_gpu_core_count, &min_gpu_core_count)) {
601    if (min_gpu_core_count >= 0) {
602      return min_gpu_core_count;
603    }
604  }
605
606  LOG(ERROR) << "Invalid minimum GPU multiprocessor count: ["
607             << tf_min_gpu_core_count << "]. "
608             << "Using the default value: "
609             << kDefaultMinGPUMultiprocessorCount;
610  return kDefaultMinGPUMultiprocessorCount;
611}
612
613namespace {
614
615struct CudaVersion {
616  // Initialize from version_name in the form of "3.5"
617  explicit CudaVersion(const std::string& version_name) {
618    size_t dot_pos = version_name.find('.');
619    CHECK(dot_pos != string::npos);
620    string major_str = version_name.substr(0, dot_pos);
621    CHECK(strings::safe_strto32(major_str.c_str(), &major_part));
622    string minor_str = version_name.substr(dot_pos + 1);
623    CHECK(strings::safe_strto32(minor_str.c_str(), &minor_part));
624  }
625  CudaVersion() {}
626  bool operator<(const CudaVersion& other) const {
627    if (this->major_part != other.major_part) {
628      return this->major_part < other.major_part;
629    }
630    return this->minor_part < other.minor_part;
631  }
632  friend std::ostream& operator<<(std::ostream& os,
633                                  const CudaVersion& version) {
634    os << version.major_part << "." << version.minor_part;
635    return os;
636  }
637  int major_part = -1;
638  int minor_part = -1;
639};
640
641// "configure" uses the specific name to substitute the following string.
642// If you change it, make sure you modify "configure" as well.
643std::vector<CudaVersion> supported_cuda_compute_capabilities = {
644    CudaVersion("3.5"), CudaVersion("5.2")};
645
646}  // namespace
647
648void BaseGPUDeviceFactory::GetValidDeviceIds(std::vector<int>* ids) {
649  auto gpu_manager = GPUMachineManager();
650  int min_gpu_core_count = GetMinGPUMultiprocessorCount();
651  if (gpu_manager) {
652    CHECK(!supported_cuda_compute_capabilities.empty());
653    CudaVersion min_supported_capability =
654        *std::min_element(supported_cuda_compute_capabilities.begin(),
655                          supported_cuda_compute_capabilities.end());
656
657    auto visible_device_count = gpu_manager->VisibleDeviceCount();
658    for (int i = 0; i < gpu_manager->VisibleDeviceCount(); ++i) {
659      auto exec_status = gpu_manager->ExecutorForDevice(i);
660      if (!exec_status.ok()) {
661        continue;
662      }
663      gpu::StreamExecutor* se = exec_status.ValueOrDie();
664      const gpu::DeviceDescription& desc = se->GetDeviceDescription();
665      CudaVersion device_capability;
666      if (!desc.cuda_compute_capability(&device_capability.major_part,
667                                        &device_capability.minor_part)) {
668        continue;
669      }
670      // Only GPUs with no less than the minimum supported compute capability is
671      // accepted.
672      if (device_capability < min_supported_capability) {
673        LOG(INFO) << "Ignoring gpu device "
674                  << "(" << GetShortDeviceDescription(i, desc) << ") "
675                  << "with Cuda compute capability " << device_capability
676                  << ". The minimum required Cuda capability is "
677                  << min_supported_capability << ".";
678        continue;
679      }
680
681      // TensorFlow currently places computation on devices assuming
682      // they have similar capability.
683      //
684      // If there are multiple GPUs available on the machine, only
685      // consider GPUs with 8 or more multiprocessors.
686      //
687      // TODO(vrv): In the medium term: we should only filter out GPUs
688      // that are slow relative to the fastest GPU. In the long term,
689      // TensorFlow should support automatic placement based on
690      // capability.
691      if (visible_device_count > 1) {
692        if (desc.core_count() < min_gpu_core_count) {
693          LOG(INFO) << "Ignoring gpu device "
694                    << "(" << GetShortDeviceDescription(i, desc) << ") "
695                    << "with Cuda multiprocessor count: " << desc.core_count()
696                    << ". The minimum required count is " << min_gpu_core_count
697                    << ". You can adjust this requirement with the env var "
698                       "TF_MIN_GPU_MULTIPROCESSOR_COUNT.";
699          continue;
700        }
701      }
702
703      int new_id = ids->size();
704      ids->push_back(i);
705
706      LOG(INFO) << "Creating TensorFlow device (/gpu:" << new_id << ") -> "
707                << "(" << GetShortDeviceDescription(i, desc) << ")";
708    }
709  }
710}
711
712}  // namespace tensorflow
713
714#endif  // GOOGLE_CUDA
715