1"""
2Test suite for socketserver.
3"""
4
5import contextlib
6import io
7import os
8import select
9import signal
10import socket
11import tempfile
12import unittest
13import socketserver
14
15import test.support
16from test.support import reap_children, reap_threads, verbose
17try:
18    import threading
19except ImportError:
20    threading = None
21
22test.support.requires("network")
23
24TEST_STR = b"hello world\n"
25HOST = test.support.HOST
26
27HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
28requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS,
29                                            'requires Unix sockets')
30HAVE_FORKING = hasattr(os, "fork")
31requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking')
32
33def signal_alarm(n):
34    """Call signal.alarm when it exists (i.e. not on Windows)."""
35    if hasattr(signal, 'alarm'):
36        signal.alarm(n)
37
38# Remember real select() to avoid interferences with mocking
39_real_select = select.select
40
41def receive(sock, n, timeout=20):
42    r, w, x = _real_select([sock], [], [], timeout)
43    if sock in r:
44        return sock.recv(n)
45    else:
46        raise RuntimeError("timed out on %r" % (sock,))
47
48if HAVE_UNIX_SOCKETS and HAVE_FORKING:
49    class ForkingUnixStreamServer(socketserver.ForkingMixIn,
50                                  socketserver.UnixStreamServer):
51        pass
52
53    class ForkingUnixDatagramServer(socketserver.ForkingMixIn,
54                                    socketserver.UnixDatagramServer):
55        pass
56
57
58@contextlib.contextmanager
59def simple_subprocess(testcase):
60    """Tests that a custom child process is not waited on (Issue 1540386)"""
61    pid = os.fork()
62    if pid == 0:
63        # Don't raise an exception; it would be caught by the test harness.
64        os._exit(72)
65    yield None
66    pid2, status = os.waitpid(pid, 0)
67    testcase.assertEqual(pid2, pid)
68    testcase.assertEqual(72 << 8, status)
69
70
71@unittest.skipUnless(threading, 'Threading required for this test.')
72class SocketServerTest(unittest.TestCase):
73    """Test all socket servers."""
74
75    def setUp(self):
76        signal_alarm(60)  # Kill deadlocks after 60 seconds.
77        self.port_seed = 0
78        self.test_files = []
79
80    def tearDown(self):
81        signal_alarm(0)  # Didn't deadlock.
82        reap_children()
83
84        for fn in self.test_files:
85            try:
86                os.remove(fn)
87            except OSError:
88                pass
89        self.test_files[:] = []
90
91    def pickaddr(self, proto):
92        if proto == socket.AF_INET:
93            return (HOST, 0)
94        else:
95            # XXX: We need a way to tell AF_UNIX to pick its own name
96            # like AF_INET provides port==0.
97            dir = None
98            fn = tempfile.mktemp(prefix='unix_socket.', dir=dir)
99            self.test_files.append(fn)
100            return fn
101
102    def make_server(self, addr, svrcls, hdlrbase):
103        class MyServer(svrcls):
104            def handle_error(self, request, client_address):
105                self.close_request(request)
106                raise
107
108        class MyHandler(hdlrbase):
109            def handle(self):
110                line = self.rfile.readline()
111                self.wfile.write(line)
112
113        if verbose: print("creating server")
114        server = MyServer(addr, MyHandler)
115        self.assertEqual(server.server_address, server.socket.getsockname())
116        return server
117
118    @reap_threads
119    def run_server(self, svrcls, hdlrbase, testfunc):
120        server = self.make_server(self.pickaddr(svrcls.address_family),
121                                  svrcls, hdlrbase)
122        # We had the OS pick a port, so pull the real address out of
123        # the server.
124        addr = server.server_address
125        if verbose:
126            print("ADDR =", addr)
127            print("CLASS =", svrcls)
128
129        t = threading.Thread(
130            name='%s serving' % svrcls,
131            target=server.serve_forever,
132            # Short poll interval to make the test finish quickly.
133            # Time between requests is short enough that we won't wake
134            # up spuriously too many times.
135            kwargs={'poll_interval':0.01})
136        t.daemon = True  # In case this function raises.
137        t.start()
138        if verbose: print("server running")
139        for i in range(3):
140            if verbose: print("test client", i)
141            testfunc(svrcls.address_family, addr)
142        if verbose: print("waiting for server")
143        server.shutdown()
144        t.join()
145        server.server_close()
146        self.assertEqual(-1, server.socket.fileno())
147        if verbose: print("done")
148
149    def stream_examine(self, proto, addr):
150        s = socket.socket(proto, socket.SOCK_STREAM)
151        s.connect(addr)
152        s.sendall(TEST_STR)
153        buf = data = receive(s, 100)
154        while data and b'\n' not in buf:
155            data = receive(s, 100)
156            buf += data
157        self.assertEqual(buf, TEST_STR)
158        s.close()
159
160    def dgram_examine(self, proto, addr):
161        s = socket.socket(proto, socket.SOCK_DGRAM)
162        if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX:
163            s.bind(self.pickaddr(proto))
164        s.sendto(TEST_STR, addr)
165        buf = data = receive(s, 100)
166        while data and b'\n' not in buf:
167            data = receive(s, 100)
168            buf += data
169        self.assertEqual(buf, TEST_STR)
170        s.close()
171
172    def test_TCPServer(self):
173        self.run_server(socketserver.TCPServer,
174                        socketserver.StreamRequestHandler,
175                        self.stream_examine)
176
177    def test_ThreadingTCPServer(self):
178        self.run_server(socketserver.ThreadingTCPServer,
179                        socketserver.StreamRequestHandler,
180                        self.stream_examine)
181
182    @requires_forking
183    def test_ForkingTCPServer(self):
184        with simple_subprocess(self):
185            self.run_server(socketserver.ForkingTCPServer,
186                            socketserver.StreamRequestHandler,
187                            self.stream_examine)
188
189    @requires_unix_sockets
190    def test_UnixStreamServer(self):
191        self.run_server(socketserver.UnixStreamServer,
192                        socketserver.StreamRequestHandler,
193                        self.stream_examine)
194
195    @requires_unix_sockets
196    def test_ThreadingUnixStreamServer(self):
197        self.run_server(socketserver.ThreadingUnixStreamServer,
198                        socketserver.StreamRequestHandler,
199                        self.stream_examine)
200
201    @requires_unix_sockets
202    @requires_forking
203    def test_ForkingUnixStreamServer(self):
204        with simple_subprocess(self):
205            self.run_server(ForkingUnixStreamServer,
206                            socketserver.StreamRequestHandler,
207                            self.stream_examine)
208
209    def test_UDPServer(self):
210        self.run_server(socketserver.UDPServer,
211                        socketserver.DatagramRequestHandler,
212                        self.dgram_examine)
213
214    def test_ThreadingUDPServer(self):
215        self.run_server(socketserver.ThreadingUDPServer,
216                        socketserver.DatagramRequestHandler,
217                        self.dgram_examine)
218
219    @requires_forking
220    def test_ForkingUDPServer(self):
221        with simple_subprocess(self):
222            self.run_server(socketserver.ForkingUDPServer,
223                            socketserver.DatagramRequestHandler,
224                            self.dgram_examine)
225
226    @requires_unix_sockets
227    def test_UnixDatagramServer(self):
228        self.run_server(socketserver.UnixDatagramServer,
229                        socketserver.DatagramRequestHandler,
230                        self.dgram_examine)
231
232    @requires_unix_sockets
233    def test_ThreadingUnixDatagramServer(self):
234        self.run_server(socketserver.ThreadingUnixDatagramServer,
235                        socketserver.DatagramRequestHandler,
236                        self.dgram_examine)
237
238    @requires_unix_sockets
239    @requires_forking
240    def test_ForkingUnixDatagramServer(self):
241        self.run_server(ForkingUnixDatagramServer,
242                        socketserver.DatagramRequestHandler,
243                        self.dgram_examine)
244
245    @reap_threads
246    def test_shutdown(self):
247        # Issue #2302: shutdown() should always succeed in making an
248        # other thread leave serve_forever().
249        class MyServer(socketserver.TCPServer):
250            pass
251
252        class MyHandler(socketserver.StreamRequestHandler):
253            pass
254
255        threads = []
256        for i in range(20):
257            s = MyServer((HOST, 0), MyHandler)
258            t = threading.Thread(
259                name='MyServer serving',
260                target=s.serve_forever,
261                kwargs={'poll_interval':0.01})
262            t.daemon = True  # In case this function raises.
263            threads.append((t, s))
264        for t, s in threads:
265            t.start()
266            s.shutdown()
267        for t, s in threads:
268            t.join()
269            s.server_close()
270
271    def test_tcpserver_bind_leak(self):
272        # Issue #22435: the server socket wouldn't be closed if bind()/listen()
273        # failed.
274        # Create many servers for which bind() will fail, to see if this result
275        # in FD exhaustion.
276        for i in range(1024):
277            with self.assertRaises(OverflowError):
278                socketserver.TCPServer((HOST, -1),
279                                       socketserver.StreamRequestHandler)
280
281    def test_context_manager(self):
282        with socketserver.TCPServer((HOST, 0),
283                                    socketserver.StreamRequestHandler) as server:
284            pass
285        self.assertEqual(-1, server.socket.fileno())
286
287
288class ErrorHandlerTest(unittest.TestCase):
289    """Test that the servers pass normal exceptions from the handler to
290    handle_error(), and that exiting exceptions like SystemExit and
291    KeyboardInterrupt are not passed."""
292
293    def tearDown(self):
294        test.support.unlink(test.support.TESTFN)
295
296    def test_sync_handled(self):
297        BaseErrorTestServer(ValueError)
298        self.check_result(handled=True)
299
300    def test_sync_not_handled(self):
301        with self.assertRaises(SystemExit):
302            BaseErrorTestServer(SystemExit)
303        self.check_result(handled=False)
304
305    @unittest.skipUnless(threading, 'Threading required for this test.')
306    def test_threading_handled(self):
307        ThreadingErrorTestServer(ValueError)
308        self.check_result(handled=True)
309
310    @unittest.skipUnless(threading, 'Threading required for this test.')
311    def test_threading_not_handled(self):
312        ThreadingErrorTestServer(SystemExit)
313        self.check_result(handled=False)
314
315    @requires_forking
316    def test_forking_handled(self):
317        ForkingErrorTestServer(ValueError)
318        self.check_result(handled=True)
319
320    @requires_forking
321    def test_forking_not_handled(self):
322        ForkingErrorTestServer(SystemExit)
323        self.check_result(handled=False)
324
325    def check_result(self, handled):
326        with open(test.support.TESTFN) as log:
327            expected = 'Handler called\n' + 'Error handled\n' * handled
328            self.assertEqual(log.read(), expected)
329
330
331class BaseErrorTestServer(socketserver.TCPServer):
332    def __init__(self, exception):
333        self.exception = exception
334        super().__init__((HOST, 0), BadHandler)
335        with socket.create_connection(self.server_address):
336            pass
337        try:
338            self.handle_request()
339        finally:
340            self.server_close()
341        self.wait_done()
342
343    def handle_error(self, request, client_address):
344        with open(test.support.TESTFN, 'a') as log:
345            log.write('Error handled\n')
346
347    def wait_done(self):
348        pass
349
350
351class BadHandler(socketserver.BaseRequestHandler):
352    def handle(self):
353        with open(test.support.TESTFN, 'a') as log:
354            log.write('Handler called\n')
355        raise self.server.exception('Test error')
356
357
358class ThreadingErrorTestServer(socketserver.ThreadingMixIn,
359        BaseErrorTestServer):
360    def __init__(self, *pos, **kw):
361        self.done = threading.Event()
362        super().__init__(*pos, **kw)
363
364    def shutdown_request(self, *pos, **kw):
365        super().shutdown_request(*pos, **kw)
366        self.done.set()
367
368    def wait_done(self):
369        self.done.wait()
370
371
372if HAVE_FORKING:
373    class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer):
374        def wait_done(self):
375            [child] = self.active_children
376            os.waitpid(child, 0)
377            self.active_children.clear()
378
379
380class SocketWriterTest(unittest.TestCase):
381    def test_basics(self):
382        class Handler(socketserver.StreamRequestHandler):
383            def handle(self):
384                self.server.wfile = self.wfile
385                self.server.wfile_fileno = self.wfile.fileno()
386                self.server.request_fileno = self.request.fileno()
387
388        server = socketserver.TCPServer((HOST, 0), Handler)
389        self.addCleanup(server.server_close)
390        s = socket.socket(
391            server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
392        with s:
393            s.connect(server.server_address)
394        server.handle_request()
395        self.assertIsInstance(server.wfile, io.BufferedIOBase)
396        self.assertEqual(server.wfile_fileno, server.request_fileno)
397
398    @unittest.skipUnless(threading, 'Threading required for this test.')
399    def test_write(self):
400        # Test that wfile.write() sends data immediately, and that it does
401        # not truncate sends when interrupted by a Unix signal
402        pthread_kill = test.support.get_attribute(signal, 'pthread_kill')
403
404        class Handler(socketserver.StreamRequestHandler):
405            def handle(self):
406                self.server.sent1 = self.wfile.write(b'write data\n')
407                # Should be sent immediately, without requiring flush()
408                self.server.received = self.rfile.readline()
409                big_chunk = b'\0' * test.support.SOCK_MAX_SIZE
410                self.server.sent2 = self.wfile.write(big_chunk)
411
412        server = socketserver.TCPServer((HOST, 0), Handler)
413        self.addCleanup(server.server_close)
414        interrupted = threading.Event()
415
416        def signal_handler(signum, frame):
417            interrupted.set()
418
419        original = signal.signal(signal.SIGUSR1, signal_handler)
420        self.addCleanup(signal.signal, signal.SIGUSR1, original)
421        response1 = None
422        received2 = None
423        main_thread = threading.get_ident()
424
425        def run_client():
426            s = socket.socket(server.address_family, socket.SOCK_STREAM,
427                socket.IPPROTO_TCP)
428            with s, s.makefile('rb') as reader:
429                s.connect(server.server_address)
430                nonlocal response1
431                response1 = reader.readline()
432                s.sendall(b'client response\n')
433
434                reader.read(100)
435                # The main thread should now be blocking in a send() syscall.
436                # But in theory, it could get interrupted by other signals,
437                # and then retried. So keep sending the signal in a loop, in
438                # case an earlier signal happens to be delivered at an
439                # inconvenient moment.
440                while True:
441                    pthread_kill(main_thread, signal.SIGUSR1)
442                    if interrupted.wait(timeout=float(1)):
443                        break
444                nonlocal received2
445                received2 = len(reader.read())
446
447        background = threading.Thread(target=run_client)
448        background.start()
449        server.handle_request()
450        background.join()
451        self.assertEqual(server.sent1, len(response1))
452        self.assertEqual(response1, b'write data\n')
453        self.assertEqual(server.received, b'client response\n')
454        self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE)
455        self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100)
456
457
458class MiscTestCase(unittest.TestCase):
459
460    def test_all(self):
461        # objects defined in the module should be in __all__
462        expected = []
463        for name in dir(socketserver):
464            if not name.startswith('_'):
465                mod_object = getattr(socketserver, name)
466                if getattr(mod_object, '__module__', None) == 'socketserver':
467                    expected.append(name)
468        self.assertCountEqual(socketserver.__all__, expected)
469
470    def test_shutdown_request_called_if_verify_request_false(self):
471        # Issue #26309: BaseServer should call shutdown_request even if
472        # verify_request is False
473
474        class MyServer(socketserver.TCPServer):
475            def verify_request(self, request, client_address):
476                return False
477
478            shutdown_called = 0
479            def shutdown_request(self, request):
480                self.shutdown_called += 1
481                socketserver.TCPServer.shutdown_request(self, request)
482
483        server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
484        s = socket.socket(server.address_family, socket.SOCK_STREAM)
485        s.connect(server.server_address)
486        s.close()
487        server.handle_request()
488        self.assertEqual(server.shutdown_called, 1)
489        server.server_close()
490
491
492if __name__ == "__main__":
493    unittest.main()
494