pool.py revision 5bc9f4c09c99eb701dbee83fb9e26eed558e474f
1#
2# Module providing the `Pool` class for managing a process pool
3#
4# multiprocessing/pool.py
5#
6# Copyright (c) 2007-2008, R Oudkerk --- see COPYING.txt
7#
8
9__all__ = ['Pool']
10
11#
12# Imports
13#
14
15import threading
16import Queue
17import itertools
18import collections
19import time
20
21from multiprocessing import Process, cpu_count, TimeoutError
22from multiprocessing.util import Finalize, debug
23
24#
25# Constants representing the state of a pool
26#
27
28RUN = 0
29CLOSE = 1
30TERMINATE = 2
31
32#
33# Miscellaneous
34#
35
36job_counter = itertools.count()
37
38def mapstar(args):
39    return map(*args)
40
41#
42# Code run by worker processes
43#
44
45def worker(inqueue, outqueue, initializer=None, initargs=()):
46    put = outqueue.put
47    get = inqueue.get
48    if hasattr(inqueue, '_writer'):
49        inqueue._writer.close()
50        outqueue._reader.close()
51
52    if initializer is not None:
53        initializer(*initargs)
54
55    while 1:
56        try:
57            task = get()
58        except (EOFError, IOError):
59            debug('worker got EOFError or IOError -- exiting')
60            break
61
62        if task is None:
63            debug('worker got sentinel -- exiting')
64            break
65
66        job, i, func, args, kwds = task
67        try:
68            result = (True, func(*args, **kwds))
69        except Exception, e:
70            result = (False, e)
71        put((job, i, result))
72
73#
74# Class representing a process pool
75#
76
77class Pool(object):
78    '''
79    Class which supports an async version of the `apply()` builtin
80    '''
81    Process = Process
82
83    def __init__(self, processes=None, initializer=None, initargs=()):
84        self._setup_queues()
85        self._taskqueue = Queue.Queue()
86        self._cache = {}
87        self._state = RUN
88
89        if processes is None:
90            try:
91                processes = cpu_count()
92            except NotImplementedError:
93                processes = 1
94
95        self._pool = []
96        for i in range(processes):
97            w = self.Process(
98                target=worker,
99                args=(self._inqueue, self._outqueue, initializer, initargs)
100                )
101            self._pool.append(w)
102            w.name = w.name.replace('Process', 'PoolWorker')
103            w.daemon = True
104            w.start()
105
106        self._task_handler = threading.Thread(
107            target=Pool._handle_tasks,
108            args=(self._taskqueue, self._quick_put, self._outqueue, self._pool)
109            )
110        self._task_handler.daemon = True
111        self._task_handler._state = RUN
112        self._task_handler.start()
113
114        self._result_handler = threading.Thread(
115            target=Pool._handle_results,
116            args=(self._outqueue, self._quick_get, self._cache)
117            )
118        self._result_handler.daemon = True
119        self._result_handler._state = RUN
120        self._result_handler.start()
121
122        self._terminate = Finalize(
123            self, self._terminate_pool,
124            args=(self._taskqueue, self._inqueue, self._outqueue, self._pool,
125                  self._task_handler, self._result_handler, self._cache),
126            exitpriority=15
127            )
128
129    def _setup_queues(self):
130        from .queues import SimpleQueue
131        self._inqueue = SimpleQueue()
132        self._outqueue = SimpleQueue()
133        self._quick_put = self._inqueue._writer.send
134        self._quick_get = self._outqueue._reader.recv
135
136    def apply(self, func, args=(), kwds={}):
137        '''
138        Equivalent of `apply()` builtin
139        '''
140        assert self._state == RUN
141        return self.apply_async(func, args, kwds).get()
142
143    def map(self, func, iterable, chunksize=None):
144        '''
145        Equivalent of `map()` builtin
146        '''
147        assert self._state == RUN
148        return self.map_async(func, iterable, chunksize).get()
149
150    def imap(self, func, iterable, chunksize=1):
151        '''
152        Equivalent of `itertool.imap()` -- can be MUCH slower than `Pool.map()`
153        '''
154        assert self._state == RUN
155        if chunksize == 1:
156            result = IMapIterator(self._cache)
157            self._taskqueue.put((((result._job, i, func, (x,), {})
158                         for i, x in enumerate(iterable)), result._set_length))
159            return result
160        else:
161            assert chunksize > 1
162            task_batches = Pool._get_tasks(func, iterable, chunksize)
163            result = IMapIterator(self._cache)
164            self._taskqueue.put((((result._job, i, mapstar, (x,), {})
165                     for i, x in enumerate(task_batches)), result._set_length))
166            return (item for chunk in result for item in chunk)
167
168    def imap_unordered(self, func, iterable, chunksize=1):
169        '''
170        Like `imap()` method but ordering of results is arbitrary
171        '''
172        assert self._state == RUN
173        if chunksize == 1:
174            result = IMapUnorderedIterator(self._cache)
175            self._taskqueue.put((((result._job, i, func, (x,), {})
176                         for i, x in enumerate(iterable)), result._set_length))
177            return result
178        else:
179            assert chunksize > 1
180            task_batches = Pool._get_tasks(func, iterable, chunksize)
181            result = IMapUnorderedIterator(self._cache)
182            self._taskqueue.put((((result._job, i, mapstar, (x,), {})
183                     for i, x in enumerate(task_batches)), result._set_length))
184            return (item for chunk in result for item in chunk)
185
186    def apply_async(self, func, args=(), kwds={}, callback=None):
187        '''
188        Asynchronous equivalent of `apply()` builtin
189        '''
190        assert self._state == RUN
191        result = ApplyResult(self._cache, callback)
192        self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
193        return result
194
195    def map_async(self, func, iterable, chunksize=None, callback=None):
196        '''
197        Asynchronous equivalent of `map()` builtin
198        '''
199        assert self._state == RUN
200        if not hasattr(iterable, '__len__'):
201            iterable = list(iterable)
202
203        if chunksize is None:
204            chunksize, extra = divmod(len(iterable), len(self._pool) * 4)
205            if extra:
206                chunksize += 1
207
208        task_batches = Pool._get_tasks(func, iterable, chunksize)
209        result = MapResult(self._cache, chunksize, len(iterable), callback)
210        self._taskqueue.put((((result._job, i, mapstar, (x,), {})
211                              for i, x in enumerate(task_batches)), None))
212        return result
213
214    @staticmethod
215    def _handle_tasks(taskqueue, put, outqueue, pool):
216        thread = threading.current_thread()
217
218        for taskseq, set_length in iter(taskqueue.get, None):
219            i = -1
220            for i, task in enumerate(taskseq):
221                if thread._state:
222                    debug('task handler found thread._state != RUN')
223                    break
224                try:
225                    put(task)
226                except IOError:
227                    debug('could not put task on queue')
228                    break
229            else:
230                if set_length:
231                    debug('doing set_length()')
232                    set_length(i+1)
233                continue
234            break
235        else:
236            debug('task handler got sentinel')
237
238
239        try:
240            # tell result handler to finish when cache is empty
241            debug('task handler sending sentinel to result handler')
242            outqueue.put(None)
243
244            # tell workers there is no more work
245            debug('task handler sending sentinel to workers')
246            for p in pool:
247                put(None)
248        except IOError:
249            debug('task handler got IOError when sending sentinels')
250
251        debug('task handler exiting')
252
253    @staticmethod
254    def _handle_results(outqueue, get, cache):
255        thread = threading.current_thread()
256
257        while 1:
258            try:
259                task = get()
260            except (IOError, EOFError):
261                debug('result handler got EOFError/IOError -- exiting')
262                return
263
264            if thread._state:
265                assert thread._state == TERMINATE
266                debug('result handler found thread._state=TERMINATE')
267                break
268
269            if task is None:
270                debug('result handler got sentinel')
271                break
272
273            job, i, obj = task
274            try:
275                cache[job]._set(i, obj)
276            except KeyError:
277                pass
278
279        while cache and thread._state != TERMINATE:
280            try:
281                task = get()
282            except (IOError, EOFError):
283                debug('result handler got EOFError/IOError -- exiting')
284                return
285
286            if task is None:
287                debug('result handler ignoring extra sentinel')
288                continue
289            job, i, obj = task
290            try:
291                cache[job]._set(i, obj)
292            except KeyError:
293                pass
294
295        if hasattr(outqueue, '_reader'):
296            debug('ensuring that outqueue is not full')
297            # If we don't make room available in outqueue then
298            # attempts to add the sentinel (None) to outqueue may
299            # block.  There is guaranteed to be no more than 2 sentinels.
300            try:
301                for i in range(10):
302                    if not outqueue._reader.poll():
303                        break
304                    get()
305            except (IOError, EOFError):
306                pass
307
308        debug('result handler exiting: len(cache)=%s, thread._state=%s',
309              len(cache), thread._state)
310
311    @staticmethod
312    def _get_tasks(func, it, size):
313        it = iter(it)
314        while 1:
315            x = tuple(itertools.islice(it, size))
316            if not x:
317                return
318            yield (func, x)
319
320    def __reduce__(self):
321        raise NotImplementedError(
322              'pool objects cannot be passed between processes or pickled'
323              )
324
325    def close(self):
326        debug('closing pool')
327        if self._state == RUN:
328            self._state = CLOSE
329            self._taskqueue.put(None)
330
331    def terminate(self):
332        debug('terminating pool')
333        self._state = TERMINATE
334        self._terminate()
335
336    def join(self):
337        debug('joining pool')
338        assert self._state in (CLOSE, TERMINATE)
339        self._task_handler.join()
340        self._result_handler.join()
341        for p in self._pool:
342            p.join()
343
344    @staticmethod
345    def _help_stuff_finish(inqueue, task_handler, size):
346        # task_handler may be blocked trying to put items on inqueue
347        debug('removing tasks from inqueue until task handler finished')
348        inqueue._rlock.acquire()
349        while task_handler.is_alive() and inqueue._reader.poll():
350            inqueue._reader.recv()
351            time.sleep(0)
352
353    @classmethod
354    def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool,
355                        task_handler, result_handler, cache):
356        # this is guaranteed to only be called once
357        debug('finalizing pool')
358
359        task_handler._state = TERMINATE
360        taskqueue.put(None)                 # sentinel
361
362        debug('helping task handler/workers to finish')
363        cls._help_stuff_finish(inqueue, task_handler, len(pool))
364
365        assert result_handler.is_alive() or len(cache) == 0
366
367        result_handler._state = TERMINATE
368        outqueue.put(None)                  # sentinel
369
370        if pool and hasattr(pool[0], 'terminate'):
371            debug('terminating workers')
372            for p in pool:
373                p.terminate()
374
375        debug('joining task handler')
376        task_handler.join(1e100)
377
378        debug('joining result handler')
379        result_handler.join(1e100)
380
381        if pool and hasattr(pool[0], 'terminate'):
382            debug('joining pool workers')
383            for p in pool:
384                p.join()
385
386#
387# Class whose instances are returned by `Pool.apply_async()`
388#
389
390class ApplyResult(object):
391
392    def __init__(self, cache, callback):
393        self._cond = threading.Condition(threading.Lock())
394        self._job = job_counter.next()
395        self._cache = cache
396        self._ready = False
397        self._callback = callback
398        cache[self._job] = self
399
400    def ready(self):
401        return self._ready
402
403    def successful(self):
404        assert self._ready
405        return self._success
406
407    def wait(self, timeout=None):
408        self._cond.acquire()
409        try:
410            if not self._ready:
411                self._cond.wait(timeout)
412        finally:
413            self._cond.release()
414
415    def get(self, timeout=None):
416        self.wait(timeout)
417        if not self._ready:
418            raise TimeoutError
419        if self._success:
420            return self._value
421        else:
422            raise self._value
423
424    def _set(self, i, obj):
425        self._success, self._value = obj
426        if self._callback and self._success:
427            self._callback(self._value)
428        self._cond.acquire()
429        try:
430            self._ready = True
431            self._cond.notify()
432        finally:
433            self._cond.release()
434        del self._cache[self._job]
435
436#
437# Class whose instances are returned by `Pool.map_async()`
438#
439
440class MapResult(ApplyResult):
441
442    def __init__(self, cache, chunksize, length, callback):
443        ApplyResult.__init__(self, cache, callback)
444        self._success = True
445        self._value = [None] * length
446        self._chunksize = chunksize
447        if chunksize <= 0:
448            self._number_left = 0
449            self._ready = True
450        else:
451            self._number_left = length//chunksize + bool(length % chunksize)
452
453    def _set(self, i, success_result):
454        success, result = success_result
455        if success:
456            self._value[i*self._chunksize:(i+1)*self._chunksize] = result
457            self._number_left -= 1
458            if self._number_left == 0:
459                if self._callback:
460                    self._callback(self._value)
461                del self._cache[self._job]
462                self._cond.acquire()
463                try:
464                    self._ready = True
465                    self._cond.notify()
466                finally:
467                    self._cond.release()
468
469        else:
470            self._success = False
471            self._value = result
472            del self._cache[self._job]
473            self._cond.acquire()
474            try:
475                self._ready = True
476                self._cond.notify()
477            finally:
478                self._cond.release()
479
480#
481# Class whose instances are returned by `Pool.imap()`
482#
483
484class IMapIterator(object):
485
486    def __init__(self, cache):
487        self._cond = threading.Condition(threading.Lock())
488        self._job = job_counter.next()
489        self._cache = cache
490        self._items = collections.deque()
491        self._index = 0
492        self._length = None
493        self._unsorted = {}
494        cache[self._job] = self
495
496    def __iter__(self):
497        return self
498
499    def next(self, timeout=None):
500        self._cond.acquire()
501        try:
502            try:
503                item = self._items.popleft()
504            except IndexError:
505                if self._index == self._length:
506                    raise StopIteration
507                self._cond.wait(timeout)
508                try:
509                    item = self._items.popleft()
510                except IndexError:
511                    if self._index == self._length:
512                        raise StopIteration
513                    raise TimeoutError
514        finally:
515            self._cond.release()
516
517        success, value = item
518        if success:
519            return value
520        raise value
521
522    __next__ = next                    # XXX
523
524    def _set(self, i, obj):
525        self._cond.acquire()
526        try:
527            if self._index == i:
528                self._items.append(obj)
529                self._index += 1
530                while self._index in self._unsorted:
531                    obj = self._unsorted.pop(self._index)
532                    self._items.append(obj)
533                    self._index += 1
534                self._cond.notify()
535            else:
536                self._unsorted[i] = obj
537
538            if self._index == self._length:
539                del self._cache[self._job]
540        finally:
541            self._cond.release()
542
543    def _set_length(self, length):
544        self._cond.acquire()
545        try:
546            self._length = length
547            if self._index == self._length:
548                self._cond.notify()
549                del self._cache[self._job]
550        finally:
551            self._cond.release()
552
553#
554# Class whose instances are returned by `Pool.imap_unordered()`
555#
556
557class IMapUnorderedIterator(IMapIterator):
558
559    def _set(self, i, obj):
560        self._cond.acquire()
561        try:
562            self._items.append(obj)
563            self._index += 1
564            self._cond.notify()
565            if self._index == self._length:
566                del self._cache[self._job]
567        finally:
568            self._cond.release()
569
570#
571#
572#
573
574class ThreadPool(Pool):
575
576    from .dummy import Process
577
578    def __init__(self, processes=None, initializer=None, initargs=()):
579        Pool.__init__(self, processes, initializer, initargs)
580
581    def _setup_queues(self):
582        self._inqueue = Queue.Queue()
583        self._outqueue = Queue.Queue()
584        self._quick_put = self._inqueue.put
585        self._quick_get = self._outqueue.get
586
587    @staticmethod
588    def _help_stuff_finish(inqueue, task_handler, size):
589        # put sentinels at head of inqueue to make workers finish
590        inqueue.not_empty.acquire()
591        try:
592            inqueue.queue.clear()
593            inqueue.queue.extend([None] * size)
594            inqueue.not_empty.notify_all()
595        finally:
596            inqueue.not_empty.release()
597