1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/core/common_runtime/allocator_retry.h" 17 18#include <vector> 19#include "tensorflow/core/lib/core/notification.h" 20#include "tensorflow/core/platform/env.h" 21#include "tensorflow/core/platform/logging.h" 22#include "tensorflow/core/platform/mutex.h" 23#include "tensorflow/core/platform/test.h" 24#include "tensorflow/core/platform/thread_annotations.h" 25#include "tensorflow/core/platform/types.h" 26 27namespace tensorflow { 28namespace { 29 30class FakeAllocator { 31 public: 32 FakeAllocator(size_t cap, int millis_to_wait) 33 : memory_capacity_(cap), millis_to_wait_(millis_to_wait) {} 34 35 // Allocate just keeps track of the number of outstanding allocations, 36 // not their sizes. Assume a constant size for each. 37 void* AllocateRaw(size_t alignment, size_t num_bytes) { 38 return retry_.AllocateRaw( 39 [this](size_t a, size_t nb, bool v) { 40 mutex_lock l(mu_); 41 if (memory_capacity_ > 0) { 42 --memory_capacity_; 43 return good_ptr_; 44 } else { 45 return static_cast<void*>(nullptr); 46 } 47 }, 48 millis_to_wait_, alignment, num_bytes); 49 } 50 51 void DeallocateRaw(void* ptr) { 52 mutex_lock l(mu_); 53 ++memory_capacity_; 54 retry_.NotifyDealloc(); 55 } 56 57 private: 58 AllocatorRetry retry_; 59 void* good_ptr_ = reinterpret_cast<void*>(0xdeadbeef); 60 mutex mu_; 61 size_t memory_capacity_ GUARDED_BY(mu_); 62 int millis_to_wait_; 63}; 64 65// GPUAllocatorRetry is a mechanism to deal with race conditions which 66// are inevitable in the TensorFlow runtime where parallel Nodes can 67// execute in any order. Properly testing this feature would use real 68// multi-threaded race conditions, but that leads to flaky tests as 69// the expected outcome fails to occur with low but non-zero 70// probability. To make these tests reliable we simulate real race 71// conditions by forcing parallel threads to take turns in the 72// interesting part of their interaction with the allocator. This 73// class is the mechanism that imposes turn taking. 74class AlternatingBarrier { 75 public: 76 explicit AlternatingBarrier(int num_users) 77 : num_users_(num_users), next_turn_(0), done_(num_users, false) {} 78 79 void WaitTurn(int user_index) { 80 mutex_lock l(mu_); 81 int wait_cycles = 0; 82 // A user is allowed to proceed out of turn if it waits too long. 83 while (next_turn_ != user_index && wait_cycles++ < 10) { 84 cv_.wait_for(l, std::chrono::milliseconds(1)); 85 } 86 if (next_turn_ == user_index) { 87 IncrementTurn(); 88 cv_.notify_all(); 89 } 90 } 91 92 // When a user quits, stop reserving it a turn. 93 void Done(int user_index) { 94 mutex_lock l(mu_); 95 done_[user_index] = true; 96 if (next_turn_ == user_index) { 97 IncrementTurn(); 98 cv_.notify_all(); 99 } 100 } 101 102 private: 103 void IncrementTurn() EXCLUSIVE_LOCKS_REQUIRED(mu_) { 104 int skipped = 0; 105 while (skipped < num_users_) { 106 next_turn_ = (next_turn_ + 1) % num_users_; 107 if (!done_[next_turn_]) return; 108 ++skipped; 109 } 110 } 111 112 mutex mu_; 113 condition_variable cv_; 114 int num_users_; 115 int next_turn_ GUARDED_BY(mu_); 116 std::vector<bool> done_ GUARDED_BY(mu_); 117}; 118 119class GPUAllocatorRetryTest : public ::testing::Test { 120 protected: 121 GPUAllocatorRetryTest() {} 122 123 void LaunchConsumerThreads(int num_consumers, int cap_needed) { 124 barrier_.reset(new AlternatingBarrier(num_consumers)); 125 consumer_count_.resize(num_consumers, 0); 126 for (int i = 0; i < num_consumers; ++i) { 127 consumers_.push_back(Env::Default()->StartThread( 128 ThreadOptions(), "anon_thread", [this, i, cap_needed]() { 129 do { 130 void* ptr = nullptr; 131 for (int j = 0; j < cap_needed; ++j) { 132 barrier_->WaitTurn(i); 133 ptr = alloc_->AllocateRaw(16, 1); 134 if (ptr == nullptr) { 135 mutex_lock l(mu_); 136 has_failed_ = true; 137 barrier_->Done(i); 138 return; 139 } 140 } 141 ++consumer_count_[i]; 142 for (int j = 0; j < cap_needed; ++j) { 143 barrier_->WaitTurn(i); 144 alloc_->DeallocateRaw(ptr); 145 } 146 } while (!notifier_.HasBeenNotified()); 147 barrier_->Done(i); 148 })); 149 } 150 } 151 152 // Wait up to wait_micros microseconds for has_failed_ to equal expected, 153 // then terminate all threads. 154 void JoinConsumerThreads(bool expected, int wait_micros) { 155 while (wait_micros > 0) { 156 { 157 mutex_lock l(mu_); 158 if (has_failed_ == expected) break; 159 } 160 int interval_micros = std::min(1000, wait_micros); 161 Env::Default()->SleepForMicroseconds(interval_micros); 162 wait_micros -= interval_micros; 163 } 164 notifier_.Notify(); 165 for (auto c : consumers_) { 166 // Blocks until thread terminates. 167 delete c; 168 } 169 } 170 171 std::unique_ptr<FakeAllocator> alloc_; 172 std::unique_ptr<AlternatingBarrier> barrier_; 173 std::vector<Thread*> consumers_; 174 std::vector<int> consumer_count_; 175 Notification notifier_; 176 mutex mu_; 177 bool has_failed_ GUARDED_BY(mu_) = false; 178 int count_ GUARDED_BY(mu_) = 0; 179}; 180 181// Verifies correct retrying when memory is slightly overcommitted but 182// we allow retry. 183TEST_F(GPUAllocatorRetryTest, RetrySuccess) { 184 // Support up to 2 allocations simultaneously, waits up to 1000 msec for 185 // a chance to alloc. 186 alloc_.reset(new FakeAllocator(2, 1000)); 187 // Launch 3 consumers, each of whom needs 1 unit at a time. 188 LaunchConsumerThreads(3, 1); 189 // This should be enough time for each consumer to be satisfied many times. 190 Env::Default()->SleepForMicroseconds(50000); 191 JoinConsumerThreads(false, 0); 192 for (int i = 0; i < 3; ++i) { 193 LOG(INFO) << "Consumer " << i << " is " << consumer_count_[i]; 194 } 195 { 196 mutex_lock l(mu_); 197 EXPECT_FALSE(has_failed_); 198 } 199 EXPECT_GT(consumer_count_[0], 0); 200 EXPECT_GT(consumer_count_[1], 0); 201 EXPECT_GT(consumer_count_[2], 0); 202} 203 204// Verifies OutOfMemory failure when memory is slightly overcommitted 205// and retry is not allowed. Note that this test will fail, i.e. no 206// memory alloc failure will be detected, if it is run in a context that 207// does not permit real multi-threaded execution. 208TEST_F(GPUAllocatorRetryTest, NoRetryFail) { 209 // Support up to 2 allocations simultaneously, waits up to 0 msec for 210 // a chance to alloc. 211 alloc_.reset(new FakeAllocator(2, 0)); 212 // Launch 3 consumers, each of whom needs 1 unit at a time. 213 LaunchConsumerThreads(3, 1); 214 Env::Default()->SleepForMicroseconds(50000); 215 // Will wait up to 10 seconds for proper race condition to occur, resulting 216 // in failure. 217 JoinConsumerThreads(true, 10000000); 218 for (int i = 0; i < 3; ++i) { 219 LOG(INFO) << "Consumer " << i << " is " << consumer_count_[i]; 220 } 221 { 222 mutex_lock l(mu_); 223 EXPECT_TRUE(has_failed_); 224 } 225} 226 227// Verifies OutOfMemory failure when retry is allowed but memory capacity 228// is too low even for retry. 229TEST_F(GPUAllocatorRetryTest, RetryInsufficientFail) { 230 // Support up to 2 allocations simultaneously, waits up to 1000 msec for 231 // a chance to alloc. 232 alloc_.reset(new FakeAllocator(2, 1000)); 233 // Launch 3 consumers, each of whom needs 2 units at a time. We expect 234 // deadlock where 2 consumers each hold 1 unit, and timeout trying to 235 // get the second. 236 LaunchConsumerThreads(3, 2); 237 Env::Default()->SleepForMicroseconds(50000); 238 // We're forcing a race condition, so this will fail quickly, but 239 // give it 10 seconds anyway. 240 JoinConsumerThreads(true, 10000000); 241 for (int i = 0; i < 3; ++i) { 242 LOG(INFO) << "Consumer " << i << " is " << consumer_count_[i]; 243 } 244 { 245 mutex_lock l(mu_); 246 EXPECT_TRUE(has_failed_); 247 } 248} 249 250} // namespace 251} // namespace tensorflow 252