1"""Unit tests for contextlib.py, and other context managers."""
2
3import sys
4import tempfile
5import unittest
6from contextlib import *  # Tests __all__
7from test import test_support
8try:
9    import threading
10except ImportError:
11    threading = None
12
13
14class ContextManagerTestCase(unittest.TestCase):
15
16    def test_contextmanager_plain(self):
17        state = []
18        @contextmanager
19        def woohoo():
20            state.append(1)
21            yield 42
22            state.append(999)
23        with woohoo() as x:
24            self.assertEqual(state, [1])
25            self.assertEqual(x, 42)
26            state.append(x)
27        self.assertEqual(state, [1, 42, 999])
28
29    def test_contextmanager_finally(self):
30        state = []
31        @contextmanager
32        def woohoo():
33            state.append(1)
34            try:
35                yield 42
36            finally:
37                state.append(999)
38        with self.assertRaises(ZeroDivisionError):
39            with woohoo() as x:
40                self.assertEqual(state, [1])
41                self.assertEqual(x, 42)
42                state.append(x)
43                raise ZeroDivisionError()
44        self.assertEqual(state, [1, 42, 999])
45
46    def test_contextmanager_no_reraise(self):
47        @contextmanager
48        def whee():
49            yield
50        ctx = whee()
51        ctx.__enter__()
52        # Calling __exit__ should not result in an exception
53        self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
54
55    def test_contextmanager_trap_yield_after_throw(self):
56        @contextmanager
57        def whoo():
58            try:
59                yield
60            except:
61                yield
62        ctx = whoo()
63        ctx.__enter__()
64        self.assertRaises(
65            RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
66        )
67
68    def test_contextmanager_except(self):
69        state = []
70        @contextmanager
71        def woohoo():
72            state.append(1)
73            try:
74                yield 42
75            except ZeroDivisionError, e:
76                state.append(e.args[0])
77                self.assertEqual(state, [1, 42, 999])
78        with woohoo() as x:
79            self.assertEqual(state, [1])
80            self.assertEqual(x, 42)
81            state.append(x)
82            raise ZeroDivisionError(999)
83        self.assertEqual(state, [1, 42, 999])
84
85    def _create_contextmanager_attribs(self):
86        def attribs(**kw):
87            def decorate(func):
88                for k,v in kw.items():
89                    setattr(func,k,v)
90                return func
91            return decorate
92        @contextmanager
93        @attribs(foo='bar')
94        def baz(spam):
95            """Whee!"""
96        return baz
97
98    def test_contextmanager_attribs(self):
99        baz = self._create_contextmanager_attribs()
100        self.assertEqual(baz.__name__,'baz')
101        self.assertEqual(baz.foo, 'bar')
102
103    @unittest.skipIf(sys.flags.optimize >= 2,
104                     "Docstrings are omitted with -O2 and above")
105    def test_contextmanager_doc_attrib(self):
106        baz = self._create_contextmanager_attribs()
107        self.assertEqual(baz.__doc__, "Whee!")
108
109class NestedTestCase(unittest.TestCase):
110
111    # XXX This needs more work
112
113    def test_nested(self):
114        @contextmanager
115        def a():
116            yield 1
117        @contextmanager
118        def b():
119            yield 2
120        @contextmanager
121        def c():
122            yield 3
123        with nested(a(), b(), c()) as (x, y, z):
124            self.assertEqual(x, 1)
125            self.assertEqual(y, 2)
126            self.assertEqual(z, 3)
127
128    def test_nested_cleanup(self):
129        state = []
130        @contextmanager
131        def a():
132            state.append(1)
133            try:
134                yield 2
135            finally:
136                state.append(3)
137        @contextmanager
138        def b():
139            state.append(4)
140            try:
141                yield 5
142            finally:
143                state.append(6)
144        with self.assertRaises(ZeroDivisionError):
145            with nested(a(), b()) as (x, y):
146                state.append(x)
147                state.append(y)
148                1 // 0
149        self.assertEqual(state, [1, 4, 2, 5, 6, 3])
150
151    def test_nested_right_exception(self):
152        @contextmanager
153        def a():
154            yield 1
155        class b(object):
156            def __enter__(self):
157                return 2
158            def __exit__(self, *exc_info):
159                try:
160                    raise Exception()
161                except:
162                    pass
163        with self.assertRaises(ZeroDivisionError):
164            with nested(a(), b()) as (x, y):
165                1 // 0
166        self.assertEqual((x, y), (1, 2))
167
168    def test_nested_b_swallows(self):
169        @contextmanager
170        def a():
171            yield
172        @contextmanager
173        def b():
174            try:
175                yield
176            except:
177                # Swallow the exception
178                pass
179        try:
180            with nested(a(), b()):
181                1 // 0
182        except ZeroDivisionError:
183            self.fail("Didn't swallow ZeroDivisionError")
184
185    def test_nested_break(self):
186        @contextmanager
187        def a():
188            yield
189        state = 0
190        while True:
191            state += 1
192            with nested(a(), a()):
193                break
194            state += 10
195        self.assertEqual(state, 1)
196
197    def test_nested_continue(self):
198        @contextmanager
199        def a():
200            yield
201        state = 0
202        while state < 3:
203            state += 1
204            with nested(a(), a()):
205                continue
206            state += 10
207        self.assertEqual(state, 3)
208
209    def test_nested_return(self):
210        @contextmanager
211        def a():
212            try:
213                yield
214            except:
215                pass
216        def foo():
217            with nested(a(), a()):
218                return 1
219            return 10
220        self.assertEqual(foo(), 1)
221
222class ClosingTestCase(unittest.TestCase):
223
224    # XXX This needs more work
225
226    def test_closing(self):
227        state = []
228        class C:
229            def close(self):
230                state.append(1)
231        x = C()
232        self.assertEqual(state, [])
233        with closing(x) as y:
234            self.assertEqual(x, y)
235        self.assertEqual(state, [1])
236
237    def test_closing_error(self):
238        state = []
239        class C:
240            def close(self):
241                state.append(1)
242        x = C()
243        self.assertEqual(state, [])
244        with self.assertRaises(ZeroDivisionError):
245            with closing(x) as y:
246                self.assertEqual(x, y)
247                1 // 0
248        self.assertEqual(state, [1])
249
250class FileContextTestCase(unittest.TestCase):
251
252    def testWithOpen(self):
253        tfn = tempfile.mktemp()
254        try:
255            f = None
256            with open(tfn, "w") as f:
257                self.assertFalse(f.closed)
258                f.write("Booh\n")
259            self.assertTrue(f.closed)
260            f = None
261            with self.assertRaises(ZeroDivisionError):
262                with open(tfn, "r") as f:
263                    self.assertFalse(f.closed)
264                    self.assertEqual(f.read(), "Booh\n")
265                    1 // 0
266            self.assertTrue(f.closed)
267        finally:
268            test_support.unlink(tfn)
269
270@unittest.skipUnless(threading, 'Threading required for this test.')
271class LockContextTestCase(unittest.TestCase):
272
273    def boilerPlate(self, lock, locked):
274        self.assertFalse(locked())
275        with lock:
276            self.assertTrue(locked())
277        self.assertFalse(locked())
278        with self.assertRaises(ZeroDivisionError):
279            with lock:
280                self.assertTrue(locked())
281                1 // 0
282        self.assertFalse(locked())
283
284    def testWithLock(self):
285        lock = threading.Lock()
286        self.boilerPlate(lock, lock.locked)
287
288    def testWithRLock(self):
289        lock = threading.RLock()
290        self.boilerPlate(lock, lock._is_owned)
291
292    def testWithCondition(self):
293        lock = threading.Condition()
294        def locked():
295            return lock._is_owned()
296        self.boilerPlate(lock, locked)
297
298    def testWithSemaphore(self):
299        lock = threading.Semaphore()
300        def locked():
301            if lock.acquire(False):
302                lock.release()
303                return False
304            else:
305                return True
306        self.boilerPlate(lock, locked)
307
308    def testWithBoundedSemaphore(self):
309        lock = threading.BoundedSemaphore()
310        def locked():
311            if lock.acquire(False):
312                lock.release()
313                return False
314            else:
315                return True
316        self.boilerPlate(lock, locked)
317
318# This is needed to make the test actually run under regrtest.py!
319def test_main():
320    with test_support.check_warnings(("With-statements now directly support "
321                                      "multiple context managers",
322                                      DeprecationWarning)):
323        test_support.run_unittest(__name__)
324
325if __name__ == "__main__":
326    test_main()
327