1#
2# Module implementing synchronization primitives
3#
4# multiprocessing/synchronize.py
5#
6# Copyright (c) 2006-2008, R Oudkerk
7# Licensed to PSF under a Contributor Agreement.
8#
9
10__all__ = [
11    'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Condition', 'Event'
12    ]
13
14import threading
15import sys
16import tempfile
17import _multiprocessing
18
19from time import time as _time
20
21from . import context
22from . import process
23from . import util
24
25# Try to import the mp.synchronize module cleanly, if it fails
26# raise ImportError for platforms lacking a working sem_open implementation.
27# See issue 3770
28try:
29    from _multiprocessing import SemLock, sem_unlink
30except (ImportError):
31    raise ImportError("This platform lacks a functioning sem_open" +
32                      " implementation, therefore, the required" +
33                      " synchronization primitives needed will not" +
34                      " function, see issue 3770.")
35
36#
37# Constants
38#
39
40RECURSIVE_MUTEX, SEMAPHORE = list(range(2))
41SEM_VALUE_MAX = _multiprocessing.SemLock.SEM_VALUE_MAX
42
43#
44# Base class for semaphores and mutexes; wraps `_multiprocessing.SemLock`
45#
46
47class SemLock(object):
48
49    _rand = tempfile._RandomNameSequence()
50
51    def __init__(self, kind, value, maxvalue, *, ctx):
52        if ctx is None:
53            ctx = context._default_context.get_context()
54        name = ctx.get_start_method()
55        unlink_now = sys.platform == 'win32' or name == 'fork'
56        for i in range(100):
57            try:
58                sl = self._semlock = _multiprocessing.SemLock(
59                    kind, value, maxvalue, self._make_name(),
60                    unlink_now)
61            except FileExistsError:
62                pass
63            else:
64                break
65        else:
66            raise FileExistsError('cannot find name for semaphore')
67
68        util.debug('created semlock with handle %s' % sl.handle)
69        self._make_methods()
70
71        if sys.platform != 'win32':
72            def _after_fork(obj):
73                obj._semlock._after_fork()
74            util.register_after_fork(self, _after_fork)
75
76        if self._semlock.name is not None:
77            # We only get here if we are on Unix with forking
78            # disabled.  When the object is garbage collected or the
79            # process shuts down we unlink the semaphore name
80            from .semaphore_tracker import register
81            register(self._semlock.name)
82            util.Finalize(self, SemLock._cleanup, (self._semlock.name,),
83                          exitpriority=0)
84
85    @staticmethod
86    def _cleanup(name):
87        from .semaphore_tracker import unregister
88        sem_unlink(name)
89        unregister(name)
90
91    def _make_methods(self):
92        self.acquire = self._semlock.acquire
93        self.release = self._semlock.release
94
95    def __enter__(self):
96        return self._semlock.__enter__()
97
98    def __exit__(self, *args):
99        return self._semlock.__exit__(*args)
100
101    def __getstate__(self):
102        context.assert_spawning(self)
103        sl = self._semlock
104        if sys.platform == 'win32':
105            h = context.get_spawning_popen().duplicate_for_child(sl.handle)
106        else:
107            h = sl.handle
108        return (h, sl.kind, sl.maxvalue, sl.name)
109
110    def __setstate__(self, state):
111        self._semlock = _multiprocessing.SemLock._rebuild(*state)
112        util.debug('recreated blocker with handle %r' % state[0])
113        self._make_methods()
114
115    @staticmethod
116    def _make_name():
117        return '%s-%s' % (process.current_process()._config['semprefix'],
118                          next(SemLock._rand))
119
120#
121# Semaphore
122#
123
124class Semaphore(SemLock):
125
126    def __init__(self, value=1, *, ctx):
127        SemLock.__init__(self, SEMAPHORE, value, SEM_VALUE_MAX, ctx=ctx)
128
129    def get_value(self):
130        return self._semlock._get_value()
131
132    def __repr__(self):
133        try:
134            value = self._semlock._get_value()
135        except Exception:
136            value = 'unknown'
137        return '<%s(value=%s)>' % (self.__class__.__name__, value)
138
139#
140# Bounded semaphore
141#
142
143class BoundedSemaphore(Semaphore):
144
145    def __init__(self, value=1, *, ctx):
146        SemLock.__init__(self, SEMAPHORE, value, value, ctx=ctx)
147
148    def __repr__(self):
149        try:
150            value = self._semlock._get_value()
151        except Exception:
152            value = 'unknown'
153        return '<%s(value=%s, maxvalue=%s)>' % \
154               (self.__class__.__name__, value, self._semlock.maxvalue)
155
156#
157# Non-recursive lock
158#
159
160class Lock(SemLock):
161
162    def __init__(self, *, ctx):
163        SemLock.__init__(self, SEMAPHORE, 1, 1, ctx=ctx)
164
165    def __repr__(self):
166        try:
167            if self._semlock._is_mine():
168                name = process.current_process().name
169                if threading.current_thread().name != 'MainThread':
170                    name += '|' + threading.current_thread().name
171            elif self._semlock._get_value() == 1:
172                name = 'None'
173            elif self._semlock._count() > 0:
174                name = 'SomeOtherThread'
175            else:
176                name = 'SomeOtherProcess'
177        except Exception:
178            name = 'unknown'
179        return '<%s(owner=%s)>' % (self.__class__.__name__, name)
180
181#
182# Recursive lock
183#
184
185class RLock(SemLock):
186
187    def __init__(self, *, ctx):
188        SemLock.__init__(self, RECURSIVE_MUTEX, 1, 1, ctx=ctx)
189
190    def __repr__(self):
191        try:
192            if self._semlock._is_mine():
193                name = process.current_process().name
194                if threading.current_thread().name != 'MainThread':
195                    name += '|' + threading.current_thread().name
196                count = self._semlock._count()
197            elif self._semlock._get_value() == 1:
198                name, count = 'None', 0
199            elif self._semlock._count() > 0:
200                name, count = 'SomeOtherThread', 'nonzero'
201            else:
202                name, count = 'SomeOtherProcess', 'nonzero'
203        except Exception:
204            name, count = 'unknown', 'unknown'
205        return '<%s(%s, %s)>' % (self.__class__.__name__, name, count)
206
207#
208# Condition variable
209#
210
211class Condition(object):
212
213    def __init__(self, lock=None, *, ctx):
214        self._lock = lock or ctx.RLock()
215        self._sleeping_count = ctx.Semaphore(0)
216        self._woken_count = ctx.Semaphore(0)
217        self._wait_semaphore = ctx.Semaphore(0)
218        self._make_methods()
219
220    def __getstate__(self):
221        context.assert_spawning(self)
222        return (self._lock, self._sleeping_count,
223                self._woken_count, self._wait_semaphore)
224
225    def __setstate__(self, state):
226        (self._lock, self._sleeping_count,
227         self._woken_count, self._wait_semaphore) = state
228        self._make_methods()
229
230    def __enter__(self):
231        return self._lock.__enter__()
232
233    def __exit__(self, *args):
234        return self._lock.__exit__(*args)
235
236    def _make_methods(self):
237        self.acquire = self._lock.acquire
238        self.release = self._lock.release
239
240    def __repr__(self):
241        try:
242            num_waiters = (self._sleeping_count._semlock._get_value() -
243                           self._woken_count._semlock._get_value())
244        except Exception:
245            num_waiters = 'unknown'
246        return '<%s(%s, %s)>' % (self.__class__.__name__, self._lock, num_waiters)
247
248    def wait(self, timeout=None):
249        assert self._lock._semlock._is_mine(), \
250               'must acquire() condition before using wait()'
251
252        # indicate that this thread is going to sleep
253        self._sleeping_count.release()
254
255        # release lock
256        count = self._lock._semlock._count()
257        for i in range(count):
258            self._lock.release()
259
260        try:
261            # wait for notification or timeout
262            return self._wait_semaphore.acquire(True, timeout)
263        finally:
264            # indicate that this thread has woken
265            self._woken_count.release()
266
267            # reacquire lock
268            for i in range(count):
269                self._lock.acquire()
270
271    def notify(self):
272        assert self._lock._semlock._is_mine(), 'lock is not owned'
273        assert not self._wait_semaphore.acquire(False)
274
275        # to take account of timeouts since last notify() we subtract
276        # woken_count from sleeping_count and rezero woken_count
277        while self._woken_count.acquire(False):
278            res = self._sleeping_count.acquire(False)
279            assert res
280
281        if self._sleeping_count.acquire(False): # try grabbing a sleeper
282            self._wait_semaphore.release()      # wake up one sleeper
283            self._woken_count.acquire()         # wait for the sleeper to wake
284
285            # rezero _wait_semaphore in case a timeout just happened
286            self._wait_semaphore.acquire(False)
287
288    def notify_all(self):
289        assert self._lock._semlock._is_mine(), 'lock is not owned'
290        assert not self._wait_semaphore.acquire(False)
291
292        # to take account of timeouts since last notify*() we subtract
293        # woken_count from sleeping_count and rezero woken_count
294        while self._woken_count.acquire(False):
295            res = self._sleeping_count.acquire(False)
296            assert res
297
298        sleepers = 0
299        while self._sleeping_count.acquire(False):
300            self._wait_semaphore.release()        # wake up one sleeper
301            sleepers += 1
302
303        if sleepers:
304            for i in range(sleepers):
305                self._woken_count.acquire()       # wait for a sleeper to wake
306
307            # rezero wait_semaphore in case some timeouts just happened
308            while self._wait_semaphore.acquire(False):
309                pass
310
311    def wait_for(self, predicate, timeout=None):
312        result = predicate()
313        if result:
314            return result
315        if timeout is not None:
316            endtime = _time() + timeout
317        else:
318            endtime = None
319            waittime = None
320        while not result:
321            if endtime is not None:
322                waittime = endtime - _time()
323                if waittime <= 0:
324                    break
325            self.wait(waittime)
326            result = predicate()
327        return result
328
329#
330# Event
331#
332
333class Event(object):
334
335    def __init__(self, *, ctx):
336        self._cond = ctx.Condition(ctx.Lock())
337        self._flag = ctx.Semaphore(0)
338
339    def is_set(self):
340        with self._cond:
341            if self._flag.acquire(False):
342                self._flag.release()
343                return True
344            return False
345
346    def set(self):
347        with self._cond:
348            self._flag.acquire(False)
349            self._flag.release()
350            self._cond.notify_all()
351
352    def clear(self):
353        with self._cond:
354            self._flag.acquire(False)
355
356    def wait(self, timeout=None):
357        with self._cond:
358            if self._flag.acquire(False):
359                self._flag.release()
360            else:
361                self._cond.wait(timeout)
362
363            if self._flag.acquire(False):
364                self._flag.release()
365                return True
366            return False
367
368#
369# Barrier
370#
371
372class Barrier(threading.Barrier):
373
374    def __init__(self, parties, action=None, timeout=None, *, ctx):
375        import struct
376        from .heap import BufferWrapper
377        wrapper = BufferWrapper(struct.calcsize('i') * 2)
378        cond = ctx.Condition()
379        self.__setstate__((parties, action, timeout, cond, wrapper))
380        self._state = 0
381        self._count = 0
382
383    def __setstate__(self, state):
384        (self._parties, self._action, self._timeout,
385         self._cond, self._wrapper) = state
386        self._array = self._wrapper.create_memoryview().cast('i')
387
388    def __getstate__(self):
389        return (self._parties, self._action, self._timeout,
390                self._cond, self._wrapper)
391
392    @property
393    def _state(self):
394        return self._array[0]
395
396    @_state.setter
397    def _state(self, value):
398        self._array[0] = value
399
400    @property
401    def _count(self):
402        return self._array[1]
403
404    @_count.setter
405    def _count(self, value):
406        self._array[1] = value
407