1import asyncore
2import unittest
3import select
4import os
5import socket
6import sys
7import time
8import warnings
9import errno
10import struct
11
12from test import test_support
13from test.test_support import TESTFN, run_unittest, unlink
14from StringIO import StringIO
15
16try:
17    import threading
18except ImportError:
19    threading = None
20
21HOST = test_support.HOST
22
23class dummysocket:
24    def __init__(self):
25        self.closed = False
26
27    def close(self):
28        self.closed = True
29
30    def fileno(self):
31        return 42
32
33class dummychannel:
34    def __init__(self):
35        self.socket = dummysocket()
36
37    def close(self):
38        self.socket.close()
39
40class exitingdummy:
41    def __init__(self):
42        pass
43
44    def handle_read_event(self):
45        raise asyncore.ExitNow()
46
47    handle_write_event = handle_read_event
48    handle_close = handle_read_event
49    handle_expt_event = handle_read_event
50
51class crashingdummy:
52    def __init__(self):
53        self.error_handled = False
54
55    def handle_read_event(self):
56        raise Exception()
57
58    handle_write_event = handle_read_event
59    handle_close = handle_read_event
60    handle_expt_event = handle_read_event
61
62    def handle_error(self):
63        self.error_handled = True
64
65# used when testing senders; just collects what it gets until newline is sent
66def capture_server(evt, buf, serv):
67    try:
68        serv.listen(5)
69        conn, addr = serv.accept()
70    except socket.timeout:
71        pass
72    else:
73        n = 200
74        while n > 0:
75            r, w, e = select.select([conn], [], [])
76            if r:
77                data = conn.recv(10)
78                # keep everything except for the newline terminator
79                buf.write(data.replace('\n', ''))
80                if '\n' in data:
81                    break
82            n -= 1
83            time.sleep(0.01)
84
85        conn.close()
86    finally:
87        serv.close()
88        evt.set()
89
90
91class HelperFunctionTests(unittest.TestCase):
92    def test_readwriteexc(self):
93        # Check exception handling behavior of read, write and _exception
94
95        # check that ExitNow exceptions in the object handler method
96        # bubbles all the way up through asyncore read/write/_exception calls
97        tr1 = exitingdummy()
98        self.assertRaises(asyncore.ExitNow, asyncore.read, tr1)
99        self.assertRaises(asyncore.ExitNow, asyncore.write, tr1)
100        self.assertRaises(asyncore.ExitNow, asyncore._exception, tr1)
101
102        # check that an exception other than ExitNow in the object handler
103        # method causes the handle_error method to get called
104        tr2 = crashingdummy()
105        asyncore.read(tr2)
106        self.assertEqual(tr2.error_handled, True)
107
108        tr2 = crashingdummy()
109        asyncore.write(tr2)
110        self.assertEqual(tr2.error_handled, True)
111
112        tr2 = crashingdummy()
113        asyncore._exception(tr2)
114        self.assertEqual(tr2.error_handled, True)
115
116    # asyncore.readwrite uses constants in the select module that
117    # are not present in Windows systems (see this thread:
118    # http://mail.python.org/pipermail/python-list/2001-October/109973.html)
119    # These constants should be present as long as poll is available
120
121    @unittest.skipUnless(hasattr(select, 'poll'), 'select.poll required')
122    def test_readwrite(self):
123        # Check that correct methods are called by readwrite()
124
125        attributes = ('read', 'expt', 'write', 'closed', 'error_handled')
126
127        expected = (
128            (select.POLLIN, 'read'),
129            (select.POLLPRI, 'expt'),
130            (select.POLLOUT, 'write'),
131            (select.POLLERR, 'closed'),
132            (select.POLLHUP, 'closed'),
133            (select.POLLNVAL, 'closed'),
134            )
135
136        class testobj:
137            def __init__(self):
138                self.read = False
139                self.write = False
140                self.closed = False
141                self.expt = False
142                self.error_handled = False
143
144            def handle_read_event(self):
145                self.read = True
146
147            def handle_write_event(self):
148                self.write = True
149
150            def handle_close(self):
151                self.closed = True
152
153            def handle_expt_event(self):
154                self.expt = True
155
156            def handle_error(self):
157                self.error_handled = True
158
159        for flag, expectedattr in expected:
160            tobj = testobj()
161            self.assertEqual(getattr(tobj, expectedattr), False)
162            asyncore.readwrite(tobj, flag)
163
164            # Only the attribute modified by the routine we expect to be
165            # called should be True.
166            for attr in attributes:
167                self.assertEqual(getattr(tobj, attr), attr==expectedattr)
168
169            # check that ExitNow exceptions in the object handler method
170            # bubbles all the way up through asyncore readwrite call
171            tr1 = exitingdummy()
172            self.assertRaises(asyncore.ExitNow, asyncore.readwrite, tr1, flag)
173
174            # check that an exception other than ExitNow in the object handler
175            # method causes the handle_error method to get called
176            tr2 = crashingdummy()
177            self.assertEqual(tr2.error_handled, False)
178            asyncore.readwrite(tr2, flag)
179            self.assertEqual(tr2.error_handled, True)
180
181    def test_closeall(self):
182        self.closeall_check(False)
183
184    def test_closeall_default(self):
185        self.closeall_check(True)
186
187    def closeall_check(self, usedefault):
188        # Check that close_all() closes everything in a given map
189
190        l = []
191        testmap = {}
192        for i in range(10):
193            c = dummychannel()
194            l.append(c)
195            self.assertEqual(c.socket.closed, False)
196            testmap[i] = c
197
198        if usedefault:
199            socketmap = asyncore.socket_map
200            try:
201                asyncore.socket_map = testmap
202                asyncore.close_all()
203            finally:
204                testmap, asyncore.socket_map = asyncore.socket_map, socketmap
205        else:
206            asyncore.close_all(testmap)
207
208        self.assertEqual(len(testmap), 0)
209
210        for c in l:
211            self.assertEqual(c.socket.closed, True)
212
213    def test_compact_traceback(self):
214        try:
215            raise Exception("I don't like spam!")
216        except:
217            real_t, real_v, real_tb = sys.exc_info()
218            r = asyncore.compact_traceback()
219        else:
220            self.fail("Expected exception")
221
222        (f, function, line), t, v, info = r
223        self.assertEqual(os.path.split(f)[-1], 'test_asyncore.py')
224        self.assertEqual(function, 'test_compact_traceback')
225        self.assertEqual(t, real_t)
226        self.assertEqual(v, real_v)
227        self.assertEqual(info, '[%s|%s|%s]' % (f, function, line))
228
229
230class DispatcherTests(unittest.TestCase):
231    def setUp(self):
232        pass
233
234    def tearDown(self):
235        asyncore.close_all()
236
237    def test_basic(self):
238        d = asyncore.dispatcher()
239        self.assertEqual(d.readable(), True)
240        self.assertEqual(d.writable(), True)
241
242    def test_repr(self):
243        d = asyncore.dispatcher()
244        self.assertEqual(repr(d), '<asyncore.dispatcher at %#x>' % id(d))
245
246    def test_log(self):
247        d = asyncore.dispatcher()
248
249        # capture output of dispatcher.log() (to stderr)
250        fp = StringIO()
251        stderr = sys.stderr
252        l1 = "Lovely spam! Wonderful spam!"
253        l2 = "I don't like spam!"
254        try:
255            sys.stderr = fp
256            d.log(l1)
257            d.log(l2)
258        finally:
259            sys.stderr = stderr
260
261        lines = fp.getvalue().splitlines()
262        self.assertEqual(lines, ['log: %s' % l1, 'log: %s' % l2])
263
264    def test_log_info(self):
265        d = asyncore.dispatcher()
266
267        # capture output of dispatcher.log_info() (to stdout via print)
268        fp = StringIO()
269        stdout = sys.stdout
270        l1 = "Have you got anything without spam?"
271        l2 = "Why can't she have egg bacon spam and sausage?"
272        l3 = "THAT'S got spam in it!"
273        try:
274            sys.stdout = fp
275            d.log_info(l1, 'EGGS')
276            d.log_info(l2)
277            d.log_info(l3, 'SPAM')
278        finally:
279            sys.stdout = stdout
280
281        lines = fp.getvalue().splitlines()
282        expected = ['EGGS: %s' % l1, 'info: %s' % l2, 'SPAM: %s' % l3]
283
284        self.assertEqual(lines, expected)
285
286    def test_unhandled(self):
287        d = asyncore.dispatcher()
288        d.ignore_log_types = ()
289
290        # capture output of dispatcher.log_info() (to stdout via print)
291        fp = StringIO()
292        stdout = sys.stdout
293        try:
294            sys.stdout = fp
295            d.handle_expt()
296            d.handle_read()
297            d.handle_write()
298            d.handle_connect()
299            d.handle_accept()
300        finally:
301            sys.stdout = stdout
302
303        lines = fp.getvalue().splitlines()
304        expected = ['warning: unhandled incoming priority event',
305                    'warning: unhandled read event',
306                    'warning: unhandled write event',
307                    'warning: unhandled connect event',
308                    'warning: unhandled accept event']
309        self.assertEqual(lines, expected)
310
311    def test_issue_8594(self):
312        # XXX - this test is supposed to be removed in next major Python
313        # version
314        d = asyncore.dispatcher(socket.socket())
315        # make sure the error message no longer refers to the socket
316        # object but the dispatcher instance instead
317        self.assertRaisesRegexp(AttributeError, 'dispatcher instance',
318                                getattr, d, 'foo')
319        # cheap inheritance with the underlying socket is supposed
320        # to still work but a DeprecationWarning is expected
321        with warnings.catch_warnings(record=True) as w:
322            warnings.simplefilter("always")
323            family = d.family
324            self.assertEqual(family, socket.AF_INET)
325            self.assertEqual(len(w), 1)
326            self.assertTrue(issubclass(w[0].category, DeprecationWarning))
327
328    def test_strerror(self):
329        # refers to bug #8573
330        err = asyncore._strerror(errno.EPERM)
331        if hasattr(os, 'strerror'):
332            self.assertEqual(err, os.strerror(errno.EPERM))
333        err = asyncore._strerror(-1)
334        self.assertTrue(err != "")
335
336
337class dispatcherwithsend_noread(asyncore.dispatcher_with_send):
338    def readable(self):
339        return False
340
341    def handle_connect(self):
342        pass
343
344class DispatcherWithSendTests(unittest.TestCase):
345    usepoll = False
346
347    def setUp(self):
348        pass
349
350    def tearDown(self):
351        asyncore.close_all()
352
353    @unittest.skipUnless(threading, 'Threading required for this test.')
354    @test_support.reap_threads
355    def test_send(self):
356        evt = threading.Event()
357        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
358        sock.settimeout(3)
359        port = test_support.bind_port(sock)
360
361        cap = StringIO()
362        args = (evt, cap, sock)
363        t = threading.Thread(target=capture_server, args=args)
364        t.start()
365        try:
366            # wait a little longer for the server to initialize (it sometimes
367            # refuses connections on slow machines without this wait)
368            time.sleep(0.2)
369
370            data = "Suppose there isn't a 16-ton weight?"
371            d = dispatcherwithsend_noread()
372            d.create_socket(socket.AF_INET, socket.SOCK_STREAM)
373            d.connect((HOST, port))
374
375            # give time for socket to connect
376            time.sleep(0.1)
377
378            d.send(data)
379            d.send(data)
380            d.send('\n')
381
382            n = 1000
383            while d.out_buffer and n > 0:
384                asyncore.poll()
385                n -= 1
386
387            evt.wait()
388
389            self.assertEqual(cap.getvalue(), data*2)
390        finally:
391            t.join()
392
393
394class DispatcherWithSendTests_UsePoll(DispatcherWithSendTests):
395    usepoll = True
396
397@unittest.skipUnless(hasattr(asyncore, 'file_wrapper'),
398                     'asyncore.file_wrapper required')
399class FileWrapperTest(unittest.TestCase):
400    def setUp(self):
401        self.d = "It's not dead, it's sleeping!"
402        with file(TESTFN, 'w') as h:
403            h.write(self.d)
404
405    def tearDown(self):
406        unlink(TESTFN)
407
408    def test_recv(self):
409        fd = os.open(TESTFN, os.O_RDONLY)
410        w = asyncore.file_wrapper(fd)
411        os.close(fd)
412
413        self.assertNotEqual(w.fd, fd)
414        self.assertNotEqual(w.fileno(), fd)
415        self.assertEqual(w.recv(13), "It's not dead")
416        self.assertEqual(w.read(6), ", it's")
417        w.close()
418        self.assertRaises(OSError, w.read, 1)
419
420
421    def test_send(self):
422        d1 = "Come again?"
423        d2 = "I want to buy some cheese."
424        fd = os.open(TESTFN, os.O_WRONLY | os.O_APPEND)
425        w = asyncore.file_wrapper(fd)
426        os.close(fd)
427
428        w.write(d1)
429        w.send(d2)
430        w.close()
431        self.assertEqual(file(TESTFN).read(), self.d + d1 + d2)
432
433    @unittest.skipUnless(hasattr(asyncore, 'file_dispatcher'),
434                         'asyncore.file_dispatcher required')
435    def test_dispatcher(self):
436        fd = os.open(TESTFN, os.O_RDONLY)
437        data = []
438        class FileDispatcher(asyncore.file_dispatcher):
439            def handle_read(self):
440                data.append(self.recv(29))
441        s = FileDispatcher(fd)
442        os.close(fd)
443        asyncore.loop(timeout=0.01, use_poll=True, count=2)
444        self.assertEqual(b"".join(data), self.d)
445
446
447class BaseTestHandler(asyncore.dispatcher):
448
449    def __init__(self, sock=None):
450        asyncore.dispatcher.__init__(self, sock)
451        self.flag = False
452
453    def handle_accept(self):
454        raise Exception("handle_accept not supposed to be called")
455
456    def handle_connect(self):
457        raise Exception("handle_connect not supposed to be called")
458
459    def handle_expt(self):
460        raise Exception("handle_expt not supposed to be called")
461
462    def handle_close(self):
463        raise Exception("handle_close not supposed to be called")
464
465    def handle_error(self):
466        raise
467
468
469class TCPServer(asyncore.dispatcher):
470    """A server which listens on an address and dispatches the
471    connection to a handler.
472    """
473
474    def __init__(self, handler=BaseTestHandler, host=HOST, port=0):
475        asyncore.dispatcher.__init__(self)
476        self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
477        self.set_reuse_addr()
478        self.bind((host, port))
479        self.listen(5)
480        self.handler = handler
481
482    @property
483    def address(self):
484        return self.socket.getsockname()[:2]
485
486    def handle_accept(self):
487        pair = self.accept()
488        if pair is not None:
489            self.handler(pair[0])
490
491    def handle_error(self):
492        raise
493
494
495class BaseClient(BaseTestHandler):
496
497    def __init__(self, address):
498        BaseTestHandler.__init__(self)
499        self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
500        self.connect(address)
501
502    def handle_connect(self):
503        pass
504
505
506class BaseTestAPI(unittest.TestCase):
507
508    def tearDown(self):
509        asyncore.close_all()
510
511    def loop_waiting_for_flag(self, instance, timeout=5):
512        timeout = float(timeout) / 100
513        count = 100
514        while asyncore.socket_map and count > 0:
515            asyncore.loop(timeout=0.01, count=1, use_poll=self.use_poll)
516            if instance.flag:
517                return
518            count -= 1
519            time.sleep(timeout)
520        self.fail("flag not set")
521
522    def test_handle_connect(self):
523        # make sure handle_connect is called on connect()
524
525        class TestClient(BaseClient):
526            def handle_connect(self):
527                self.flag = True
528
529        server = TCPServer()
530        client = TestClient(server.address)
531        self.loop_waiting_for_flag(client)
532
533    def test_handle_accept(self):
534        # make sure handle_accept() is called when a client connects
535
536        class TestListener(BaseTestHandler):
537
538            def __init__(self):
539                BaseTestHandler.__init__(self)
540                self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
541                self.bind((HOST, 0))
542                self.listen(5)
543                self.address = self.socket.getsockname()[:2]
544
545            def handle_accept(self):
546                self.flag = True
547
548        server = TestListener()
549        client = BaseClient(server.address)
550        self.loop_waiting_for_flag(server)
551
552    def test_handle_read(self):
553        # make sure handle_read is called on data received
554
555        class TestClient(BaseClient):
556            def handle_read(self):
557                self.flag = True
558
559        class TestHandler(BaseTestHandler):
560            def __init__(self, conn):
561                BaseTestHandler.__init__(self, conn)
562                self.send('x' * 1024)
563
564        server = TCPServer(TestHandler)
565        client = TestClient(server.address)
566        self.loop_waiting_for_flag(client)
567
568    def test_handle_write(self):
569        # make sure handle_write is called
570
571        class TestClient(BaseClient):
572            def handle_write(self):
573                self.flag = True
574
575        server = TCPServer()
576        client = TestClient(server.address)
577        self.loop_waiting_for_flag(client)
578
579    def test_handle_close(self):
580        # make sure handle_close is called when the other end closes
581        # the connection
582
583        class TestClient(BaseClient):
584
585            def handle_read(self):
586                # in order to make handle_close be called we are supposed
587                # to make at least one recv() call
588                self.recv(1024)
589
590            def handle_close(self):
591                self.flag = True
592                self.close()
593
594        class TestHandler(BaseTestHandler):
595            def __init__(self, conn):
596                BaseTestHandler.__init__(self, conn)
597                self.close()
598
599        server = TCPServer(TestHandler)
600        client = TestClient(server.address)
601        self.loop_waiting_for_flag(client)
602
603    @unittest.skipIf(sys.platform.startswith("sunos"),
604                     "OOB support is broken on Solaris")
605    def test_handle_expt(self):
606        # Make sure handle_expt is called on OOB data received.
607        # Note: this might fail on some platforms as OOB data is
608        # tenuously supported and rarely used.
609
610        class TestClient(BaseClient):
611            def handle_expt(self):
612                self.flag = True
613
614        class TestHandler(BaseTestHandler):
615            def __init__(self, conn):
616                BaseTestHandler.__init__(self, conn)
617                self.socket.send(chr(244), socket.MSG_OOB)
618
619        server = TCPServer(TestHandler)
620        client = TestClient(server.address)
621        self.loop_waiting_for_flag(client)
622
623    def test_handle_error(self):
624
625        class TestClient(BaseClient):
626            def handle_write(self):
627                1.0 / 0
628            def handle_error(self):
629                self.flag = True
630                try:
631                    raise
632                except ZeroDivisionError:
633                    pass
634                else:
635                    raise Exception("exception not raised")
636
637        server = TCPServer()
638        client = TestClient(server.address)
639        self.loop_waiting_for_flag(client)
640
641    def test_connection_attributes(self):
642        server = TCPServer()
643        client = BaseClient(server.address)
644
645        # we start disconnected
646        self.assertFalse(server.connected)
647        self.assertTrue(server.accepting)
648        # this can't be taken for granted across all platforms
649        #self.assertFalse(client.connected)
650        self.assertFalse(client.accepting)
651
652        # execute some loops so that client connects to server
653        asyncore.loop(timeout=0.01, use_poll=self.use_poll, count=100)
654        self.assertFalse(server.connected)
655        self.assertTrue(server.accepting)
656        self.assertTrue(client.connected)
657        self.assertFalse(client.accepting)
658
659        # disconnect the client
660        client.close()
661        self.assertFalse(server.connected)
662        self.assertTrue(server.accepting)
663        self.assertFalse(client.connected)
664        self.assertFalse(client.accepting)
665
666        # stop serving
667        server.close()
668        self.assertFalse(server.connected)
669        self.assertFalse(server.accepting)
670
671    def test_create_socket(self):
672        s = asyncore.dispatcher()
673        s.create_socket(socket.AF_INET, socket.SOCK_STREAM)
674        self.assertEqual(s.socket.family, socket.AF_INET)
675        self.assertEqual(s.socket.type, socket.SOCK_STREAM)
676
677    def test_bind(self):
678        s1 = asyncore.dispatcher()
679        s1.create_socket(socket.AF_INET, socket.SOCK_STREAM)
680        s1.bind((HOST, 0))
681        s1.listen(5)
682        port = s1.socket.getsockname()[1]
683
684        s2 = asyncore.dispatcher()
685        s2.create_socket(socket.AF_INET, socket.SOCK_STREAM)
686        # EADDRINUSE indicates the socket was correctly bound
687        self.assertRaises(socket.error, s2.bind, (HOST, port))
688
689    def test_set_reuse_addr(self):
690        sock = socket.socket()
691        try:
692            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
693        except socket.error:
694            unittest.skip("SO_REUSEADDR not supported on this platform")
695        else:
696            # if SO_REUSEADDR succeeded for sock we expect asyncore
697            # to do the same
698            s = asyncore.dispatcher(socket.socket())
699            self.assertFalse(s.socket.getsockopt(socket.SOL_SOCKET,
700                                                 socket.SO_REUSEADDR))
701            s.create_socket(socket.AF_INET, socket.SOCK_STREAM)
702            s.set_reuse_addr()
703            self.assertTrue(s.socket.getsockopt(socket.SOL_SOCKET,
704                                                 socket.SO_REUSEADDR))
705        finally:
706            sock.close()
707
708    @unittest.skipUnless(threading, 'Threading required for this test.')
709    @test_support.reap_threads
710    def test_quick_connect(self):
711        # see: http://bugs.python.org/issue10340
712        server = TCPServer()
713        t = threading.Thread(target=lambda: asyncore.loop(timeout=0.1, count=500))
714        t.start()
715        self.addCleanup(t.join)
716
717        for x in xrange(20):
718            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
719            s.settimeout(.2)
720            s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
721                         struct.pack('ii', 1, 0))
722            try:
723                s.connect(server.address)
724            except socket.error:
725                pass
726            finally:
727                s.close()
728
729
730class TestAPI_UseSelect(BaseTestAPI):
731    use_poll = False
732
733@unittest.skipUnless(hasattr(select, 'poll'), 'select.poll required')
734class TestAPI_UsePoll(BaseTestAPI):
735    use_poll = True
736
737
738def test_main():
739    tests = [HelperFunctionTests, DispatcherTests, DispatcherWithSendTests,
740             DispatcherWithSendTests_UsePoll, TestAPI_UseSelect,
741             TestAPI_UsePoll, FileWrapperTest]
742    run_unittest(*tests)
743
744if __name__ == "__main__":
745    test_main()
746