1# Copyright 2013 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5"""A utility to run functions with timeouts and retries."""
6# pylint: disable=W0702
7
8import logging
9import threading
10import time
11
12from devil.utils import reraiser_thread
13from devil.utils import watchdog_timer
14
15
16class TimeoutRetryThreadGroup(reraiser_thread.ReraiserThreadGroup):
17
18  def __init__(self, timeout, threads=None):
19    super(TimeoutRetryThreadGroup, self).__init__(threads)
20    self._watcher = watchdog_timer.WatchdogTimer(timeout)
21
22  def GetWatcher(self):
23    """Returns the watchdog keeping track of this thread's time."""
24    return self._watcher
25
26  def GetElapsedTime(self):
27    return self._watcher.GetElapsed()
28
29  def GetRemainingTime(self, required=0, msg=None):
30    """Get the remaining time before the thread times out.
31
32    Useful to send as the |timeout| parameter of async IO operations.
33
34    Args:
35      required: minimum amount of time that will be required to complete, e.g.,
36        some sleep or IO operation.
37      msg: error message to show if timing out.
38
39    Returns:
40      The number of seconds remaining before the thread times out, or None
41      if the thread never times out.
42
43    Raises:
44      reraiser_thread.TimeoutError if the remaining time is less than the
45        required time.
46    """
47    remaining = self._watcher.GetRemaining()
48    if remaining is not None and remaining < required:
49      if msg is None:
50        msg = 'Timeout expired'
51      if remaining > 0:
52        msg += (', wait of %.1f secs required but only %.1f secs left'
53                % (required, remaining))
54      raise reraiser_thread.TimeoutError(msg)
55    return remaining
56
57
58def CurrentTimeoutThreadGroup():
59  """Returns the thread group that owns or is blocked on the active thread.
60
61  Returns:
62    Returns None if no TimeoutRetryThreadGroup is tracking the current thread.
63  """
64  thread_group = reraiser_thread.CurrentThreadGroup()
65  while thread_group:
66    if isinstance(thread_group, TimeoutRetryThreadGroup):
67      return thread_group
68    thread_group = thread_group.blocked_parent_thread_group
69  return None
70
71
72def WaitFor(condition, wait_period=5, max_tries=None):
73  """Wait for a condition to become true.
74
75  Repeatedly call the function condition(), with no arguments, until it returns
76  a true value.
77
78  If called within a TimeoutRetryThreadGroup, it cooperates nicely with it.
79
80  Args:
81    condition: function with the condition to check
82    wait_period: number of seconds to wait before retrying to check the
83      condition
84    max_tries: maximum number of checks to make, the default tries forever
85      or until the TimeoutRetryThreadGroup expires.
86
87  Returns:
88    The true value returned by the condition, or None if the condition was
89    not met after max_tries.
90
91  Raises:
92    reraiser_thread.TimeoutError: if the current thread is a
93      TimeoutRetryThreadGroup and the timeout expires.
94  """
95  condition_name = condition.__name__
96  timeout_thread_group = CurrentTimeoutThreadGroup()
97  while max_tries is None or max_tries > 0:
98    result = condition()
99    if max_tries is not None:
100      max_tries -= 1
101    msg = ['condition', repr(condition_name), 'met' if result else 'not met']
102    if timeout_thread_group:
103      # pylint: disable=no-member
104      msg.append('(%.1fs)' % timeout_thread_group.GetElapsedTime())
105    logging.info(' '.join(msg))
106    if result:
107      return result
108    if timeout_thread_group:
109      # pylint: disable=no-member
110      timeout_thread_group.GetRemainingTime(wait_period,
111          msg='Timed out waiting for %r' % condition_name)
112    time.sleep(wait_period)
113  return None
114
115
116def AlwaysRetry(_exception):
117  return True
118
119
120def Run(func, timeout, retries, args=None, kwargs=None, desc=None,
121        error_log_func=logging.critical, retry_if_func=AlwaysRetry):
122  """Runs the passed function in a separate thread with timeouts and retries.
123
124  Args:
125    func: the function to be wrapped.
126    timeout: the timeout in seconds for each try.
127    retries: the number of retries.
128    args: list of positional args to pass to |func|.
129    kwargs: dictionary of keyword args to pass to |func|.
130    desc: An optional description of |func| used in logging. If omitted,
131      |func.__name__| will be used.
132    error_log_func: Logging function when logging errors.
133    retry_if_func: Unary callable that takes an exception and returns
134      whether |func| should be retried. Defaults to always retrying.
135
136  Returns:
137    The return value of func(*args, **kwargs).
138  """
139  if not args:
140    args = []
141  if not kwargs:
142    kwargs = {}
143  if not desc:
144    desc = func.__name__
145
146  num_try = 1
147  while True:
148    thread_name = 'TimeoutThread-%d-for-%s' % (num_try,
149                                               threading.current_thread().name)
150    child_thread = reraiser_thread.ReraiserThread(lambda: func(*args, **kwargs),
151                                                  name=thread_name)
152    try:
153      thread_group = TimeoutRetryThreadGroup(timeout, threads=[child_thread])
154      thread_group.StartAll(will_block=True)
155      while True:
156        thread_group.JoinAll(watcher=thread_group.GetWatcher(), timeout=60,
157                             error_log_func=error_log_func)
158        if thread_group.IsAlive():
159          logging.info('Still working on %s', desc)
160        else:
161          return thread_group.GetAllReturnValues()[0]
162    except reraiser_thread.TimeoutError as e:
163      # Timeouts already get their stacks logged.
164      if num_try > retries or not retry_if_func(e):
165        raise
166      # Do not catch KeyboardInterrupt.
167    except Exception as e:  # pylint: disable=broad-except
168      if num_try > retries or not retry_if_func(e):
169        raise
170      error_log_func(
171          '(%s) Exception on %s, attempt %d of %d: %r',
172          thread_name, desc, num_try, retries + 1, e)
173    num_try += 1
174