1# Copyright 2013 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5"""Thread and ThreadGroup that reraise exceptions on the main thread."""
6# pylint: disable=W0212
7
8import logging
9import sys
10import threading
11import time
12import traceback
13
14from devil.utils import watchdog_timer
15
16
17class TimeoutError(Exception):
18  """Module-specific timeout exception."""
19  pass
20
21
22def LogThreadStack(thread, error_log_func=logging.critical):
23  """Log the stack for the given thread.
24
25  Args:
26    thread: a threading.Thread instance.
27    error_log_func: Logging function when logging errors.
28  """
29  stack = sys._current_frames()[thread.ident]
30  error_log_func('*' * 80)
31  error_log_func('Stack dump for thread %r', thread.name)
32  error_log_func('*' * 80)
33  for filename, lineno, name, line in traceback.extract_stack(stack):
34    error_log_func('File: "%s", line %d, in %s', filename, lineno, name)
35    if line:
36      error_log_func('  %s', line.strip())
37  error_log_func('*' * 80)
38
39
40class ReraiserThread(threading.Thread):
41  """Thread class that can reraise exceptions."""
42
43  def __init__(self, func, args=None, kwargs=None, name=None):
44    """Initialize thread.
45
46    Args:
47      func: callable to call on a new thread.
48      args: list of positional arguments for callable, defaults to empty.
49      kwargs: dictionary of keyword arguments for callable, defaults to empty.
50      name: thread name, defaults to Thread-N.
51    """
52    if not name and func.__name__ != '<lambda>':
53      name = func.__name__
54    super(ReraiserThread, self).__init__(name=name)
55    if not args:
56      args = []
57    if not kwargs:
58      kwargs = {}
59    self.daemon = True
60    self._func = func
61    self._args = args
62    self._kwargs = kwargs
63    self._ret = None
64    self._exc_info = None
65    self._thread_group = None
66
67  def ReraiseIfException(self):
68    """Reraise exception if an exception was raised in the thread."""
69    if self._exc_info:
70      raise self._exc_info[0], self._exc_info[1], self._exc_info[2]
71
72  def GetReturnValue(self):
73    """Reraise exception if present, otherwise get the return value."""
74    self.ReraiseIfException()
75    return self._ret
76
77  # override
78  def run(self):
79    """Overrides Thread.run() to add support for reraising exceptions."""
80    try:
81      self._ret = self._func(*self._args, **self._kwargs)
82    except:  # pylint: disable=W0702
83      self._exc_info = sys.exc_info()
84
85
86class ReraiserThreadGroup(object):
87  """A group of ReraiserThread objects."""
88
89  def __init__(self, threads=None):
90    """Initialize thread group.
91
92    Args:
93      threads: a list of ReraiserThread objects; defaults to empty.
94    """
95    self._threads = []
96    # Set when a thread from one group has called JoinAll on another. It is used
97    # to detect when a there is a TimeoutRetryThread active that links to the
98    # current thread.
99    self.blocked_parent_thread_group = None
100    if threads:
101      for thread in threads:
102        self.Add(thread)
103
104  def Add(self, thread):
105    """Add a thread to the group.
106
107    Args:
108      thread: a ReraiserThread object.
109    """
110    assert thread._thread_group is None
111    thread._thread_group = self
112    self._threads.append(thread)
113
114  def StartAll(self, will_block=False):
115    """Start all threads.
116
117    Args:
118      will_block: Whether the calling thread will subsequently block on this
119        thread group. Causes the active ReraiserThreadGroup (if there is one)
120        to be marked as blocking on this thread group.
121    """
122    if will_block:
123      # Multiple threads blocking on the same outer thread should not happen in
124      # practice.
125      assert not self.blocked_parent_thread_group
126      self.blocked_parent_thread_group = CurrentThreadGroup()
127    for thread in self._threads:
128      thread.start()
129
130  def _JoinAll(self, watcher=None, timeout=None):
131    """Join all threads without stack dumps.
132
133    Reraises exceptions raised by the child threads and supports breaking
134    immediately on exceptions raised on the main thread.
135
136    Args:
137      watcher: Watchdog object providing the thread timeout. If none is
138          provided, the thread will never be timed out.
139      timeout: An optional number of seconds to wait before timing out the join
140          operation. This will not time out the threads.
141    """
142    if watcher is None:
143      watcher = watchdog_timer.WatchdogTimer(None)
144    alive_threads = self._threads[:]
145    end_time = (time.time() + timeout) if timeout else None
146    try:
147      while alive_threads and (end_time is None or end_time > time.time()):
148        for thread in alive_threads[:]:
149          if watcher.IsTimedOut():
150            raise TimeoutError('Timed out waiting for %d of %d threads.' %
151                               (len(alive_threads), len(self._threads)))
152          # Allow the main thread to periodically check for interrupts.
153          thread.join(0.1)
154          if not thread.isAlive():
155            alive_threads.remove(thread)
156      # All threads are allowed to complete before reraising exceptions.
157      for thread in self._threads:
158        thread.ReraiseIfException()
159    finally:
160      self.blocked_parent_thread_group = None
161
162  def IsAlive(self):
163    """Check whether any of the threads are still alive.
164
165    Returns:
166      Whether any of the threads are still alive.
167    """
168    return any(t.isAlive() for t in self._threads)
169
170  def JoinAll(self, watcher=None, timeout=None,
171              error_log_func=logging.critical):
172    """Join all threads.
173
174    Reraises exceptions raised by the child threads and supports breaking
175    immediately on exceptions raised on the main thread. Unfinished threads'
176    stacks will be logged on watchdog timeout.
177
178    Args:
179      watcher: Watchdog object providing the thread timeout. If none is
180          provided, the thread will never be timed out.
181      timeout: An optional number of seconds to wait before timing out the join
182          operation. This will not time out the threads.
183      error_log_func: Logging function when logging errors.
184    """
185    try:
186      self._JoinAll(watcher, timeout)
187    except TimeoutError:
188      error_log_func('Timed out. Dumping threads.')
189      for thread in (t for t in self._threads if t.isAlive()):
190        LogThreadStack(thread, error_log_func=error_log_func)
191      raise
192
193  def GetAllReturnValues(self, watcher=None):
194    """Get all return values, joining all threads if necessary.
195
196    Args:
197      watcher: same as in |JoinAll|. Only used if threads are alive.
198    """
199    if any([t.isAlive() for t in self._threads]):
200      self.JoinAll(watcher)
201    return [t.GetReturnValue() for t in self._threads]
202
203
204def CurrentThreadGroup():
205  """Returns the ReraiserThreadGroup that owns the running thread.
206
207  Returns:
208    The current thread group, otherwise None.
209  """
210  current_thread = threading.current_thread()
211  if isinstance(current_thread, ReraiserThread):
212    return current_thread._thread_group  # pylint: disable=no-member
213  return None
214
215
216def RunAsync(funcs, watcher=None):
217  """Executes the given functions in parallel and returns their results.
218
219  Args:
220    funcs: List of functions to perform on their own threads.
221    watcher: Watchdog object providing timeout, by default waits forever.
222
223  Returns:
224    A list of return values in the order of the given functions.
225  """
226  thread_group = ReraiserThreadGroup(ReraiserThread(f) for f in funcs)
227  thread_group.StartAll(will_block=True)
228  return thread_group.GetAllReturnValues(watcher=watcher)
229