1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/stream_executor/executor_cache.h"
17
18#include "tensorflow/stream_executor/lib/stringprintf.h"
19
20namespace perftools {
21namespace gputools {
22
23port::StatusOr<StreamExecutor*> ExecutorCache::GetOrCreate(
24    const StreamExecutorConfig& config,
25    const std::function<ExecutorFactory>& factory) {
26  // In the fast path case, the cache already has an entry and we can just
27  // return after Get() which only takes a shared lock and not a unique lock.
28  // If we need to create, we take a unique lock on cache_.
29  auto fast_result = Get(config);
30  if (fast_result.ok()) {
31    return fast_result;
32  }
33
34  Entry* entry = nullptr;
35  {
36    mutex_lock lock{mutex_};
37    entry = &cache_[config.ordinal];
38    // Release the map lock; the address of 'entry' is stable because
39    // std::map guarantees reference stability.
40  }
41
42  // Acquire the per-Entry mutex without holding the map mutex. Initializing
43  // an Executor may be expensive, so we want to allow concurrent
44  // initialization of different entries.
45  mutex_lock lock{entry->configurations_mutex};
46  for (const auto& iter : entry->configurations) {
47    if (iter.first.plugin_config == config.plugin_config &&
48        iter.first.device_options == config.device_options) {
49      VLOG(2) << "hit in cache";
50      return iter.second.get();
51    }
52  }
53
54  VLOG(2) << "building executor";
55  port::StatusOr<std::unique_ptr<StreamExecutor>> result = factory();
56  if (!result.ok()) {
57    VLOG(2) << "failed to get build executor: " << result.status();
58    // If construction failed, leave the cache Entry around, but with a null
59    // executor.
60    return result.status();
61  }
62  entry->configurations.emplace_back(config, std::move(result.ValueOrDie()));
63  return entry->configurations.back().second.get();
64}
65
66port::StatusOr<StreamExecutor*> ExecutorCache::Get(
67    const StreamExecutorConfig& config) {
68  Entry* entry = nullptr;
69  {
70    tf_shared_lock lock{mutex_};
71    auto it = cache_.find(config.ordinal);
72    if (it != cache_.end()) {
73      entry = &it->second;
74    } else {
75      return port::Status(port::error::NOT_FOUND,
76                          port::Printf("No executors registered for ordinal %d",
77                                       config.ordinal));
78    }
79  }
80  tf_shared_lock lock{entry->configurations_mutex};
81  if (entry->configurations.empty()) {
82    return port::Status(
83        port::error::NOT_FOUND,
84        port::Printf("No executors registered for ordinal %d", config.ordinal));
85  }
86  for (const auto& iter : entry->configurations) {
87    if (iter.first.plugin_config == config.plugin_config &&
88        iter.first.device_options == config.device_options) {
89      VLOG(2) << "hit in cache for device ordinal " << config.ordinal;
90      return iter.second.get();
91    }
92  }
93  return port::Status(port::error::NOT_FOUND,
94                      "No executor found with a matching config.");
95}
96
97void ExecutorCache::DestroyAllExecutors() {
98  mutex_lock lock{mutex_};
99  cache_.clear();
100}
101
102ExecutorCache::Entry::~Entry() {
103  mutex_lock lock{configurations_mutex};
104  configurations.clear();
105}
106
107}  // namespace gputools
108}  // namespace perftools
109