1"""
2Test suite for SocketServer.py.
3"""
4
5import contextlib
6import imp
7import os
8import select
9import signal
10import socket
11import select
12import errno
13import tempfile
14import unittest
15import SocketServer
16
17import test.test_support
18from test.test_support import reap_children, reap_threads, verbose
19try:
20    import threading
21except ImportError:
22    threading = None
23
24test.test_support.requires("network")
25
26TEST_STR = "hello world\n"
27HOST = test.test_support.HOST
28
29HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
30HAVE_FORKING = hasattr(os, "fork") and os.name != "os2"
31
32def signal_alarm(n):
33    """Call signal.alarm when it exists (i.e. not on Windows)."""
34    if hasattr(signal, 'alarm'):
35        signal.alarm(n)
36
37# Remember real select() to avoid interferences with mocking
38_real_select = select.select
39
40def receive(sock, n, timeout=20):
41    r, w, x = _real_select([sock], [], [], timeout)
42    if sock in r:
43        return sock.recv(n)
44    else:
45        raise RuntimeError, "timed out on %r" % (sock,)
46
47if HAVE_UNIX_SOCKETS:
48    class ForkingUnixStreamServer(SocketServer.ForkingMixIn,
49                                  SocketServer.UnixStreamServer):
50        pass
51
52    class ForkingUnixDatagramServer(SocketServer.ForkingMixIn,
53                                    SocketServer.UnixDatagramServer):
54        pass
55
56
57@contextlib.contextmanager
58def simple_subprocess(testcase):
59    pid = os.fork()
60    if pid == 0:
61        # Don't raise an exception; it would be caught by the test harness.
62        os._exit(72)
63    yield None
64    pid2, status = os.waitpid(pid, 0)
65    testcase.assertEqual(pid2, pid)
66    testcase.assertEqual(72 << 8, status)
67
68
69@unittest.skipUnless(threading, 'Threading required for this test.')
70class SocketServerTest(unittest.TestCase):
71    """Test all socket servers."""
72
73    def setUp(self):
74        signal_alarm(60)  # Kill deadlocks after 60 seconds.
75        self.port_seed = 0
76        self.test_files = []
77
78    def tearDown(self):
79        signal_alarm(0)  # Didn't deadlock.
80        reap_children()
81
82        for fn in self.test_files:
83            try:
84                os.remove(fn)
85            except os.error:
86                pass
87        self.test_files[:] = []
88
89    def pickaddr(self, proto):
90        if proto == socket.AF_INET:
91            return (HOST, 0)
92        else:
93            # XXX: We need a way to tell AF_UNIX to pick its own name
94            # like AF_INET provides port==0.
95            dir = None
96            if os.name == 'os2':
97                dir = '\socket'
98            fn = tempfile.mktemp(prefix='unix_socket.', dir=dir)
99            if os.name == 'os2':
100                # AF_UNIX socket names on OS/2 require a specific prefix
101                # which can't include a drive letter and must also use
102                # backslashes as directory separators
103                if fn[1] == ':':
104                    fn = fn[2:]
105                if fn[0] in (os.sep, os.altsep):
106                    fn = fn[1:]
107                if os.sep == '/':
108                    fn = fn.replace(os.sep, os.altsep)
109                else:
110                    fn = fn.replace(os.altsep, os.sep)
111            self.test_files.append(fn)
112            return fn
113
114    def make_server(self, addr, svrcls, hdlrbase):
115        class MyServer(svrcls):
116            def handle_error(self, request, client_address):
117                self.close_request(request)
118                self.server_close()
119                raise
120
121        class MyHandler(hdlrbase):
122            def handle(self):
123                line = self.rfile.readline()
124                self.wfile.write(line)
125
126        if verbose: print "creating server"
127        server = MyServer(addr, MyHandler)
128        self.assertEqual(server.server_address, server.socket.getsockname())
129        return server
130
131    @reap_threads
132    def run_server(self, svrcls, hdlrbase, testfunc):
133        server = self.make_server(self.pickaddr(svrcls.address_family),
134                                  svrcls, hdlrbase)
135        # We had the OS pick a port, so pull the real address out of
136        # the server.
137        addr = server.server_address
138        if verbose:
139            print "server created"
140            print "ADDR =", addr
141            print "CLASS =", svrcls
142        t = threading.Thread(
143            name='%s serving' % svrcls,
144            target=server.serve_forever,
145            # Short poll interval to make the test finish quickly.
146            # Time between requests is short enough that we won't wake
147            # up spuriously too many times.
148            kwargs={'poll_interval':0.01})
149        t.daemon = True  # In case this function raises.
150        t.start()
151        if verbose: print "server running"
152        for i in range(3):
153            if verbose: print "test client", i
154            testfunc(svrcls.address_family, addr)
155        if verbose: print "waiting for server"
156        server.shutdown()
157        t.join()
158        if verbose: print "done"
159
160    def stream_examine(self, proto, addr):
161        s = socket.socket(proto, socket.SOCK_STREAM)
162        s.connect(addr)
163        s.sendall(TEST_STR)
164        buf = data = receive(s, 100)
165        while data and '\n' not in buf:
166            data = receive(s, 100)
167            buf += data
168        self.assertEqual(buf, TEST_STR)
169        s.close()
170
171    def dgram_examine(self, proto, addr):
172        s = socket.socket(proto, socket.SOCK_DGRAM)
173        s.sendto(TEST_STR, addr)
174        buf = data = receive(s, 100)
175        while data and '\n' not in buf:
176            data = receive(s, 100)
177            buf += data
178        self.assertEqual(buf, TEST_STR)
179        s.close()
180
181    def test_TCPServer(self):
182        self.run_server(SocketServer.TCPServer,
183                        SocketServer.StreamRequestHandler,
184                        self.stream_examine)
185
186    def test_ThreadingTCPServer(self):
187        self.run_server(SocketServer.ThreadingTCPServer,
188                        SocketServer.StreamRequestHandler,
189                        self.stream_examine)
190
191    if HAVE_FORKING:
192        def test_ForkingTCPServer(self):
193            with simple_subprocess(self):
194                self.run_server(SocketServer.ForkingTCPServer,
195                                SocketServer.StreamRequestHandler,
196                                self.stream_examine)
197
198    if HAVE_UNIX_SOCKETS:
199        def test_UnixStreamServer(self):
200            self.run_server(SocketServer.UnixStreamServer,
201                            SocketServer.StreamRequestHandler,
202                            self.stream_examine)
203
204        def test_ThreadingUnixStreamServer(self):
205            self.run_server(SocketServer.ThreadingUnixStreamServer,
206                            SocketServer.StreamRequestHandler,
207                            self.stream_examine)
208
209        if HAVE_FORKING:
210            def test_ForkingUnixStreamServer(self):
211                with simple_subprocess(self):
212                    self.run_server(ForkingUnixStreamServer,
213                                    SocketServer.StreamRequestHandler,
214                                    self.stream_examine)
215
216    def test_UDPServer(self):
217        self.run_server(SocketServer.UDPServer,
218                        SocketServer.DatagramRequestHandler,
219                        self.dgram_examine)
220
221    def test_ThreadingUDPServer(self):
222        self.run_server(SocketServer.ThreadingUDPServer,
223                        SocketServer.DatagramRequestHandler,
224                        self.dgram_examine)
225
226    if HAVE_FORKING:
227        def test_ForkingUDPServer(self):
228            with simple_subprocess(self):
229                self.run_server(SocketServer.ForkingUDPServer,
230                                SocketServer.DatagramRequestHandler,
231                                self.dgram_examine)
232
233    @contextlib.contextmanager
234    def mocked_select_module(self):
235        """Mocks the select.select() call to raise EINTR for first call"""
236        old_select = select.select
237
238        class MockSelect:
239            def __init__(self):
240                self.called = 0
241
242            def __call__(self, *args):
243                self.called += 1
244                if self.called == 1:
245                    # raise the exception on first call
246                    raise select.error(errno.EINTR, os.strerror(errno.EINTR))
247                else:
248                    # Return real select value for consecutive calls
249                    return old_select(*args)
250
251        select.select = MockSelect()
252        try:
253            yield select.select
254        finally:
255            select.select = old_select
256
257    def test_InterruptServerSelectCall(self):
258        with self.mocked_select_module() as mock_select:
259            pid = self.run_server(SocketServer.TCPServer,
260                                  SocketServer.StreamRequestHandler,
261                                  self.stream_examine)
262            # Make sure select was called again:
263            self.assertGreater(mock_select.called, 1)
264
265    # Alas, on Linux (at least) recvfrom() doesn't return a meaningful
266    # client address so this cannot work:
267
268    # if HAVE_UNIX_SOCKETS:
269    #     def test_UnixDatagramServer(self):
270    #         self.run_server(SocketServer.UnixDatagramServer,
271    #                         SocketServer.DatagramRequestHandler,
272    #                         self.dgram_examine)
273    #
274    #     def test_ThreadingUnixDatagramServer(self):
275    #         self.run_server(SocketServer.ThreadingUnixDatagramServer,
276    #                         SocketServer.DatagramRequestHandler,
277    #                         self.dgram_examine)
278    #
279    #     if HAVE_FORKING:
280    #         def test_ForkingUnixDatagramServer(self):
281    #             self.run_server(SocketServer.ForkingUnixDatagramServer,
282    #                             SocketServer.DatagramRequestHandler,
283    #                             self.dgram_examine)
284
285    @reap_threads
286    def test_shutdown(self):
287        # Issue #2302: shutdown() should always succeed in making an
288        # other thread leave serve_forever().
289        class MyServer(SocketServer.TCPServer):
290            pass
291
292        class MyHandler(SocketServer.StreamRequestHandler):
293            pass
294
295        threads = []
296        for i in range(20):
297            s = MyServer((HOST, 0), MyHandler)
298            t = threading.Thread(
299                name='MyServer serving',
300                target=s.serve_forever,
301                kwargs={'poll_interval':0.01})
302            t.daemon = True  # In case this function raises.
303            threads.append((t, s))
304        for t, s in threads:
305            t.start()
306            s.shutdown()
307        for t, s in threads:
308            t.join()
309
310
311def test_main():
312    if imp.lock_held():
313        # If the import lock is held, the threads will hang
314        raise unittest.SkipTest("can't run when import lock is held")
315
316    test.test_support.run_unittest(SocketServerTest)
317
318if __name__ == "__main__":
319    test_main()
320