1"""
2Test suite for SocketServer.py.
3"""
4
5import contextlib
6import imp
7import os
8import select
9import signal
10import socket
11import select
12import errno
13import tempfile
14import unittest
15import SocketServer
16
17import test.test_support
18from test.test_support import reap_children, reap_threads, verbose
19try:
20    import threading
21except ImportError:
22    threading = None
23
24test.test_support.requires("network")
25
26TEST_STR = "hello world\n"
27HOST = test.test_support.HOST
28
29HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
30requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS,
31                                            'requires Unix sockets')
32HAVE_FORKING = hasattr(os, "fork") and os.name != "os2"
33requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking')
34
35def signal_alarm(n):
36    """Call signal.alarm when it exists (i.e. not on Windows)."""
37    if hasattr(signal, 'alarm'):
38        signal.alarm(n)
39
40# Remember real select() to avoid interferences with mocking
41_real_select = select.select
42
43def receive(sock, n, timeout=20):
44    r, w, x = _real_select([sock], [], [], timeout)
45    if sock in r:
46        return sock.recv(n)
47    else:
48        raise RuntimeError, "timed out on %r" % (sock,)
49
50if HAVE_UNIX_SOCKETS:
51    class ForkingUnixStreamServer(SocketServer.ForkingMixIn,
52                                  SocketServer.UnixStreamServer):
53        pass
54
55    class ForkingUnixDatagramServer(SocketServer.ForkingMixIn,
56                                    SocketServer.UnixDatagramServer):
57        pass
58
59
60@contextlib.contextmanager
61def simple_subprocess(testcase):
62    pid = os.fork()
63    if pid == 0:
64        # Don't raise an exception; it would be caught by the test harness.
65        os._exit(72)
66    yield None
67    pid2, status = os.waitpid(pid, 0)
68    testcase.assertEqual(pid2, pid)
69    testcase.assertEqual(72 << 8, status)
70
71
72@unittest.skipUnless(threading, 'Threading required for this test.')
73class SocketServerTest(unittest.TestCase):
74    """Test all socket servers."""
75
76    def setUp(self):
77        signal_alarm(60)  # Kill deadlocks after 60 seconds.
78        self.port_seed = 0
79        self.test_files = []
80
81    def tearDown(self):
82        signal_alarm(0)  # Didn't deadlock.
83        reap_children()
84
85        for fn in self.test_files:
86            try:
87                os.remove(fn)
88            except os.error:
89                pass
90        self.test_files[:] = []
91
92    def pickaddr(self, proto):
93        if proto == socket.AF_INET:
94            return (HOST, 0)
95        else:
96            # XXX: We need a way to tell AF_UNIX to pick its own name
97            # like AF_INET provides port==0.
98            dir = None
99            if os.name == 'os2':
100                dir = '\socket'
101            fn = tempfile.mktemp(prefix='unix_socket.', dir=dir)
102            if os.name == 'os2':
103                # AF_UNIX socket names on OS/2 require a specific prefix
104                # which can't include a drive letter and must also use
105                # backslashes as directory separators
106                if fn[1] == ':':
107                    fn = fn[2:]
108                if fn[0] in (os.sep, os.altsep):
109                    fn = fn[1:]
110                if os.sep == '/':
111                    fn = fn.replace(os.sep, os.altsep)
112                else:
113                    fn = fn.replace(os.altsep, os.sep)
114            self.test_files.append(fn)
115            return fn
116
117    def make_server(self, addr, svrcls, hdlrbase):
118        class MyServer(svrcls):
119            def handle_error(self, request, client_address):
120                self.close_request(request)
121                self.server_close()
122                raise
123
124        class MyHandler(hdlrbase):
125            def handle(self):
126                line = self.rfile.readline()
127                self.wfile.write(line)
128
129        if verbose: print "creating server"
130        server = MyServer(addr, MyHandler)
131        self.assertEqual(server.server_address, server.socket.getsockname())
132        return server
133
134    @reap_threads
135    def run_server(self, svrcls, hdlrbase, testfunc):
136        server = self.make_server(self.pickaddr(svrcls.address_family),
137                                  svrcls, hdlrbase)
138        # We had the OS pick a port, so pull the real address out of
139        # the server.
140        addr = server.server_address
141        if verbose:
142            print "server created"
143            print "ADDR =", addr
144            print "CLASS =", svrcls
145        t = threading.Thread(
146            name='%s serving' % svrcls,
147            target=server.serve_forever,
148            # Short poll interval to make the test finish quickly.
149            # Time between requests is short enough that we won't wake
150            # up spuriously too many times.
151            kwargs={'poll_interval':0.01})
152        t.daemon = True  # In case this function raises.
153        t.start()
154        if verbose: print "server running"
155        for i in range(3):
156            if verbose: print "test client", i
157            testfunc(svrcls.address_family, addr)
158        if verbose: print "waiting for server"
159        server.shutdown()
160        t.join()
161        server.server_close()
162        self.assertRaises(socket.error, server.socket.fileno)
163        if verbose: print "done"
164
165    def stream_examine(self, proto, addr):
166        s = socket.socket(proto, socket.SOCK_STREAM)
167        s.connect(addr)
168        s.sendall(TEST_STR)
169        buf = data = receive(s, 100)
170        while data and '\n' not in buf:
171            data = receive(s, 100)
172            buf += data
173        self.assertEqual(buf, TEST_STR)
174        s.close()
175
176    def dgram_examine(self, proto, addr):
177        s = socket.socket(proto, socket.SOCK_DGRAM)
178        if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX:
179            s.bind(self.pickaddr(proto))
180        s.sendto(TEST_STR, addr)
181        buf = data = receive(s, 100)
182        while data and '\n' not in buf:
183            data = receive(s, 100)
184            buf += data
185        self.assertEqual(buf, TEST_STR)
186        s.close()
187
188    def test_TCPServer(self):
189        self.run_server(SocketServer.TCPServer,
190                        SocketServer.StreamRequestHandler,
191                        self.stream_examine)
192
193    def test_ThreadingTCPServer(self):
194        self.run_server(SocketServer.ThreadingTCPServer,
195                        SocketServer.StreamRequestHandler,
196                        self.stream_examine)
197
198    @requires_forking
199    def test_ForkingTCPServer(self):
200        with simple_subprocess(self):
201            self.run_server(SocketServer.ForkingTCPServer,
202                            SocketServer.StreamRequestHandler,
203                            self.stream_examine)
204
205    @requires_unix_sockets
206    def test_UnixStreamServer(self):
207        self.run_server(SocketServer.UnixStreamServer,
208                        SocketServer.StreamRequestHandler,
209                        self.stream_examine)
210
211    @requires_unix_sockets
212    def test_ThreadingUnixStreamServer(self):
213        self.run_server(SocketServer.ThreadingUnixStreamServer,
214                        SocketServer.StreamRequestHandler,
215                        self.stream_examine)
216
217    @requires_unix_sockets
218    @requires_forking
219    def test_ForkingUnixStreamServer(self):
220        with simple_subprocess(self):
221            self.run_server(ForkingUnixStreamServer,
222                            SocketServer.StreamRequestHandler,
223                            self.stream_examine)
224
225    def test_UDPServer(self):
226        self.run_server(SocketServer.UDPServer,
227                        SocketServer.DatagramRequestHandler,
228                        self.dgram_examine)
229
230    def test_ThreadingUDPServer(self):
231        self.run_server(SocketServer.ThreadingUDPServer,
232                        SocketServer.DatagramRequestHandler,
233                        self.dgram_examine)
234
235    @requires_forking
236    def test_ForkingUDPServer(self):
237        with simple_subprocess(self):
238            self.run_server(SocketServer.ForkingUDPServer,
239                            SocketServer.DatagramRequestHandler,
240                            self.dgram_examine)
241
242    @contextlib.contextmanager
243    def mocked_select_module(self):
244        """Mocks the select.select() call to raise EINTR for first call"""
245        old_select = select.select
246
247        class MockSelect:
248            def __init__(self):
249                self.called = 0
250
251            def __call__(self, *args):
252                self.called += 1
253                if self.called == 1:
254                    # raise the exception on first call
255                    raise select.error(errno.EINTR, os.strerror(errno.EINTR))
256                else:
257                    # Return real select value for consecutive calls
258                    return old_select(*args)
259
260        select.select = MockSelect()
261        try:
262            yield select.select
263        finally:
264            select.select = old_select
265
266    def test_InterruptServerSelectCall(self):
267        with self.mocked_select_module() as mock_select:
268            pid = self.run_server(SocketServer.TCPServer,
269                                  SocketServer.StreamRequestHandler,
270                                  self.stream_examine)
271            # Make sure select was called again:
272            self.assertGreater(mock_select.called, 1)
273
274    @requires_unix_sockets
275    def test_UnixDatagramServer(self):
276        self.run_server(SocketServer.UnixDatagramServer,
277                        SocketServer.DatagramRequestHandler,
278                        self.dgram_examine)
279
280    @requires_unix_sockets
281    def test_ThreadingUnixDatagramServer(self):
282        self.run_server(SocketServer.ThreadingUnixDatagramServer,
283                        SocketServer.DatagramRequestHandler,
284                        self.dgram_examine)
285
286    @requires_unix_sockets
287    @requires_forking
288    def test_ForkingUnixDatagramServer(self):
289        self.run_server(ForkingUnixDatagramServer,
290                        SocketServer.DatagramRequestHandler,
291                        self.dgram_examine)
292
293    @reap_threads
294    def test_shutdown(self):
295        # Issue #2302: shutdown() should always succeed in making an
296        # other thread leave serve_forever().
297        class MyServer(SocketServer.TCPServer):
298            pass
299
300        class MyHandler(SocketServer.StreamRequestHandler):
301            pass
302
303        threads = []
304        for i in range(20):
305            s = MyServer((HOST, 0), MyHandler)
306            t = threading.Thread(
307                name='MyServer serving',
308                target=s.serve_forever,
309                kwargs={'poll_interval':0.01})
310            t.daemon = True  # In case this function raises.
311            threads.append((t, s))
312        for t, s in threads:
313            t.start()
314            s.shutdown()
315        for t, s in threads:
316            t.join()
317
318    def test_tcpserver_bind_leak(self):
319        # Issue #22435: the server socket wouldn't be closed if bind()/listen()
320        # failed.
321        # Create many servers for which bind() will fail, to see if this result
322        # in FD exhaustion.
323        for i in range(1024):
324            with self.assertRaises(OverflowError):
325                SocketServer.TCPServer((HOST, -1),
326                                       SocketServer.StreamRequestHandler)
327
328
329class MiscTestCase(unittest.TestCase):
330
331    def test_shutdown_request_called_if_verify_request_false(self):
332        # Issue #26309: BaseServer should call shutdown_request even if
333        # verify_request is False
334
335        class MyServer(SocketServer.TCPServer):
336            def verify_request(self, request, client_address):
337                return False
338
339            shutdown_called = 0
340            def shutdown_request(self, request):
341                self.shutdown_called += 1
342                SocketServer.TCPServer.shutdown_request(self, request)
343
344        server = MyServer((HOST, 0), SocketServer.StreamRequestHandler)
345        s = socket.socket(server.address_family, socket.SOCK_STREAM)
346        s.connect(server.server_address)
347        s.close()
348        server.handle_request()
349        self.assertEqual(server.shutdown_called, 1)
350        server.server_close()
351
352
353def test_main():
354    if imp.lock_held():
355        # If the import lock is held, the threads will hang
356        raise unittest.SkipTest("can't run when import lock is held")
357
358    test.test_support.run_unittest(SocketServerTest)
359
360if __name__ == "__main__":
361    test_main()
362