1# Copyright (C) 2011 Google Inc. All rights reserved.
2#
3# Redistribution and use in source and binary forms, with or without
4# modification, are permitted provided that the following conditions are
5# met:
6#
7#     * Redistributions of source code must retain the above copyright
8# notice, this list of conditions and the following disclaimer.
9#     * Redistributions in binary form must reproduce the above
10# copyright notice, this list of conditions and the following disclaimer
11# in the documentation and/or other materials provided with the
12# distribution.
13#     * Neither the name of Google Inc. nor the names of its
14# contributors may be used to endorse or promote products derived from
15# this software without specific prior written permission.
16#
17# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
21# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
22# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
23# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
24# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
25# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
29"""Module for handling messages and concurrency for run-webkit-tests
30and test-webkitpy. This module follows the design for multiprocessing.Pool
31and concurrency.futures.ProcessPoolExecutor, with the following differences:
32
33* Tasks are executed in stateful subprocesses via objects that implement the
34  Worker interface - this allows the workers to share state across tasks.
35* The pool provides an asynchronous event-handling interface so the caller
36  may receive events as tasks are processed.
37
38If you don't need these features, use multiprocessing.Pool or concurrency.futures
39intead.
40
41"""
42
43import cPickle
44import logging
45import multiprocessing
46import Queue
47import sys
48import time
49import traceback
50
51
52from webkitpy.common.host import Host
53from webkitpy.common.system import stack_utils
54
55
56_log = logging.getLogger(__name__)
57
58
59def get(caller, worker_factory, num_workers, host=None):
60    """Returns an object that exposes a run() method that takes a list of test shards and runs them in parallel."""
61    return _MessagePool(caller, worker_factory, num_workers, host)
62
63
64class _MessagePool(object):
65    def __init__(self, caller, worker_factory, num_workers, host=None):
66        self._caller = caller
67        self._worker_factory = worker_factory
68        self._num_workers = num_workers
69        self._workers = []
70        self._workers_stopped = set()
71        self._host = host
72        self._name = 'manager'
73        self._running_inline = (self._num_workers == 1)
74        if self._running_inline:
75            self._messages_to_worker = Queue.Queue()
76            self._messages_to_manager = Queue.Queue()
77        else:
78            self._messages_to_worker = multiprocessing.Queue()
79            self._messages_to_manager = multiprocessing.Queue()
80
81    def __enter__(self):
82        return self
83
84    def __exit__(self, exc_type, exc_value, exc_traceback):
85        self._close()
86        return False
87
88    def run(self, shards):
89        """Posts a list of messages to the pool and waits for them to complete."""
90        for message in shards:
91            self._messages_to_worker.put(_Message(self._name, message[0], message[1:], from_user=True, logs=()))
92
93        for _ in xrange(self._num_workers):
94            self._messages_to_worker.put(_Message(self._name, 'stop', message_args=(), from_user=False, logs=()))
95
96        self.wait()
97
98    def _start_workers(self):
99        assert not self._workers
100        self._workers_stopped = set()
101        host = None
102        if self._running_inline or self._can_pickle(self._host):
103            host = self._host
104
105        for worker_number in xrange(self._num_workers):
106            worker = _Worker(host, self._messages_to_manager, self._messages_to_worker, self._worker_factory, worker_number, self._running_inline, self if self._running_inline else None, self._worker_log_level())
107            self._workers.append(worker)
108            worker.start()
109
110    def _worker_log_level(self):
111        log_level = logging.NOTSET
112        for handler in logging.root.handlers:
113            if handler.level != logging.NOTSET:
114                if log_level == logging.NOTSET:
115                    log_level = handler.level
116                else:
117                    log_level = min(log_level, handler.level)
118        return log_level
119
120    def wait(self):
121        try:
122            self._start_workers()
123            if self._running_inline:
124                self._workers[0].run()
125                self._loop(block=False)
126            else:
127                self._loop(block=True)
128        finally:
129            self._close()
130
131    def _close(self):
132        for worker in self._workers:
133            if worker.is_alive():
134                worker.terminate()
135                worker.join()
136        self._workers = []
137        if not self._running_inline:
138            # FIXME: This is a hack to get multiprocessing to not log tracebacks during shutdown :(.
139            multiprocessing.util._exiting = True
140            if self._messages_to_worker:
141                self._messages_to_worker.close()
142                self._messages_to_worker = None
143            if self._messages_to_manager:
144                self._messages_to_manager.close()
145                self._messages_to_manager = None
146
147    def _log_messages(self, messages):
148        for message in messages:
149            logging.root.handle(message)
150
151    def _handle_done(self, source):
152        self._workers_stopped.add(source)
153
154    @staticmethod
155    def _handle_worker_exception(source, exception_type, exception_value, _):
156        if exception_type == KeyboardInterrupt:
157            raise exception_type(exception_value)
158        raise WorkerException(str(exception_value))
159
160    def _can_pickle(self, host):
161        try:
162            cPickle.dumps(host)
163            return True
164        except TypeError:
165            return False
166
167    def _loop(self, block):
168        try:
169            while True:
170                if len(self._workers_stopped) == len(self._workers):
171                    block = False
172                message = self._messages_to_manager.get(block)
173                self._log_messages(message.logs)
174                if message.from_user:
175                    self._caller.handle(message.name, message.src, *message.args)
176                    continue
177                method = getattr(self, '_handle_' + message.name)
178                assert method, 'bad message %s' % repr(message)
179                method(message.src, *message.args)
180        except Queue.Empty:
181            pass
182
183
184class WorkerException(BaseException):
185    """Raised when we receive an unexpected/unknown exception from a worker."""
186    pass
187
188
189class _Message(object):
190    def __init__(self, src, message_name, message_args, from_user, logs):
191        self.src = src
192        self.name = message_name
193        self.args = message_args
194        self.from_user = from_user
195        self.logs = logs
196
197    def __repr__(self):
198        return '_Message(src=%s, name=%s, args=%s, from_user=%s, logs=%s)' % (self.src, self.name, self.args, self.from_user, self.logs)
199
200
201class _Worker(multiprocessing.Process):
202    def __init__(self, host, messages_to_manager, messages_to_worker, worker_factory, worker_number, running_inline, manager, log_level):
203        super(_Worker, self).__init__()
204        self.host = host
205        self.worker_number = worker_number
206        self.name = 'worker/%d' % worker_number
207        self.log_messages = []
208        self.log_level = log_level
209        self._running = False
210        self._running_inline = running_inline
211        self._manager = manager
212
213        self._messages_to_manager = messages_to_manager
214        self._messages_to_worker = messages_to_worker
215        self._worker = worker_factory(self)
216        self._logger = None
217        self._log_handler = None
218
219    def terminate(self):
220        if self._worker:
221            if hasattr(self._worker, 'stop'):
222                self._worker.stop()
223            self._worker = None
224        if self.is_alive():
225            super(_Worker, self).terminate()
226
227    def _close(self):
228        if self._log_handler and self._logger:
229            self._logger.removeHandler(self._log_handler)
230        self._log_handler = None
231        self._logger = None
232
233    def start(self):
234        if not self._running_inline:
235            super(_Worker, self).start()
236
237    def run(self):
238        if not self.host:
239            self.host = Host()
240        if not self._running_inline:
241            self._set_up_logging()
242
243        worker = self._worker
244        exception_msg = ""
245        _log.debug("%s starting" % self.name)
246        self._running = True
247
248        try:
249            if hasattr(worker, 'start'):
250                worker.start()
251            while self._running:
252                message = self._messages_to_worker.get()
253                if message.from_user:
254                    worker.handle(message.name, message.src, *message.args)
255                    self._yield_to_manager()
256                else:
257                    assert message.name == 'stop', 'bad message %s' % repr(message)
258                    break
259
260            _log.debug("%s exiting" % self.name)
261        except Queue.Empty:
262            assert False, '%s: ran out of messages in worker queue.' % self.name
263        except KeyboardInterrupt, e:
264            self._raise(sys.exc_info())
265        except Exception, e:
266            self._raise(sys.exc_info())
267        finally:
268            try:
269                if hasattr(worker, 'stop'):
270                    worker.stop()
271            finally:
272                self._post(name='done', args=(), from_user=False)
273            self._close()
274
275    def stop_running(self):
276        self._running = False
277
278    def post(self, name, *args):
279        self._post(name, args, from_user=True)
280        self._yield_to_manager()
281
282    def _yield_to_manager(self):
283        if self._running_inline:
284            self._manager._loop(block=False)
285
286    def _post(self, name, args, from_user):
287        log_messages = self.log_messages
288        self.log_messages = []
289        self._messages_to_manager.put(_Message(self.name, name, args, from_user, log_messages))
290
291    def _raise(self, exc_info):
292        exception_type, exception_value, exception_traceback = exc_info
293        if self._running_inline:
294            raise exception_type, exception_value, exception_traceback
295
296        if exception_type == KeyboardInterrupt:
297            _log.debug("%s: interrupted, exiting" % self.name)
298            stack_utils.log_traceback(_log.debug, exception_traceback)
299        else:
300            _log.error("%s: %s('%s') raised:" % (self.name, exception_value.__class__.__name__, str(exception_value)))
301            stack_utils.log_traceback(_log.error, exception_traceback)
302        # Since tracebacks aren't picklable, send the extracted stack instead.
303        stack = traceback.extract_tb(exception_traceback)
304        self._post(name='worker_exception', args=(exception_type, exception_value, stack), from_user=False)
305
306    def _set_up_logging(self):
307        self._logger = logging.getLogger()
308
309        # The unix multiprocessing implementation clones any log handlers into the child process,
310        # so we remove them to avoid duplicate logging.
311        for h in self._logger.handlers:
312            self._logger.removeHandler(h)
313
314        self._log_handler = _WorkerLogHandler(self)
315        self._logger.addHandler(self._log_handler)
316        self._logger.setLevel(self.log_level)
317
318
319class _WorkerLogHandler(logging.Handler):
320    def __init__(self, worker):
321        logging.Handler.__init__(self)
322        self._worker = worker
323        self.setLevel(worker.log_level)
324
325    def emit(self, record):
326        self._worker.log_messages.append(record)
327