1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "thread_pool.h"
18
19#include "base/casts.h"
20#include "base/stl_util.h"
21#include "runtime.h"
22#include "thread-inl.h"
23
24namespace art {
25
26static constexpr bool kMeasureWaitTime = false;
27
28ThreadPoolWorker::ThreadPoolWorker(ThreadPool* thread_pool, const std::string& name,
29                                   size_t stack_size)
30    : thread_pool_(thread_pool),
31      name_(name) {
32  std::string error_msg;
33  stack_.reset(MemMap::MapAnonymous(name.c_str(), nullptr, stack_size, PROT_READ | PROT_WRITE,
34                                    false, &error_msg));
35  CHECK(stack_.get() != nullptr) << error_msg;
36  const char* reason = "new thread pool worker thread";
37  pthread_attr_t attr;
38  CHECK_PTHREAD_CALL(pthread_attr_init, (&attr), reason);
39  CHECK_PTHREAD_CALL(pthread_attr_setstack, (&attr, stack_->Begin(), stack_->Size()), reason);
40  CHECK_PTHREAD_CALL(pthread_create, (&pthread_, &attr, &Callback, this), reason);
41  CHECK_PTHREAD_CALL(pthread_attr_destroy, (&attr), reason);
42}
43
44ThreadPoolWorker::~ThreadPoolWorker() {
45  CHECK_PTHREAD_CALL(pthread_join, (pthread_, NULL), "thread pool worker shutdown");
46}
47
48void ThreadPoolWorker::Run() {
49  Thread* self = Thread::Current();
50  Task* task = NULL;
51  thread_pool_->creation_barier_.Wait(self);
52  while ((task = thread_pool_->GetTask(self)) != NULL) {
53    task->Run(self);
54    task->Finalize();
55  }
56}
57
58void* ThreadPoolWorker::Callback(void* arg) {
59  ThreadPoolWorker* worker = reinterpret_cast<ThreadPoolWorker*>(arg);
60  Runtime* runtime = Runtime::Current();
61  CHECK(runtime->AttachCurrentThread(worker->name_.c_str(), true, NULL, false));
62  // Do work until its time to shut down.
63  worker->Run();
64  runtime->DetachCurrentThread();
65  return NULL;
66}
67
68void ThreadPool::AddTask(Thread* self, Task* task) {
69  MutexLock mu(self, task_queue_lock_);
70  tasks_.push_back(task);
71  // If we have any waiters, signal one.
72  if (started_ && waiting_count_ != 0) {
73    task_queue_condition_.Signal(self);
74  }
75}
76
77ThreadPool::ThreadPool(const char* name, size_t num_threads)
78  : name_(name),
79    task_queue_lock_("task queue lock"),
80    task_queue_condition_("task queue condition", task_queue_lock_),
81    completion_condition_("task completion condition", task_queue_lock_),
82    started_(false),
83    shutting_down_(false),
84    waiting_count_(0),
85    start_time_(0),
86    total_wait_time_(0),
87    // Add one since the caller of constructor waits on the barrier too.
88    creation_barier_(num_threads + 1),
89    max_active_workers_(num_threads) {
90  Thread* self = Thread::Current();
91  while (GetThreadCount() < num_threads) {
92    const std::string name = StringPrintf("%s worker thread %zu", name_.c_str(), GetThreadCount());
93    threads_.push_back(new ThreadPoolWorker(this, name, ThreadPoolWorker::kDefaultStackSize));
94  }
95  // Wait for all of the threads to attach.
96  creation_barier_.Wait(self);
97}
98
99void ThreadPool::SetMaxActiveWorkers(size_t threads) {
100  MutexLock mu(Thread::Current(), task_queue_lock_);
101  CHECK_LE(threads, GetThreadCount());
102  max_active_workers_ = threads;
103}
104
105ThreadPool::~ThreadPool() {
106  {
107    Thread* self = Thread::Current();
108    MutexLock mu(self, task_queue_lock_);
109    // Tell any remaining workers to shut down.
110    shutting_down_ = true;
111    // Broadcast to everyone waiting.
112    task_queue_condition_.Broadcast(self);
113    completion_condition_.Broadcast(self);
114  }
115  // Wait for the threads to finish.
116  STLDeleteElements(&threads_);
117}
118
119void ThreadPool::StartWorkers(Thread* self) {
120  MutexLock mu(self, task_queue_lock_);
121  started_ = true;
122  task_queue_condition_.Broadcast(self);
123  start_time_ = NanoTime();
124  total_wait_time_ = 0;
125}
126
127void ThreadPool::StopWorkers(Thread* self) {
128  MutexLock mu(self, task_queue_lock_);
129  started_ = false;
130}
131
132Task* ThreadPool::GetTask(Thread* self) {
133  MutexLock mu(self, task_queue_lock_);
134  while (!IsShuttingDown()) {
135    const size_t thread_count = GetThreadCount();
136    // Ensure that we don't use more threads than the maximum active workers.
137    const size_t active_threads = thread_count - waiting_count_;
138    // <= since self is considered an active worker.
139    if (active_threads <= max_active_workers_) {
140      Task* task = TryGetTaskLocked(self);
141      if (task != NULL) {
142        return task;
143      }
144    }
145
146    ++waiting_count_;
147    if (waiting_count_ == GetThreadCount() && tasks_.empty()) {
148      // We may be done, lets broadcast to the completion condition.
149      completion_condition_.Broadcast(self);
150    }
151    const uint64_t wait_start = kMeasureWaitTime ? NanoTime() : 0;
152    task_queue_condition_.Wait(self);
153    if (kMeasureWaitTime) {
154      const uint64_t wait_end = NanoTime();
155      total_wait_time_ += wait_end - std::max(wait_start, start_time_);
156    }
157    --waiting_count_;
158  }
159
160  // We are shutting down, return NULL to tell the worker thread to stop looping.
161  return NULL;
162}
163
164Task* ThreadPool::TryGetTask(Thread* self) {
165  MutexLock mu(self, task_queue_lock_);
166  return TryGetTaskLocked(self);
167}
168
169Task* ThreadPool::TryGetTaskLocked(Thread* self) {
170  if (started_ && !tasks_.empty()) {
171    Task* task = tasks_.front();
172    tasks_.pop_front();
173    return task;
174  }
175  return NULL;
176}
177
178void ThreadPool::Wait(Thread* self, bool do_work, bool may_hold_locks) {
179  if (do_work) {
180    Task* task = NULL;
181    while ((task = TryGetTask(self)) != NULL) {
182      task->Run(self);
183      task->Finalize();
184    }
185  }
186  // Wait until each thread is waiting and the task list is empty.
187  MutexLock mu(self, task_queue_lock_);
188  while (!shutting_down_ && (waiting_count_ != GetThreadCount() || !tasks_.empty())) {
189    if (!may_hold_locks) {
190      completion_condition_.Wait(self);
191    } else {
192      completion_condition_.WaitHoldingLocks(self);
193    }
194  }
195}
196
197size_t ThreadPool::GetTaskCount(Thread* self) {
198  MutexLock mu(self, task_queue_lock_);
199  return tasks_.size();
200}
201
202WorkStealingWorker::WorkStealingWorker(ThreadPool* thread_pool, const std::string& name,
203                                       size_t stack_size)
204    : ThreadPoolWorker(thread_pool, name, stack_size), task_(NULL) {}
205
206void WorkStealingWorker::Run() {
207  Thread* self = Thread::Current();
208  Task* task = NULL;
209  WorkStealingThreadPool* thread_pool = down_cast<WorkStealingThreadPool*>(thread_pool_);
210  while ((task = thread_pool_->GetTask(self)) != NULL) {
211    WorkStealingTask* stealing_task = down_cast<WorkStealingTask*>(task);
212
213    {
214      CHECK(task_ == NULL);
215      MutexLock mu(self, thread_pool->work_steal_lock_);
216      // Register that we are running the task
217      ++stealing_task->ref_count_;
218      task_ = stealing_task;
219    }
220    stealing_task->Run(self);
221    // Mark ourselves as not running a task so that nobody tries to steal from us.
222    // There is a race condition that someone starts stealing from us at this point. This is okay
223    // due to the reference counting.
224    task_ = NULL;
225
226    bool finalize;
227
228    // Steal work from tasks until there is none left to steal. Note: There is a race, but
229    // all that happens when the race occurs is that we steal some work instead of processing a
230    // task from the queue.
231    while (thread_pool->GetTaskCount(self) == 0) {
232      WorkStealingTask* steal_from_task  = NULL;
233
234      {
235        MutexLock mu(self, thread_pool->work_steal_lock_);
236        // Try finding a task to steal from.
237        steal_from_task = thread_pool->FindTaskToStealFrom(self);
238        if (steal_from_task != NULL) {
239          CHECK_NE(stealing_task, steal_from_task)
240              << "Attempting to steal from completed self task";
241          steal_from_task->ref_count_++;
242        } else {
243          break;
244        }
245      }
246
247      if (steal_from_task != NULL) {
248        // Task which completed earlier is going to steal some work.
249        stealing_task->StealFrom(self, steal_from_task);
250
251        {
252          // We are done stealing from the task, lets decrement its reference count.
253          MutexLock mu(self, thread_pool->work_steal_lock_);
254          finalize = !--steal_from_task->ref_count_;
255        }
256
257        if (finalize) {
258          steal_from_task->Finalize();
259        }
260      }
261    }
262
263    {
264      MutexLock mu(self, thread_pool->work_steal_lock_);
265      // If nobody is still referencing task_ we can finalize it.
266      finalize = !--stealing_task->ref_count_;
267    }
268
269    if (finalize) {
270      stealing_task->Finalize();
271    }
272  }
273}
274
275WorkStealingWorker::~WorkStealingWorker() {}
276
277WorkStealingThreadPool::WorkStealingThreadPool(const char* name, size_t num_threads)
278    : ThreadPool(name, 0),
279      work_steal_lock_("work stealing lock"),
280      steal_index_(0) {
281  while (GetThreadCount() < num_threads) {
282    const std::string name = StringPrintf("Work stealing worker %zu", GetThreadCount());
283    threads_.push_back(new WorkStealingWorker(this, name, ThreadPoolWorker::kDefaultStackSize));
284  }
285}
286
287WorkStealingTask* WorkStealingThreadPool::FindTaskToStealFrom(Thread* self) {
288  const size_t thread_count = GetThreadCount();
289  for (size_t i = 0; i < thread_count; ++i) {
290    // TODO: Use CAS instead of lock.
291    ++steal_index_;
292    if (steal_index_ >= thread_count) {
293      steal_index_-= thread_count;
294    }
295
296    WorkStealingWorker* worker = down_cast<WorkStealingWorker*>(threads_[steal_index_]);
297    WorkStealingTask* task = worker->task_;
298    if (task) {
299      // Not null, we can probably steal from this worker.
300      return task;
301    }
302  }
303  // Couldn't find something to steal.
304  return NULL;
305}
306
307WorkStealingThreadPool::~WorkStealingThreadPool() {}
308
309}  // namespace art
310