test_socketserver.py revision 15ebc88d87d2ff8f520581a9f6a6816d78a7e504
1"""
2Test suite for SocketServer.py.
3"""
4
5import os
6import socket
7import errno
8import imp
9import select
10import time
11import threading
12from functools import wraps
13import unittest
14import SocketServer
15
16import test.test_support
17from test.test_support import reap_children, verbose, TestSkipped
18from test.test_support import TESTFN as TEST_FILE
19
20test.test_support.requires("network")
21
22NREQ = 3
23DELAY = 0.5
24TEST_STR = b"hello world\n"
25HOST = "localhost"
26
27HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
28HAVE_FORKING = hasattr(os, "fork") and os.name != "os2"
29
30
31class MyMixinHandler:
32    def handle(self):
33        time.sleep(DELAY)
34        line = self.rfile.readline()
35        time.sleep(DELAY)
36        self.wfile.write(line)
37
38
39def receive(sock, n, timeout=20):
40    r, w, x = select.select([sock], [], [], timeout)
41    if sock in r:
42        return sock.recv(n)
43    else:
44        raise RuntimeError("timed out on %r" % (sock,))
45
46
47class MyStreamHandler(MyMixinHandler, SocketServer.StreamRequestHandler):
48    pass
49
50class MyDatagramHandler(MyMixinHandler,
51    SocketServer.DatagramRequestHandler):
52    pass
53
54if HAVE_UNIX_SOCKETS:
55    class ForkingUnixStreamServer(SocketServer.ForkingMixIn,
56                                  SocketServer.UnixStreamServer):
57        pass
58
59    class ForkingUnixDatagramServer(SocketServer.ForkingMixIn,
60                                    SocketServer.UnixDatagramServer):
61        pass
62
63
64class MyMixinServer:
65    def serve_a_few(self):
66        for i in range(NREQ):
67            self.handle_request()
68
69    def handle_error(self, request, client_address):
70        self.close_request(request)
71        self.server_close()
72        raise
73
74def receive(sock, n, timeout=20):
75    r, w, x = select.select([sock], [], [], timeout)
76    if sock in r:
77        return sock.recv(n)
78    else:
79        raise RuntimeError("timed out on %r" % (sock,))
80
81def testdgram(proto, addr):
82    s = socket.socket(proto, socket.SOCK_DGRAM)
83    s.sendto(teststring, addr)
84    buf = data = receive(s, 100)
85    while data and b'\n' not in buf:
86        data = receive(s, 100)
87        buf += data
88    verify(buf == teststring)
89    s.close()
90
91def teststream(proto, addr):
92    s = socket.socket(proto, socket.SOCK_STREAM)
93    s.connect(addr)
94    s.sendall(teststring)
95    buf = data = receive(s, 100)
96    while data and b'\n' not in buf:
97        data = receive(s, 100)
98        buf += data
99    verify(buf == teststring)
100    s.close()
101
102class ServerThread(threading.Thread):
103    def __init__(self, addr, svrcls, hdlrcls):
104        threading.Thread.__init__(self)
105        self.__addr = addr
106        self.__svrcls = svrcls
107        self.__hdlrcls = hdlrcls
108        self.ready = threading.Event()
109
110    def run(self):
111        class svrcls(MyMixinServer, self.__svrcls):
112            pass
113        if verbose: print("thread: creating server")
114        svr = svrcls(self.__addr, self.__hdlrcls)
115        # pull the address out of the server in case it changed
116        # this can happen if another process is using the port
117        addr = svr.server_address
118        if addr:
119            self.__addr = addr
120            if self.__addr != svr.socket.getsockname():
121                raise RuntimeError('server_address was %s, expected %s' %
122                                       (self.__addr, svr.socket.getsockname()))
123        self.ready.set()
124        if verbose: print("thread: serving three times")
125        svr.serve_a_few()
126        if verbose: print("thread: done")
127
128
129class ForgivingTCPServer(SocketServer.TCPServer):
130    # prevent errors if another process is using the port we want
131    def server_bind(self):
132        host, default_port = self.server_address
133        # this code shamelessly stolen from test.test_support
134        # the ports were changed to protect the innocent
135        import sys
136        for port in [default_port, 3434, 8798, 23833]:
137            try:
138                self.server_address = host, port
139                SocketServer.TCPServer.server_bind(self)
140                break
141            except socket.error as e:
142                (err, msg) = e
143                if err != errno.EADDRINUSE:
144                    raise
145                print('  WARNING: failed to listen on port %d, trying another' % port, file=sys.__stderr__)
146
147class SocketServerTest(unittest.TestCase):
148    """Test all socket servers."""
149
150    def setUp(self):
151        self.port_seed = 0
152        self.test_files = []
153
154    def tearDown(self):
155        time.sleep(DELAY)
156        reap_children()
157
158        for fn in self.test_files:
159            try:
160                os.remove(fn)
161            except os.error:
162                pass
163        self.test_files[:] = []
164
165    def pickport(self):
166        self.port_seed += 1
167        return 10000 + (os.getpid() % 1000)*10 + self.port_seed
168
169    def pickaddr(self, proto):
170        if proto == socket.AF_INET:
171            return (HOST, self.pickport())
172        else:
173            fn = TEST_FILE + str(self.pickport())
174            if os.name == 'os2':
175                # AF_UNIX socket names on OS/2 require a specific prefix
176                # which can't include a drive letter and must also use
177                # backslashes as directory separators
178                if fn[1] == ':':
179                    fn = fn[2:]
180                if fn[0] in (os.sep, os.altsep):
181                    fn = fn[1:]
182                fn = os.path.join('\socket', fn)
183                if os.sep == '/':
184                    fn = fn.replace(os.sep, os.altsep)
185                else:
186                    fn = fn.replace(os.altsep, os.sep)
187            self.test_files.append(fn)
188            return fn
189
190    def run_servers(self, proto, servers, hdlrcls, testfunc):
191        for svrcls in servers:
192            addr = self.pickaddr(proto)
193            if verbose:
194                print("ADDR =", addr)
195                print("CLASS =", svrcls)
196            t = ServerThread(addr, svrcls, hdlrcls)
197            if verbose: print("server created")
198            t.start()
199            if verbose: print("server running")
200            for i in range(NREQ):
201                t.ready.wait(10*DELAY)
202                self.assert_(t.ready.isSet(),
203                    "Server not ready within a reasonable time")
204                if verbose: print("test client", i)
205                testfunc(proto, addr)
206            if verbose: print("waiting for server")
207            t.join()
208            if verbose: print("done")
209
210    def stream_examine(self, proto, addr):
211        s = socket.socket(proto, socket.SOCK_STREAM)
212        s.connect(addr)
213        s.sendall(TEST_STR)
214        buf = data = receive(s, 100)
215        while data and b'\n' not in buf:
216            data = receive(s, 100)
217            buf += data
218        self.assertEquals(buf, TEST_STR)
219        s.close()
220
221    def dgram_examine(self, proto, addr):
222        s = socket.socket(proto, socket.SOCK_DGRAM)
223        s.sendto(TEST_STR, addr)
224        buf = data = receive(s, 100)
225        while data and b'\n' not in buf:
226            data = receive(s, 100)
227            buf += data
228        self.assertEquals(buf, TEST_STR)
229        s.close()
230
231    def test_TCPServers(self):
232        # Test SocketServer.TCPServer
233        servers = [ForgivingTCPServer, SocketServer.ThreadingTCPServer]
234        if HAVE_FORKING:
235            servers.append(SocketServer.ForkingTCPServer)
236        self.run_servers(socket.AF_INET, servers,
237                         MyStreamHandler, self.stream_examine)
238
239    def test_UDPServers(self):
240        # Test SocketServer.UDPServer
241        servers = [SocketServer.UDPServer,
242                   SocketServer.ThreadingUDPServer]
243        if HAVE_FORKING:
244            servers.append(SocketServer.ForkingUDPServer)
245        self.run_servers(socket.AF_INET, servers, MyDatagramHandler,
246                         self.dgram_examine)
247
248    def test_stream_servers(self):
249        # Test SocketServer's stream servers
250        if not HAVE_UNIX_SOCKETS:
251            return
252        servers = [SocketServer.UnixStreamServer,
253                   SocketServer.ThreadingUnixStreamServer]
254        if HAVE_FORKING:
255            servers.append(ForkingUnixStreamServer)
256        self.run_servers(socket.AF_UNIX, servers, MyStreamHandler,
257                         self.stream_examine)
258
259    # Alas, on Linux (at least) recvfrom() doesn't return a meaningful
260    # client address so this cannot work:
261
262    # def test_dgram_servers(self):
263    #     # Test SocketServer.UnixDatagramServer
264    #     if not HAVE_UNIX_SOCKETS:
265    #         return
266    #     servers = [SocketServer.UnixDatagramServer,
267    #                SocketServer.ThreadingUnixDatagramServer]
268    #     if HAVE_FORKING:
269    #         servers.append(ForkingUnixDatagramServer)
270    #     self.run_servers(socket.AF_UNIX, servers, MyDatagramHandler,
271    #                      self.dgram_examine)
272
273
274def test_main():
275    if imp.lock_held():
276        # If the import lock is held, the threads will hang
277        raise TestSkipped("can't run when import lock is held")
278
279    test.test_support.run_unittest(SocketServerTest)
280
281if __name__ == "__main__":
282    test_main()
283