1# Copyright (c) 2013 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5from multiprocessing import Queue, queues
6
7
8class QueueBarrierTimeout(Exception):
9    """QueueBarrier timeout exception."""
10
11
12class QueueBarrier(object):
13    """This class implements a simple barrier to synchronize processes. The
14    barrier relies on the fact that there a single process "master" and |n|
15    different "slaves" to make the implementation simpler. Also, given this
16    hierarchy, the slaves and the master can exchange a token while passing
17    through the barrier.
18
19    The so called "master" shall call master_barrier() while the "slave" shall
20    call the slave_barrier() method.
21
22    If the same group of |n| slaves and the same master are participating in the
23    barrier, it is totally safe to reuse the barrier several times with the same
24    group of processes.
25    """
26
27
28    def __init__(self, n):
29        """Initializes the barrier with |n| slave processes and a master.
30
31        @param n: The number of slave processes."""
32        self.n_ = n
33        self.queue_master_ = Queue()
34        self.queue_slave_ = Queue()
35
36
37    def master_barrier(self, token=None, timeout=None):
38        """Makes the master wait until all the "n" slaves have reached this
39        point.
40
41        @param token: A value passed to every slave.
42        @param timeout: The timeout, in seconds, to wait for the slaves.
43                A None value will block forever.
44
45        Returns the list of received tokens from the slaves.
46        """
47        # Wait for all the slaves.
48        result = []
49        try:
50            for _ in range(self.n_):
51                result.append(self.queue_master_.get(timeout=timeout))
52        except queues.Empty:
53            # Timeout expired
54            raise QueueBarrierTimeout()
55        # Release all the blocked slaves.
56        for _ in range(self.n_):
57            self.queue_slave_.put(token)
58        return result
59
60
61    def slave_barrier(self, token=None, timeout=None):
62        """Makes a slave wait until all the "n" slaves and the master have
63        reached this point.
64
65        @param token: A value passed to the master.
66        @param timeout: The timeout, in seconds, to wait for the slaves.
67                A None value will block forever.
68        """
69        self.queue_master_.put(token)
70        try:
71            return self.queue_slave_.get(timeout=timeout)
72        except queues.Empty:
73            # Timeout expired
74            raise QueueBarrierTimeout()
75