1import signal
2import weakref
3
4from unittest2.compatibility import wraps
5
6__unittest = True
7
8
9class _InterruptHandler(object):
10    def __init__(self, default_handler):
11        self.called = False
12        self.default_handler = default_handler
13
14    def __call__(self, signum, frame):
15        installed_handler = signal.getsignal(signal.SIGINT)
16        if installed_handler is not self:
17            # if we aren't the installed handler, then delegate immediately
18            # to the default handler
19            self.default_handler(signum, frame)
20
21        if self.called:
22            self.default_handler(signum, frame)
23        self.called = True
24        for result in _results.keys():
25            result.stop()
26
27_results = weakref.WeakKeyDictionary()
28def registerResult(result):
29    _results[result] = 1
30
31def removeResult(result):
32    return bool(_results.pop(result, None))
33
34_interrupt_handler = None
35def installHandler():
36    global _interrupt_handler
37    if _interrupt_handler is None:
38        default_handler = signal.getsignal(signal.SIGINT)
39        _interrupt_handler = _InterruptHandler(default_handler)
40        signal.signal(signal.SIGINT, _interrupt_handler)
41
42
43def removeHandler(method=None):
44    if method is not None:
45        @wraps(method)
46        def inner(*args, **kwargs):
47            initial = signal.getsignal(signal.SIGINT)
48            removeHandler()
49            try:
50                return method(*args, **kwargs)
51            finally:
52                signal.signal(signal.SIGINT, initial)
53        return inner
54
55    global _interrupt_handler
56    if _interrupt_handler is not None:
57        signal.signal(signal.SIGINT, _interrupt_handler.default_handler)
58