1#!/usr/bin/env python
2
3import unittest
4from test import test_support
5
6import errno
7import socket
8import select
9import _testcapi
10import time
11import traceback
12import Queue
13import sys
14import os
15import array
16import contextlib
17from weakref import proxy
18import signal
19import math
20
21def try_address(host, port=0, family=socket.AF_INET):
22    """Try to bind a socket on the given host:port and return True
23    if that has been possible."""
24    try:
25        sock = socket.socket(family, socket.SOCK_STREAM)
26        sock.bind((host, port))
27    except (socket.error, socket.gaierror):
28        return False
29    else:
30        sock.close()
31        return True
32
33HOST = test_support.HOST
34MSG = b'Michael Gilfix was here\n'
35SUPPORTS_IPV6 = socket.has_ipv6 and try_address('::1', family=socket.AF_INET6)
36
37try:
38    import thread
39    import threading
40except ImportError:
41    thread = None
42    threading = None
43
44HOST = test_support.HOST
45MSG = 'Michael Gilfix was here\n'
46
47class SocketTCPTest(unittest.TestCase):
48
49    def setUp(self):
50        self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
51        self.port = test_support.bind_port(self.serv)
52        self.serv.listen(1)
53
54    def tearDown(self):
55        self.serv.close()
56        self.serv = None
57
58class SocketUDPTest(unittest.TestCase):
59
60    def setUp(self):
61        self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
62        self.port = test_support.bind_port(self.serv)
63
64    def tearDown(self):
65        self.serv.close()
66        self.serv = None
67
68class ThreadableTest:
69    """Threadable Test class
70
71    The ThreadableTest class makes it easy to create a threaded
72    client/server pair from an existing unit test. To create a
73    new threaded class from an existing unit test, use multiple
74    inheritance:
75
76        class NewClass (OldClass, ThreadableTest):
77            pass
78
79    This class defines two new fixture functions with obvious
80    purposes for overriding:
81
82        clientSetUp ()
83        clientTearDown ()
84
85    Any new test functions within the class must then define
86    tests in pairs, where the test name is preceeded with a
87    '_' to indicate the client portion of the test. Ex:
88
89        def testFoo(self):
90            # Server portion
91
92        def _testFoo(self):
93            # Client portion
94
95    Any exceptions raised by the clients during their tests
96    are caught and transferred to the main thread to alert
97    the testing framework.
98
99    Note, the server setup function cannot call any blocking
100    functions that rely on the client thread during setup,
101    unless serverExplicitReady() is called just before
102    the blocking call (such as in setting up a client/server
103    connection and performing the accept() in setUp().
104    """
105
106    def __init__(self):
107        # Swap the true setup function
108        self.__setUp = self.setUp
109        self.__tearDown = self.tearDown
110        self.setUp = self._setUp
111        self.tearDown = self._tearDown
112
113    def serverExplicitReady(self):
114        """This method allows the server to explicitly indicate that
115        it wants the client thread to proceed. This is useful if the
116        server is about to execute a blocking routine that is
117        dependent upon the client thread during its setup routine."""
118        self.server_ready.set()
119
120    def _setUp(self):
121        self.server_ready = threading.Event()
122        self.client_ready = threading.Event()
123        self.done = threading.Event()
124        self.queue = Queue.Queue(1)
125
126        # Do some munging to start the client test.
127        methodname = self.id()
128        i = methodname.rfind('.')
129        methodname = methodname[i+1:]
130        test_method = getattr(self, '_' + methodname)
131        self.client_thread = thread.start_new_thread(
132            self.clientRun, (test_method,))
133
134        self.__setUp()
135        if not self.server_ready.is_set():
136            self.server_ready.set()
137        self.client_ready.wait()
138
139    def _tearDown(self):
140        self.__tearDown()
141        self.done.wait()
142
143        if not self.queue.empty():
144            msg = self.queue.get()
145            self.fail(msg)
146
147    def clientRun(self, test_func):
148        self.server_ready.wait()
149        self.clientSetUp()
150        self.client_ready.set()
151        if not callable(test_func):
152            raise TypeError("test_func must be a callable function.")
153        try:
154            test_func()
155        except Exception, strerror:
156            self.queue.put(strerror)
157        self.clientTearDown()
158
159    def clientSetUp(self):
160        raise NotImplementedError("clientSetUp must be implemented.")
161
162    def clientTearDown(self):
163        self.done.set()
164        thread.exit()
165
166class ThreadedTCPSocketTest(SocketTCPTest, ThreadableTest):
167
168    def __init__(self, methodName='runTest'):
169        SocketTCPTest.__init__(self, methodName=methodName)
170        ThreadableTest.__init__(self)
171
172    def clientSetUp(self):
173        self.cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
174
175    def clientTearDown(self):
176        self.cli.close()
177        self.cli = None
178        ThreadableTest.clientTearDown(self)
179
180class ThreadedUDPSocketTest(SocketUDPTest, ThreadableTest):
181
182    def __init__(self, methodName='runTest'):
183        SocketUDPTest.__init__(self, methodName=methodName)
184        ThreadableTest.__init__(self)
185
186    def clientSetUp(self):
187        self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
188
189    def clientTearDown(self):
190        self.cli.close()
191        self.cli = None
192        ThreadableTest.clientTearDown(self)
193
194class SocketConnectedTest(ThreadedTCPSocketTest):
195
196    def __init__(self, methodName='runTest'):
197        ThreadedTCPSocketTest.__init__(self, methodName=methodName)
198
199    def setUp(self):
200        ThreadedTCPSocketTest.setUp(self)
201        # Indicate explicitly we're ready for the client thread to
202        # proceed and then perform the blocking call to accept
203        self.serverExplicitReady()
204        conn, addr = self.serv.accept()
205        self.cli_conn = conn
206
207    def tearDown(self):
208        self.cli_conn.close()
209        self.cli_conn = None
210        ThreadedTCPSocketTest.tearDown(self)
211
212    def clientSetUp(self):
213        ThreadedTCPSocketTest.clientSetUp(self)
214        self.cli.connect((HOST, self.port))
215        self.serv_conn = self.cli
216
217    def clientTearDown(self):
218        self.serv_conn.close()
219        self.serv_conn = None
220        ThreadedTCPSocketTest.clientTearDown(self)
221
222class SocketPairTest(unittest.TestCase, ThreadableTest):
223
224    def __init__(self, methodName='runTest'):
225        unittest.TestCase.__init__(self, methodName=methodName)
226        ThreadableTest.__init__(self)
227
228    def setUp(self):
229        self.serv, self.cli = socket.socketpair()
230
231    def tearDown(self):
232        self.serv.close()
233        self.serv = None
234
235    def clientSetUp(self):
236        pass
237
238    def clientTearDown(self):
239        self.cli.close()
240        self.cli = None
241        ThreadableTest.clientTearDown(self)
242
243
244#######################################################################
245## Begin Tests
246
247class GeneralModuleTests(unittest.TestCase):
248
249    def test_weakref(self):
250        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
251        p = proxy(s)
252        self.assertEqual(p.fileno(), s.fileno())
253        s.close()
254        s = None
255        try:
256            p.fileno()
257        except ReferenceError:
258            pass
259        else:
260            self.fail('Socket proxy still exists')
261
262    def testSocketError(self):
263        # Testing socket module exceptions
264        def raise_error(*args, **kwargs):
265            raise socket.error
266        def raise_herror(*args, **kwargs):
267            raise socket.herror
268        def raise_gaierror(*args, **kwargs):
269            raise socket.gaierror
270        self.assertRaises(socket.error, raise_error,
271                              "Error raising socket exception.")
272        self.assertRaises(socket.error, raise_herror,
273                              "Error raising socket exception.")
274        self.assertRaises(socket.error, raise_gaierror,
275                              "Error raising socket exception.")
276
277    def testSendtoErrors(self):
278        # Testing that sendto doens't masks failures. See #10169.
279        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
280        self.addCleanup(s.close)
281        s.bind(('', 0))
282        sockname = s.getsockname()
283        # 2 args
284        with self.assertRaises(UnicodeEncodeError):
285            s.sendto(u'\u2620', sockname)
286        with self.assertRaises(TypeError) as cm:
287            s.sendto(5j, sockname)
288        self.assertIn('not complex', str(cm.exception))
289        with self.assertRaises(TypeError) as cm:
290            s.sendto('foo', None)
291        self.assertIn('not NoneType', str(cm.exception))
292        # 3 args
293        with self.assertRaises(UnicodeEncodeError):
294            s.sendto(u'\u2620', 0, sockname)
295        with self.assertRaises(TypeError) as cm:
296            s.sendto(5j, 0, sockname)
297        self.assertIn('not complex', str(cm.exception))
298        with self.assertRaises(TypeError) as cm:
299            s.sendto('foo', 0, None)
300        self.assertIn('not NoneType', str(cm.exception))
301        with self.assertRaises(TypeError) as cm:
302            s.sendto('foo', 'bar', sockname)
303        self.assertIn('an integer is required', str(cm.exception))
304        with self.assertRaises(TypeError) as cm:
305            s.sendto('foo', None, None)
306        self.assertIn('an integer is required', str(cm.exception))
307        # wrong number of args
308        with self.assertRaises(TypeError) as cm:
309            s.sendto('foo')
310        self.assertIn('(1 given)', str(cm.exception))
311        with self.assertRaises(TypeError) as cm:
312            s.sendto('foo', 0, sockname, 4)
313        self.assertIn('(4 given)', str(cm.exception))
314
315
316    def testCrucialConstants(self):
317        # Testing for mission critical constants
318        socket.AF_INET
319        socket.SOCK_STREAM
320        socket.SOCK_DGRAM
321        socket.SOCK_RAW
322        socket.SOCK_RDM
323        socket.SOCK_SEQPACKET
324        socket.SOL_SOCKET
325        socket.SO_REUSEADDR
326
327    def testHostnameRes(self):
328        # Testing hostname resolution mechanisms
329        hostname = socket.gethostname()
330        try:
331            ip = socket.gethostbyname(hostname)
332        except socket.error:
333            # Probably name lookup wasn't set up right; skip this test
334            return
335        self.assertTrue(ip.find('.') >= 0, "Error resolving host to ip.")
336        try:
337            hname, aliases, ipaddrs = socket.gethostbyaddr(ip)
338        except socket.error:
339            # Probably a similar problem as above; skip this test
340            return
341        all_host_names = [hostname, hname] + aliases
342        fqhn = socket.getfqdn(ip)
343        if not fqhn in all_host_names:
344            self.fail("Error testing host resolution mechanisms. (fqdn: %s, all: %s)" % (fqhn, repr(all_host_names)))
345
346    def testRefCountGetNameInfo(self):
347        # Testing reference count for getnameinfo
348        if hasattr(sys, "getrefcount"):
349            try:
350                # On some versions, this loses a reference
351                orig = sys.getrefcount(__name__)
352                socket.getnameinfo(__name__,0)
353            except TypeError:
354                self.assertEqual(sys.getrefcount(__name__), orig,
355                                 "socket.getnameinfo loses a reference")
356
357    def testInterpreterCrash(self):
358        # Making sure getnameinfo doesn't crash the interpreter
359        try:
360            # On some versions, this crashes the interpreter.
361            socket.getnameinfo(('x', 0, 0, 0), 0)
362        except socket.error:
363            pass
364
365    def testNtoH(self):
366        # This just checks that htons etc. are their own inverse,
367        # when looking at the lower 16 or 32 bits.
368        sizes = {socket.htonl: 32, socket.ntohl: 32,
369                 socket.htons: 16, socket.ntohs: 16}
370        for func, size in sizes.items():
371            mask = (1L<<size) - 1
372            for i in (0, 1, 0xffff, ~0xffff, 2, 0x01234567, 0x76543210):
373                self.assertEqual(i & mask, func(func(i&mask)) & mask)
374
375            swapped = func(mask)
376            self.assertEqual(swapped & mask, mask)
377            self.assertRaises(OverflowError, func, 1L<<34)
378
379    def testNtoHErrors(self):
380        good_values = [ 1, 2, 3, 1L, 2L, 3L ]
381        bad_values = [ -1, -2, -3, -1L, -2L, -3L ]
382        for k in good_values:
383            socket.ntohl(k)
384            socket.ntohs(k)
385            socket.htonl(k)
386            socket.htons(k)
387        for k in bad_values:
388            self.assertRaises(OverflowError, socket.ntohl, k)
389            self.assertRaises(OverflowError, socket.ntohs, k)
390            self.assertRaises(OverflowError, socket.htonl, k)
391            self.assertRaises(OverflowError, socket.htons, k)
392
393    def testGetServBy(self):
394        eq = self.assertEqual
395        # Find one service that exists, then check all the related interfaces.
396        # I've ordered this by protocols that have both a tcp and udp
397        # protocol, at least for modern Linuxes.
398        if (sys.platform.startswith('linux') or
399            sys.platform.startswith('freebsd') or
400            sys.platform.startswith('netbsd') or
401            sys.platform == 'darwin'):
402            # avoid the 'echo' service on this platform, as there is an
403            # assumption breaking non-standard port/protocol entry
404            services = ('daytime', 'qotd', 'domain')
405        else:
406            services = ('echo', 'daytime', 'domain')
407        for service in services:
408            try:
409                port = socket.getservbyname(service, 'tcp')
410                break
411            except socket.error:
412                pass
413        else:
414            raise socket.error
415        # Try same call with optional protocol omitted
416        port2 = socket.getservbyname(service)
417        eq(port, port2)
418        # Try udp, but don't barf if it doesn't exist
419        try:
420            udpport = socket.getservbyname(service, 'udp')
421        except socket.error:
422            udpport = None
423        else:
424            eq(udpport, port)
425        # Now make sure the lookup by port returns the same service name
426        eq(socket.getservbyport(port2), service)
427        eq(socket.getservbyport(port, 'tcp'), service)
428        if udpport is not None:
429            eq(socket.getservbyport(udpport, 'udp'), service)
430        # Make sure getservbyport does not accept out of range ports.
431        self.assertRaises(OverflowError, socket.getservbyport, -1)
432        self.assertRaises(OverflowError, socket.getservbyport, 65536)
433
434    def testDefaultTimeout(self):
435        # Testing default timeout
436        # The default timeout should initially be None
437        self.assertEqual(socket.getdefaulttimeout(), None)
438        s = socket.socket()
439        self.assertEqual(s.gettimeout(), None)
440        s.close()
441
442        # Set the default timeout to 10, and see if it propagates
443        socket.setdefaulttimeout(10)
444        self.assertEqual(socket.getdefaulttimeout(), 10)
445        s = socket.socket()
446        self.assertEqual(s.gettimeout(), 10)
447        s.close()
448
449        # Reset the default timeout to None, and see if it propagates
450        socket.setdefaulttimeout(None)
451        self.assertEqual(socket.getdefaulttimeout(), None)
452        s = socket.socket()
453        self.assertEqual(s.gettimeout(), None)
454        s.close()
455
456        # Check that setting it to an invalid value raises ValueError
457        self.assertRaises(ValueError, socket.setdefaulttimeout, -1)
458
459        # Check that setting it to an invalid type raises TypeError
460        self.assertRaises(TypeError, socket.setdefaulttimeout, "spam")
461
462    def testIPv4_inet_aton_fourbytes(self):
463        if not hasattr(socket, 'inet_aton'):
464            return  # No inet_aton, nothing to check
465        # Test that issue1008086 and issue767150 are fixed.
466        # It must return 4 bytes.
467        self.assertEqual('\x00'*4, socket.inet_aton('0.0.0.0'))
468        self.assertEqual('\xff'*4, socket.inet_aton('255.255.255.255'))
469
470    def testIPv4toString(self):
471        if not hasattr(socket, 'inet_pton'):
472            return # No inet_pton() on this platform
473        from socket import inet_aton as f, inet_pton, AF_INET
474        g = lambda a: inet_pton(AF_INET, a)
475
476        self.assertEqual('\x00\x00\x00\x00', f('0.0.0.0'))
477        self.assertEqual('\xff\x00\xff\x00', f('255.0.255.0'))
478        self.assertEqual('\xaa\xaa\xaa\xaa', f('170.170.170.170'))
479        self.assertEqual('\x01\x02\x03\x04', f('1.2.3.4'))
480        self.assertEqual('\xff\xff\xff\xff', f('255.255.255.255'))
481
482        self.assertEqual('\x00\x00\x00\x00', g('0.0.0.0'))
483        self.assertEqual('\xff\x00\xff\x00', g('255.0.255.0'))
484        self.assertEqual('\xaa\xaa\xaa\xaa', g('170.170.170.170'))
485        self.assertEqual('\xff\xff\xff\xff', g('255.255.255.255'))
486
487    def testIPv6toString(self):
488        if not hasattr(socket, 'inet_pton'):
489            return # No inet_pton() on this platform
490        try:
491            from socket import inet_pton, AF_INET6, has_ipv6
492            if not has_ipv6:
493                return
494        except ImportError:
495            return
496        f = lambda a: inet_pton(AF_INET6, a)
497
498        self.assertEqual('\x00' * 16, f('::'))
499        self.assertEqual('\x00' * 16, f('0::0'))
500        self.assertEqual('\x00\x01' + '\x00' * 14, f('1::'))
501        self.assertEqual(
502            '\x45\xef\x76\xcb\x00\x1a\x56\xef\xaf\xeb\x0b\xac\x19\x24\xae\xae',
503            f('45ef:76cb:1a:56ef:afeb:bac:1924:aeae')
504        )
505
506    def testStringToIPv4(self):
507        if not hasattr(socket, 'inet_ntop'):
508            return # No inet_ntop() on this platform
509        from socket import inet_ntoa as f, inet_ntop, AF_INET
510        g = lambda a: inet_ntop(AF_INET, a)
511
512        self.assertEqual('1.0.1.0', f('\x01\x00\x01\x00'))
513        self.assertEqual('170.85.170.85', f('\xaa\x55\xaa\x55'))
514        self.assertEqual('255.255.255.255', f('\xff\xff\xff\xff'))
515        self.assertEqual('1.2.3.4', f('\x01\x02\x03\x04'))
516
517        self.assertEqual('1.0.1.0', g('\x01\x00\x01\x00'))
518        self.assertEqual('170.85.170.85', g('\xaa\x55\xaa\x55'))
519        self.assertEqual('255.255.255.255', g('\xff\xff\xff\xff'))
520
521    def testStringToIPv6(self):
522        if not hasattr(socket, 'inet_ntop'):
523            return # No inet_ntop() on this platform
524        try:
525            from socket import inet_ntop, AF_INET6, has_ipv6
526            if not has_ipv6:
527                return
528        except ImportError:
529            return
530        f = lambda a: inet_ntop(AF_INET6, a)
531
532        self.assertEqual('::', f('\x00' * 16))
533        self.assertEqual('::1', f('\x00' * 15 + '\x01'))
534        self.assertEqual(
535            'aef:b01:506:1001:ffff:9997:55:170',
536            f('\x0a\xef\x0b\x01\x05\x06\x10\x01\xff\xff\x99\x97\x00\x55\x01\x70')
537        )
538
539    # XXX The following don't test module-level functionality...
540
541    def _get_unused_port(self, bind_address='0.0.0.0'):
542        """Use a temporary socket to elicit an unused ephemeral port.
543
544        Args:
545            bind_address: Hostname or IP address to search for a port on.
546
547        Returns: A most likely to be unused port.
548        """
549        tempsock = socket.socket()
550        tempsock.bind((bind_address, 0))
551        host, port = tempsock.getsockname()
552        tempsock.close()
553        return port
554
555    def testSockName(self):
556        # Testing getsockname()
557        port = self._get_unused_port()
558        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
559        self.addCleanup(sock.close)
560        sock.bind(("0.0.0.0", port))
561        name = sock.getsockname()
562        # XXX(nnorwitz): http://tinyurl.com/os5jz seems to indicate
563        # it reasonable to get the host's addr in addition to 0.0.0.0.
564        # At least for eCos.  This is required for the S/390 to pass.
565        try:
566            my_ip_addr = socket.gethostbyname(socket.gethostname())
567        except socket.error:
568            # Probably name lookup wasn't set up right; skip this test
569            return
570        self.assertIn(name[0], ("0.0.0.0", my_ip_addr), '%s invalid' % name[0])
571        self.assertEqual(name[1], port)
572
573    def testGetSockOpt(self):
574        # Testing getsockopt()
575        # We know a socket should start without reuse==0
576        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
577        self.addCleanup(sock.close)
578        reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
579        self.assertFalse(reuse != 0, "initial mode is reuse")
580
581    def testSetSockOpt(self):
582        # Testing setsockopt()
583        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
584        self.addCleanup(sock.close)
585        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
586        reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
587        self.assertFalse(reuse == 0, "failed to set reuse mode")
588
589    def testSendAfterClose(self):
590        # testing send() after close() with timeout
591        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
592        sock.settimeout(1)
593        sock.close()
594        self.assertRaises(socket.error, sock.send, "spam")
595
596    def testNewAttributes(self):
597        # testing .family, .type and .protocol
598        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
599        self.assertEqual(sock.family, socket.AF_INET)
600        self.assertEqual(sock.type, socket.SOCK_STREAM)
601        self.assertEqual(sock.proto, 0)
602        sock.close()
603
604    def test_getsockaddrarg(self):
605        host = '0.0.0.0'
606        port = self._get_unused_port(bind_address=host)
607        big_port = port + 65536
608        neg_port = port - 65536
609        sock = socket.socket()
610        try:
611            self.assertRaises(OverflowError, sock.bind, (host, big_port))
612            self.assertRaises(OverflowError, sock.bind, (host, neg_port))
613            sock.bind((host, port))
614        finally:
615            sock.close()
616
617    @unittest.skipUnless(os.name == "nt", "Windows specific")
618    def test_sock_ioctl(self):
619        self.assertTrue(hasattr(socket.socket, 'ioctl'))
620        self.assertTrue(hasattr(socket, 'SIO_RCVALL'))
621        self.assertTrue(hasattr(socket, 'RCVALL_ON'))
622        self.assertTrue(hasattr(socket, 'RCVALL_OFF'))
623        self.assertTrue(hasattr(socket, 'SIO_KEEPALIVE_VALS'))
624        s = socket.socket()
625        self.addCleanup(s.close)
626        self.assertRaises(ValueError, s.ioctl, -1, None)
627        s.ioctl(socket.SIO_KEEPALIVE_VALS, (1, 100, 100))
628
629    def testGetaddrinfo(self):
630        try:
631            socket.getaddrinfo('localhost', 80)
632        except socket.gaierror as err:
633            if err.errno == socket.EAI_SERVICE:
634                # see http://bugs.python.org/issue1282647
635                self.skipTest("buggy libc version")
636            raise
637        # len of every sequence is supposed to be == 5
638        for info in socket.getaddrinfo(HOST, None):
639            self.assertEqual(len(info), 5)
640        # host can be a domain name, a string representation of an
641        # IPv4/v6 address or None
642        socket.getaddrinfo('localhost', 80)
643        socket.getaddrinfo('127.0.0.1', 80)
644        socket.getaddrinfo(None, 80)
645        if SUPPORTS_IPV6:
646            socket.getaddrinfo('::1', 80)
647        # port can be a string service name such as "http", a numeric
648        # port number (int or long), or None
649        socket.getaddrinfo(HOST, "http")
650        socket.getaddrinfo(HOST, 80)
651        socket.getaddrinfo(HOST, 80L)
652        socket.getaddrinfo(HOST, None)
653        # test family and socktype filters
654        infos = socket.getaddrinfo(HOST, None, socket.AF_INET)
655        for family, _, _, _, _ in infos:
656            self.assertEqual(family, socket.AF_INET)
657        infos = socket.getaddrinfo(HOST, None, 0, socket.SOCK_STREAM)
658        for _, socktype, _, _, _ in infos:
659            self.assertEqual(socktype, socket.SOCK_STREAM)
660        # test proto and flags arguments
661        socket.getaddrinfo(HOST, None, 0, 0, socket.SOL_TCP)
662        socket.getaddrinfo(HOST, None, 0, 0, 0, socket.AI_PASSIVE)
663        # a server willing to support both IPv4 and IPv6 will
664        # usually do this
665        socket.getaddrinfo(None, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0,
666                           socket.AI_PASSIVE)
667
668
669    def check_sendall_interrupted(self, with_timeout):
670        # socketpair() is not stricly required, but it makes things easier.
671        if not hasattr(signal, 'alarm') or not hasattr(socket, 'socketpair'):
672            self.skipTest("signal.alarm and socket.socketpair required for this test")
673        # Our signal handlers clobber the C errno by calling a math function
674        # with an invalid domain value.
675        def ok_handler(*args):
676            self.assertRaises(ValueError, math.acosh, 0)
677        def raising_handler(*args):
678            self.assertRaises(ValueError, math.acosh, 0)
679            1 // 0
680        c, s = socket.socketpair()
681        old_alarm = signal.signal(signal.SIGALRM, raising_handler)
682        try:
683            if with_timeout:
684                # Just above the one second minimum for signal.alarm
685                c.settimeout(1.5)
686            with self.assertRaises(ZeroDivisionError):
687                signal.alarm(1)
688                c.sendall(b"x" * (1024**2))
689            if with_timeout:
690                signal.signal(signal.SIGALRM, ok_handler)
691                signal.alarm(1)
692                self.assertRaises(socket.timeout, c.sendall, b"x" * (1024**2))
693        finally:
694            signal.signal(signal.SIGALRM, old_alarm)
695            c.close()
696            s.close()
697
698    def test_sendall_interrupted(self):
699        self.check_sendall_interrupted(False)
700
701    def test_sendall_interrupted_with_timeout(self):
702        self.check_sendall_interrupted(True)
703
704    def test_listen_backlog(self):
705        for backlog in 0, -1:
706            srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
707            srv.bind((HOST, 0))
708            srv.listen(backlog)
709            srv.close()
710
711        # Issue 15989
712        srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
713        srv.bind((HOST, 0))
714        self.assertRaises(OverflowError, srv.listen, _testcapi.INT_MAX + 1)
715        srv.close()
716
717    @unittest.skipUnless(SUPPORTS_IPV6, 'IPv6 required for this test.')
718    def test_flowinfo(self):
719        self.assertRaises(OverflowError, socket.getnameinfo,
720                          ('::1',0, 0xffffffff), 0)
721        s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
722        try:
723            self.assertRaises(OverflowError, s.bind, ('::1', 0, -10))
724        finally:
725            s.close()
726
727
728@unittest.skipUnless(thread, 'Threading required for this test.')
729class BasicTCPTest(SocketConnectedTest):
730
731    def __init__(self, methodName='runTest'):
732        SocketConnectedTest.__init__(self, methodName=methodName)
733
734    def testRecv(self):
735        # Testing large receive over TCP
736        msg = self.cli_conn.recv(1024)
737        self.assertEqual(msg, MSG)
738
739    def _testRecv(self):
740        self.serv_conn.send(MSG)
741
742    def testOverFlowRecv(self):
743        # Testing receive in chunks over TCP
744        seg1 = self.cli_conn.recv(len(MSG) - 3)
745        seg2 = self.cli_conn.recv(1024)
746        msg = seg1 + seg2
747        self.assertEqual(msg, MSG)
748
749    def _testOverFlowRecv(self):
750        self.serv_conn.send(MSG)
751
752    def testRecvFrom(self):
753        # Testing large recvfrom() over TCP
754        msg, addr = self.cli_conn.recvfrom(1024)
755        self.assertEqual(msg, MSG)
756
757    def _testRecvFrom(self):
758        self.serv_conn.send(MSG)
759
760    def testOverFlowRecvFrom(self):
761        # Testing recvfrom() in chunks over TCP
762        seg1, addr = self.cli_conn.recvfrom(len(MSG)-3)
763        seg2, addr = self.cli_conn.recvfrom(1024)
764        msg = seg1 + seg2
765        self.assertEqual(msg, MSG)
766
767    def _testOverFlowRecvFrom(self):
768        self.serv_conn.send(MSG)
769
770    def testSendAll(self):
771        # Testing sendall() with a 2048 byte string over TCP
772        msg = ''
773        while 1:
774            read = self.cli_conn.recv(1024)
775            if not read:
776                break
777            msg += read
778        self.assertEqual(msg, 'f' * 2048)
779
780    def _testSendAll(self):
781        big_chunk = 'f' * 2048
782        self.serv_conn.sendall(big_chunk)
783
784    def testFromFd(self):
785        # Testing fromfd()
786        if not hasattr(socket, "fromfd"):
787            return # On Windows, this doesn't exist
788        fd = self.cli_conn.fileno()
789        sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
790        self.addCleanup(sock.close)
791        msg = sock.recv(1024)
792        self.assertEqual(msg, MSG)
793
794    def _testFromFd(self):
795        self.serv_conn.send(MSG)
796
797    def testDup(self):
798        # Testing dup()
799        sock = self.cli_conn.dup()
800        self.addCleanup(sock.close)
801        msg = sock.recv(1024)
802        self.assertEqual(msg, MSG)
803
804    def _testDup(self):
805        self.serv_conn.send(MSG)
806
807    def testShutdown(self):
808        # Testing shutdown()
809        msg = self.cli_conn.recv(1024)
810        self.assertEqual(msg, MSG)
811        # wait for _testShutdown to finish: on OS X, when the server
812        # closes the connection the client also becomes disconnected,
813        # and the client's shutdown call will fail. (Issue #4397.)
814        self.done.wait()
815
816    def _testShutdown(self):
817        self.serv_conn.send(MSG)
818        # Issue 15989
819        self.assertRaises(OverflowError, self.serv_conn.shutdown,
820                          _testcapi.INT_MAX + 1)
821        self.assertRaises(OverflowError, self.serv_conn.shutdown,
822                          2 + (_testcapi.UINT_MAX + 1))
823        self.serv_conn.shutdown(2)
824
825@unittest.skipUnless(thread, 'Threading required for this test.')
826class BasicUDPTest(ThreadedUDPSocketTest):
827
828    def __init__(self, methodName='runTest'):
829        ThreadedUDPSocketTest.__init__(self, methodName=methodName)
830
831    def testSendtoAndRecv(self):
832        # Testing sendto() and Recv() over UDP
833        msg = self.serv.recv(len(MSG))
834        self.assertEqual(msg, MSG)
835
836    def _testSendtoAndRecv(self):
837        self.cli.sendto(MSG, 0, (HOST, self.port))
838
839    def testRecvFrom(self):
840        # Testing recvfrom() over UDP
841        msg, addr = self.serv.recvfrom(len(MSG))
842        self.assertEqual(msg, MSG)
843
844    def _testRecvFrom(self):
845        self.cli.sendto(MSG, 0, (HOST, self.port))
846
847    def testRecvFromNegative(self):
848        # Negative lengths passed to recvfrom should give ValueError.
849        self.assertRaises(ValueError, self.serv.recvfrom, -1)
850
851    def _testRecvFromNegative(self):
852        self.cli.sendto(MSG, 0, (HOST, self.port))
853
854@unittest.skipUnless(thread, 'Threading required for this test.')
855class TCPCloserTest(ThreadedTCPSocketTest):
856
857    def testClose(self):
858        conn, addr = self.serv.accept()
859        conn.close()
860
861        sd = self.cli
862        read, write, err = select.select([sd], [], [], 1.0)
863        self.assertEqual(read, [sd])
864        self.assertEqual(sd.recv(1), '')
865
866    def _testClose(self):
867        self.cli.connect((HOST, self.port))
868        time.sleep(1.0)
869
870@unittest.skipUnless(thread, 'Threading required for this test.')
871class BasicSocketPairTest(SocketPairTest):
872
873    def __init__(self, methodName='runTest'):
874        SocketPairTest.__init__(self, methodName=methodName)
875
876    def testRecv(self):
877        msg = self.serv.recv(1024)
878        self.assertEqual(msg, MSG)
879
880    def _testRecv(self):
881        self.cli.send(MSG)
882
883    def testSend(self):
884        self.serv.send(MSG)
885
886    def _testSend(self):
887        msg = self.cli.recv(1024)
888        self.assertEqual(msg, MSG)
889
890@unittest.skipUnless(thread, 'Threading required for this test.')
891class NonBlockingTCPTests(ThreadedTCPSocketTest):
892
893    def __init__(self, methodName='runTest'):
894        ThreadedTCPSocketTest.__init__(self, methodName=methodName)
895
896    def testSetBlocking(self):
897        # Testing whether set blocking works
898        self.serv.setblocking(True)
899        self.assertIsNone(self.serv.gettimeout())
900        self.serv.setblocking(False)
901        self.assertEqual(self.serv.gettimeout(), 0.0)
902        start = time.time()
903        try:
904            self.serv.accept()
905        except socket.error:
906            pass
907        end = time.time()
908        self.assertTrue((end - start) < 1.0, "Error setting non-blocking mode.")
909        # Issue 15989
910        if _testcapi.UINT_MAX < _testcapi.ULONG_MAX:
911            self.serv.setblocking(_testcapi.UINT_MAX + 1)
912            self.assertIsNone(self.serv.gettimeout())
913
914    def _testSetBlocking(self):
915        pass
916
917    def testAccept(self):
918        # Testing non-blocking accept
919        self.serv.setblocking(0)
920        try:
921            conn, addr = self.serv.accept()
922        except socket.error:
923            pass
924        else:
925            self.fail("Error trying to do non-blocking accept.")
926        read, write, err = select.select([self.serv], [], [])
927        if self.serv in read:
928            conn, addr = self.serv.accept()
929            conn.close()
930        else:
931            self.fail("Error trying to do accept after select.")
932
933    def _testAccept(self):
934        time.sleep(0.1)
935        self.cli.connect((HOST, self.port))
936
937    def testConnect(self):
938        # Testing non-blocking connect
939        conn, addr = self.serv.accept()
940        conn.close()
941
942    def _testConnect(self):
943        self.cli.settimeout(10)
944        self.cli.connect((HOST, self.port))
945
946    def testRecv(self):
947        # Testing non-blocking recv
948        conn, addr = self.serv.accept()
949        conn.setblocking(0)
950        try:
951            msg = conn.recv(len(MSG))
952        except socket.error:
953            pass
954        else:
955            self.fail("Error trying to do non-blocking recv.")
956        read, write, err = select.select([conn], [], [])
957        if conn in read:
958            msg = conn.recv(len(MSG))
959            conn.close()
960            self.assertEqual(msg, MSG)
961        else:
962            self.fail("Error during select call to non-blocking socket.")
963
964    def _testRecv(self):
965        self.cli.connect((HOST, self.port))
966        time.sleep(0.1)
967        self.cli.send(MSG)
968
969@unittest.skipUnless(thread, 'Threading required for this test.')
970class FileObjectClassTestCase(SocketConnectedTest):
971
972    bufsize = -1 # Use default buffer size
973
974    def __init__(self, methodName='runTest'):
975        SocketConnectedTest.__init__(self, methodName=methodName)
976
977    def setUp(self):
978        SocketConnectedTest.setUp(self)
979        self.serv_file = self.cli_conn.makefile('rb', self.bufsize)
980
981    def tearDown(self):
982        self.serv_file.close()
983        self.assertTrue(self.serv_file.closed)
984        SocketConnectedTest.tearDown(self)
985        self.serv_file = None
986
987    def clientSetUp(self):
988        SocketConnectedTest.clientSetUp(self)
989        self.cli_file = self.serv_conn.makefile('wb')
990
991    def clientTearDown(self):
992        self.cli_file.close()
993        self.assertTrue(self.cli_file.closed)
994        self.cli_file = None
995        SocketConnectedTest.clientTearDown(self)
996
997    def testSmallRead(self):
998        # Performing small file read test
999        first_seg = self.serv_file.read(len(MSG)-3)
1000        second_seg = self.serv_file.read(3)
1001        msg = first_seg + second_seg
1002        self.assertEqual(msg, MSG)
1003
1004    def _testSmallRead(self):
1005        self.cli_file.write(MSG)
1006        self.cli_file.flush()
1007
1008    def testFullRead(self):
1009        # read until EOF
1010        msg = self.serv_file.read()
1011        self.assertEqual(msg, MSG)
1012
1013    def _testFullRead(self):
1014        self.cli_file.write(MSG)
1015        self.cli_file.close()
1016
1017    def testUnbufferedRead(self):
1018        # Performing unbuffered file read test
1019        buf = ''
1020        while 1:
1021            char = self.serv_file.read(1)
1022            if not char:
1023                break
1024            buf += char
1025        self.assertEqual(buf, MSG)
1026
1027    def _testUnbufferedRead(self):
1028        self.cli_file.write(MSG)
1029        self.cli_file.flush()
1030
1031    def testReadline(self):
1032        # Performing file readline test
1033        line = self.serv_file.readline()
1034        self.assertEqual(line, MSG)
1035
1036    def _testReadline(self):
1037        self.cli_file.write(MSG)
1038        self.cli_file.flush()
1039
1040    def testReadlineAfterRead(self):
1041        a_baloo_is = self.serv_file.read(len("A baloo is"))
1042        self.assertEqual("A baloo is", a_baloo_is)
1043        _a_bear = self.serv_file.read(len(" a bear"))
1044        self.assertEqual(" a bear", _a_bear)
1045        line = self.serv_file.readline()
1046        self.assertEqual("\n", line)
1047        line = self.serv_file.readline()
1048        self.assertEqual("A BALOO IS A BEAR.\n", line)
1049        line = self.serv_file.readline()
1050        self.assertEqual(MSG, line)
1051
1052    def _testReadlineAfterRead(self):
1053        self.cli_file.write("A baloo is a bear\n")
1054        self.cli_file.write("A BALOO IS A BEAR.\n")
1055        self.cli_file.write(MSG)
1056        self.cli_file.flush()
1057
1058    def testReadlineAfterReadNoNewline(self):
1059        end_of_ = self.serv_file.read(len("End Of "))
1060        self.assertEqual("End Of ", end_of_)
1061        line = self.serv_file.readline()
1062        self.assertEqual("Line", line)
1063
1064    def _testReadlineAfterReadNoNewline(self):
1065        self.cli_file.write("End Of Line")
1066
1067    def testClosedAttr(self):
1068        self.assertTrue(not self.serv_file.closed)
1069
1070    def _testClosedAttr(self):
1071        self.assertTrue(not self.cli_file.closed)
1072
1073
1074class FileObjectInterruptedTestCase(unittest.TestCase):
1075    """Test that the file object correctly handles EINTR internally."""
1076
1077    class MockSocket(object):
1078        def __init__(self, recv_funcs=()):
1079            # A generator that returns callables that we'll call for each
1080            # call to recv().
1081            self._recv_step = iter(recv_funcs)
1082
1083        def recv(self, size):
1084            return self._recv_step.next()()
1085
1086    @staticmethod
1087    def _raise_eintr():
1088        raise socket.error(errno.EINTR)
1089
1090    def _test_readline(self, size=-1, **kwargs):
1091        mock_sock = self.MockSocket(recv_funcs=[
1092                lambda : "This is the first line\nAnd the sec",
1093                self._raise_eintr,
1094                lambda : "ond line is here\n",
1095                lambda : "",
1096            ])
1097        fo = socket._fileobject(mock_sock, **kwargs)
1098        self.assertEqual(fo.readline(size), "This is the first line\n")
1099        self.assertEqual(fo.readline(size), "And the second line is here\n")
1100
1101    def _test_read(self, size=-1, **kwargs):
1102        mock_sock = self.MockSocket(recv_funcs=[
1103                lambda : "This is the first line\nAnd the sec",
1104                self._raise_eintr,
1105                lambda : "ond line is here\n",
1106                lambda : "",
1107            ])
1108        fo = socket._fileobject(mock_sock, **kwargs)
1109        self.assertEqual(fo.read(size), "This is the first line\n"
1110                          "And the second line is here\n")
1111
1112    def test_default(self):
1113        self._test_readline()
1114        self._test_readline(size=100)
1115        self._test_read()
1116        self._test_read(size=100)
1117
1118    def test_with_1k_buffer(self):
1119        self._test_readline(bufsize=1024)
1120        self._test_readline(size=100, bufsize=1024)
1121        self._test_read(bufsize=1024)
1122        self._test_read(size=100, bufsize=1024)
1123
1124    def _test_readline_no_buffer(self, size=-1):
1125        mock_sock = self.MockSocket(recv_funcs=[
1126                lambda : "aa",
1127                lambda : "\n",
1128                lambda : "BB",
1129                self._raise_eintr,
1130                lambda : "bb",
1131                lambda : "",
1132            ])
1133        fo = socket._fileobject(mock_sock, bufsize=0)
1134        self.assertEqual(fo.readline(size), "aa\n")
1135        self.assertEqual(fo.readline(size), "BBbb")
1136
1137    def test_no_buffer(self):
1138        self._test_readline_no_buffer()
1139        self._test_readline_no_buffer(size=4)
1140        self._test_read(bufsize=0)
1141        self._test_read(size=100, bufsize=0)
1142
1143
1144class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase):
1145
1146    """Repeat the tests from FileObjectClassTestCase with bufsize==0.
1147
1148    In this case (and in this case only), it should be possible to
1149    create a file object, read a line from it, create another file
1150    object, read another line from it, without loss of data in the
1151    first file object's buffer.  Note that httplib relies on this
1152    when reading multiple requests from the same socket."""
1153
1154    bufsize = 0 # Use unbuffered mode
1155
1156    def testUnbufferedReadline(self):
1157        # Read a line, create a new file object, read another line with it
1158        line = self.serv_file.readline() # first line
1159        self.assertEqual(line, "A. " + MSG) # first line
1160        self.serv_file = self.cli_conn.makefile('rb', 0)
1161        line = self.serv_file.readline() # second line
1162        self.assertEqual(line, "B. " + MSG) # second line
1163
1164    def _testUnbufferedReadline(self):
1165        self.cli_file.write("A. " + MSG)
1166        self.cli_file.write("B. " + MSG)
1167        self.cli_file.flush()
1168
1169class LineBufferedFileObjectClassTestCase(FileObjectClassTestCase):
1170
1171    bufsize = 1 # Default-buffered for reading; line-buffered for writing
1172
1173    class SocketMemo(object):
1174        """A wrapper to keep track of sent data, needed to examine write behaviour"""
1175        def __init__(self, sock):
1176            self._sock = sock
1177            self.sent = []
1178
1179        def send(self, data, flags=0):
1180            n = self._sock.send(data, flags)
1181            self.sent.append(data[:n])
1182            return n
1183
1184        def sendall(self, data, flags=0):
1185            self._sock.sendall(data, flags)
1186            self.sent.append(data)
1187
1188        def __getattr__(self, attr):
1189            return getattr(self._sock, attr)
1190
1191        def getsent(self):
1192            return [e.tobytes() if isinstance(e, memoryview) else e for e in self.sent]
1193
1194    def setUp(self):
1195        FileObjectClassTestCase.setUp(self)
1196        self.serv_file._sock = self.SocketMemo(self.serv_file._sock)
1197
1198    def testLinebufferedWrite(self):
1199        # Write two lines, in small chunks
1200        msg = MSG.strip()
1201        print >> self.serv_file, msg,
1202        print >> self.serv_file, msg
1203
1204        # second line:
1205        print >> self.serv_file, msg,
1206        print >> self.serv_file, msg,
1207        print >> self.serv_file, msg
1208
1209        # third line
1210        print >> self.serv_file, ''
1211
1212        self.serv_file.flush()
1213
1214        msg1 = "%s %s\n"%(msg, msg)
1215        msg2 =  "%s %s %s\n"%(msg, msg, msg)
1216        msg3 =  "\n"
1217        self.assertEqual(self.serv_file._sock.getsent(), [msg1, msg2, msg3])
1218
1219    def _testLinebufferedWrite(self):
1220        msg = MSG.strip()
1221        msg1 = "%s %s\n"%(msg, msg)
1222        msg2 =  "%s %s %s\n"%(msg, msg, msg)
1223        msg3 =  "\n"
1224        l1 = self.cli_file.readline()
1225        self.assertEqual(l1, msg1)
1226        l2 = self.cli_file.readline()
1227        self.assertEqual(l2, msg2)
1228        l3 = self.cli_file.readline()
1229        self.assertEqual(l3, msg3)
1230
1231
1232class SmallBufferedFileObjectClassTestCase(FileObjectClassTestCase):
1233
1234    bufsize = 2 # Exercise the buffering code
1235
1236
1237class NetworkConnectionTest(object):
1238    """Prove network connection."""
1239    def clientSetUp(self):
1240        # We're inherited below by BasicTCPTest2, which also inherits
1241        # BasicTCPTest, which defines self.port referenced below.
1242        self.cli = socket.create_connection((HOST, self.port))
1243        self.serv_conn = self.cli
1244
1245class BasicTCPTest2(NetworkConnectionTest, BasicTCPTest):
1246    """Tests that NetworkConnection does not break existing TCP functionality.
1247    """
1248
1249class NetworkConnectionNoServer(unittest.TestCase):
1250    class MockSocket(socket.socket):
1251        def connect(self, *args):
1252            raise socket.timeout('timed out')
1253
1254    @contextlib.contextmanager
1255    def mocked_socket_module(self):
1256        """Return a socket which times out on connect"""
1257        old_socket = socket.socket
1258        socket.socket = self.MockSocket
1259        try:
1260            yield
1261        finally:
1262            socket.socket = old_socket
1263
1264    def test_connect(self):
1265        port = test_support.find_unused_port()
1266        cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1267        self.addCleanup(cli.close)
1268        with self.assertRaises(socket.error) as cm:
1269            cli.connect((HOST, port))
1270        self.assertEqual(cm.exception.errno, errno.ECONNREFUSED)
1271
1272    def test_create_connection(self):
1273        # Issue #9792: errors raised by create_connection() should have
1274        # a proper errno attribute.
1275        port = test_support.find_unused_port()
1276        with self.assertRaises(socket.error) as cm:
1277            socket.create_connection((HOST, port))
1278
1279        # Issue #16257: create_connection() calls getaddrinfo() against
1280        # 'localhost'.  This may result in an IPV6 addr being returned
1281        # as well as an IPV4 one:
1282        #   >>> socket.getaddrinfo('localhost', port, 0, SOCK_STREAM)
1283        #   >>> [(2,  2, 0, '', ('127.0.0.1', 41230)),
1284        #        (26, 2, 0, '', ('::1', 41230, 0, 0))]
1285        #
1286        # create_connection() enumerates through all the addresses returned
1287        # and if it doesn't successfully bind to any of them, it propagates
1288        # the last exception it encountered.
1289        #
1290        # On Solaris, ENETUNREACH is returned in this circumstance instead
1291        # of ECONNREFUSED.  So, if that errno exists, add it to our list of
1292        # expected errnos.
1293        expected_errnos = [ errno.ECONNREFUSED, ]
1294        if hasattr(errno, 'ENETUNREACH'):
1295            expected_errnos.append(errno.ENETUNREACH)
1296
1297        self.assertIn(cm.exception.errno, expected_errnos)
1298
1299    def test_create_connection_timeout(self):
1300        # Issue #9792: create_connection() should not recast timeout errors
1301        # as generic socket errors.
1302        with self.mocked_socket_module():
1303            with self.assertRaises(socket.timeout):
1304                socket.create_connection((HOST, 1234))
1305
1306
1307@unittest.skipUnless(thread, 'Threading required for this test.')
1308class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest):
1309
1310    def __init__(self, methodName='runTest'):
1311        SocketTCPTest.__init__(self, methodName=methodName)
1312        ThreadableTest.__init__(self)
1313
1314    def clientSetUp(self):
1315        self.source_port = test_support.find_unused_port()
1316
1317    def clientTearDown(self):
1318        self.cli.close()
1319        self.cli = None
1320        ThreadableTest.clientTearDown(self)
1321
1322    def _justAccept(self):
1323        conn, addr = self.serv.accept()
1324        conn.close()
1325
1326    testFamily = _justAccept
1327    def _testFamily(self):
1328        self.cli = socket.create_connection((HOST, self.port), timeout=30)
1329        self.addCleanup(self.cli.close)
1330        self.assertEqual(self.cli.family, 2)
1331
1332    testSourceAddress = _justAccept
1333    def _testSourceAddress(self):
1334        self.cli = socket.create_connection((HOST, self.port), timeout=30,
1335                source_address=('', self.source_port))
1336        self.addCleanup(self.cli.close)
1337        self.assertEqual(self.cli.getsockname()[1], self.source_port)
1338        # The port number being used is sufficient to show that the bind()
1339        # call happened.
1340
1341    testTimeoutDefault = _justAccept
1342    def _testTimeoutDefault(self):
1343        # passing no explicit timeout uses socket's global default
1344        self.assertTrue(socket.getdefaulttimeout() is None)
1345        socket.setdefaulttimeout(42)
1346        try:
1347            self.cli = socket.create_connection((HOST, self.port))
1348            self.addCleanup(self.cli.close)
1349        finally:
1350            socket.setdefaulttimeout(None)
1351        self.assertEqual(self.cli.gettimeout(), 42)
1352
1353    testTimeoutNone = _justAccept
1354    def _testTimeoutNone(self):
1355        # None timeout means the same as sock.settimeout(None)
1356        self.assertTrue(socket.getdefaulttimeout() is None)
1357        socket.setdefaulttimeout(30)
1358        try:
1359            self.cli = socket.create_connection((HOST, self.port), timeout=None)
1360            self.addCleanup(self.cli.close)
1361        finally:
1362            socket.setdefaulttimeout(None)
1363        self.assertEqual(self.cli.gettimeout(), None)
1364
1365    testTimeoutValueNamed = _justAccept
1366    def _testTimeoutValueNamed(self):
1367        self.cli = socket.create_connection((HOST, self.port), timeout=30)
1368        self.assertEqual(self.cli.gettimeout(), 30)
1369
1370    testTimeoutValueNonamed = _justAccept
1371    def _testTimeoutValueNonamed(self):
1372        self.cli = socket.create_connection((HOST, self.port), 30)
1373        self.addCleanup(self.cli.close)
1374        self.assertEqual(self.cli.gettimeout(), 30)
1375
1376@unittest.skipUnless(thread, 'Threading required for this test.')
1377class NetworkConnectionBehaviourTest(SocketTCPTest, ThreadableTest):
1378
1379    def __init__(self, methodName='runTest'):
1380        SocketTCPTest.__init__(self, methodName=methodName)
1381        ThreadableTest.__init__(self)
1382
1383    def clientSetUp(self):
1384        pass
1385
1386    def clientTearDown(self):
1387        self.cli.close()
1388        self.cli = None
1389        ThreadableTest.clientTearDown(self)
1390
1391    def testInsideTimeout(self):
1392        conn, addr = self.serv.accept()
1393        self.addCleanup(conn.close)
1394        time.sleep(3)
1395        conn.send("done!")
1396    testOutsideTimeout = testInsideTimeout
1397
1398    def _testInsideTimeout(self):
1399        self.cli = sock = socket.create_connection((HOST, self.port))
1400        data = sock.recv(5)
1401        self.assertEqual(data, "done!")
1402
1403    def _testOutsideTimeout(self):
1404        self.cli = sock = socket.create_connection((HOST, self.port), timeout=1)
1405        self.assertRaises(socket.timeout, lambda: sock.recv(5))
1406
1407
1408class Urllib2FileobjectTest(unittest.TestCase):
1409
1410    # urllib2.HTTPHandler has "borrowed" socket._fileobject, and requires that
1411    # it close the socket if the close c'tor argument is true
1412
1413    def testClose(self):
1414        class MockSocket:
1415            closed = False
1416            def flush(self): pass
1417            def close(self): self.closed = True
1418
1419        # must not close unless we request it: the original use of _fileobject
1420        # by module socket requires that the underlying socket not be closed until
1421        # the _socketobject that created the _fileobject is closed
1422        s = MockSocket()
1423        f = socket._fileobject(s)
1424        f.close()
1425        self.assertTrue(not s.closed)
1426
1427        s = MockSocket()
1428        f = socket._fileobject(s, close=True)
1429        f.close()
1430        self.assertTrue(s.closed)
1431
1432class TCPTimeoutTest(SocketTCPTest):
1433
1434    def testTCPTimeout(self):
1435        def raise_timeout(*args, **kwargs):
1436            self.serv.settimeout(1.0)
1437            self.serv.accept()
1438        self.assertRaises(socket.timeout, raise_timeout,
1439                              "Error generating a timeout exception (TCP)")
1440
1441    def testTimeoutZero(self):
1442        ok = False
1443        try:
1444            self.serv.settimeout(0.0)
1445            foo = self.serv.accept()
1446        except socket.timeout:
1447            self.fail("caught timeout instead of error (TCP)")
1448        except socket.error:
1449            ok = True
1450        except:
1451            self.fail("caught unexpected exception (TCP)")
1452        if not ok:
1453            self.fail("accept() returned success when we did not expect it")
1454
1455    def testInterruptedTimeout(self):
1456        # XXX I don't know how to do this test on MSWindows or any other
1457        # plaform that doesn't support signal.alarm() or os.kill(), though
1458        # the bug should have existed on all platforms.
1459        if not hasattr(signal, "alarm"):
1460            return                  # can only test on *nix
1461        self.serv.settimeout(5.0)   # must be longer than alarm
1462        class Alarm(Exception):
1463            pass
1464        def alarm_handler(signal, frame):
1465            raise Alarm
1466        old_alarm = signal.signal(signal.SIGALRM, alarm_handler)
1467        try:
1468            signal.alarm(2)    # POSIX allows alarm to be up to 1 second early
1469            try:
1470                foo = self.serv.accept()
1471            except socket.timeout:
1472                self.fail("caught timeout instead of Alarm")
1473            except Alarm:
1474                pass
1475            except:
1476                self.fail("caught other exception instead of Alarm:"
1477                          " %s(%s):\n%s" %
1478                          (sys.exc_info()[:2] + (traceback.format_exc(),)))
1479            else:
1480                self.fail("nothing caught")
1481            finally:
1482                signal.alarm(0)         # shut off alarm
1483        except Alarm:
1484            self.fail("got Alarm in wrong place")
1485        finally:
1486            # no alarm can be pending.  Safe to restore old handler.
1487            signal.signal(signal.SIGALRM, old_alarm)
1488
1489class UDPTimeoutTest(SocketUDPTest):
1490
1491    def testUDPTimeout(self):
1492        def raise_timeout(*args, **kwargs):
1493            self.serv.settimeout(1.0)
1494            self.serv.recv(1024)
1495        self.assertRaises(socket.timeout, raise_timeout,
1496                              "Error generating a timeout exception (UDP)")
1497
1498    def testTimeoutZero(self):
1499        ok = False
1500        try:
1501            self.serv.settimeout(0.0)
1502            foo = self.serv.recv(1024)
1503        except socket.timeout:
1504            self.fail("caught timeout instead of error (UDP)")
1505        except socket.error:
1506            ok = True
1507        except:
1508            self.fail("caught unexpected exception (UDP)")
1509        if not ok:
1510            self.fail("recv() returned success when we did not expect it")
1511
1512class TestExceptions(unittest.TestCase):
1513
1514    def testExceptionTree(self):
1515        self.assertTrue(issubclass(socket.error, Exception))
1516        self.assertTrue(issubclass(socket.herror, socket.error))
1517        self.assertTrue(issubclass(socket.gaierror, socket.error))
1518        self.assertTrue(issubclass(socket.timeout, socket.error))
1519
1520class TestLinuxAbstractNamespace(unittest.TestCase):
1521
1522    UNIX_PATH_MAX = 108
1523
1524    def testLinuxAbstractNamespace(self):
1525        address = "\x00python-test-hello\x00\xff"
1526        s1 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1527        s1.bind(address)
1528        s1.listen(1)
1529        s2 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1530        s2.connect(s1.getsockname())
1531        s1.accept()
1532        self.assertEqual(s1.getsockname(), address)
1533        self.assertEqual(s2.getpeername(), address)
1534
1535    def testMaxName(self):
1536        address = "\x00" + "h" * (self.UNIX_PATH_MAX - 1)
1537        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1538        s.bind(address)
1539        self.assertEqual(s.getsockname(), address)
1540
1541    def testNameOverflow(self):
1542        address = "\x00" + "h" * self.UNIX_PATH_MAX
1543        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1544        self.assertRaises(socket.error, s.bind, address)
1545
1546
1547@unittest.skipUnless(thread, 'Threading required for this test.')
1548class BufferIOTest(SocketConnectedTest):
1549    """
1550    Test the buffer versions of socket.recv() and socket.send().
1551    """
1552    def __init__(self, methodName='runTest'):
1553        SocketConnectedTest.__init__(self, methodName=methodName)
1554
1555    def testRecvIntoArray(self):
1556        buf = array.array('c', ' '*1024)
1557        nbytes = self.cli_conn.recv_into(buf)
1558        self.assertEqual(nbytes, len(MSG))
1559        msg = buf.tostring()[:len(MSG)]
1560        self.assertEqual(msg, MSG)
1561
1562    def _testRecvIntoArray(self):
1563        with test_support.check_py3k_warnings():
1564            buf = buffer(MSG)
1565        self.serv_conn.send(buf)
1566
1567    def testRecvIntoBytearray(self):
1568        buf = bytearray(1024)
1569        nbytes = self.cli_conn.recv_into(buf)
1570        self.assertEqual(nbytes, len(MSG))
1571        msg = buf[:len(MSG)]
1572        self.assertEqual(msg, MSG)
1573
1574    _testRecvIntoBytearray = _testRecvIntoArray
1575
1576    def testRecvIntoMemoryview(self):
1577        buf = bytearray(1024)
1578        nbytes = self.cli_conn.recv_into(memoryview(buf))
1579        self.assertEqual(nbytes, len(MSG))
1580        msg = buf[:len(MSG)]
1581        self.assertEqual(msg, MSG)
1582
1583    _testRecvIntoMemoryview = _testRecvIntoArray
1584
1585    def testRecvFromIntoArray(self):
1586        buf = array.array('c', ' '*1024)
1587        nbytes, addr = self.cli_conn.recvfrom_into(buf)
1588        self.assertEqual(nbytes, len(MSG))
1589        msg = buf.tostring()[:len(MSG)]
1590        self.assertEqual(msg, MSG)
1591
1592    def _testRecvFromIntoArray(self):
1593        with test_support.check_py3k_warnings():
1594            buf = buffer(MSG)
1595        self.serv_conn.send(buf)
1596
1597    def testRecvFromIntoBytearray(self):
1598        buf = bytearray(1024)
1599        nbytes, addr = self.cli_conn.recvfrom_into(buf)
1600        self.assertEqual(nbytes, len(MSG))
1601        msg = buf[:len(MSG)]
1602        self.assertEqual(msg, MSG)
1603
1604    _testRecvFromIntoBytearray = _testRecvFromIntoArray
1605
1606    def testRecvFromIntoMemoryview(self):
1607        buf = bytearray(1024)
1608        nbytes, addr = self.cli_conn.recvfrom_into(memoryview(buf))
1609        self.assertEqual(nbytes, len(MSG))
1610        msg = buf[:len(MSG)]
1611        self.assertEqual(msg, MSG)
1612
1613    _testRecvFromIntoMemoryview = _testRecvFromIntoArray
1614
1615
1616TIPC_STYPE = 2000
1617TIPC_LOWER = 200
1618TIPC_UPPER = 210
1619
1620def isTipcAvailable():
1621    """Check if the TIPC module is loaded
1622
1623    The TIPC module is not loaded automatically on Ubuntu and probably
1624    other Linux distros.
1625    """
1626    if not hasattr(socket, "AF_TIPC"):
1627        return False
1628    if not os.path.isfile("/proc/modules"):
1629        return False
1630    with open("/proc/modules") as f:
1631        for line in f:
1632            if line.startswith("tipc "):
1633                return True
1634    if test_support.verbose:
1635        print "TIPC module is not loaded, please 'sudo modprobe tipc'"
1636    return False
1637
1638class TIPCTest (unittest.TestCase):
1639    def testRDM(self):
1640        srv = socket.socket(socket.AF_TIPC, socket.SOCK_RDM)
1641        cli = socket.socket(socket.AF_TIPC, socket.SOCK_RDM)
1642
1643        srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1644        srvaddr = (socket.TIPC_ADDR_NAMESEQ, TIPC_STYPE,
1645                TIPC_LOWER, TIPC_UPPER)
1646        srv.bind(srvaddr)
1647
1648        sendaddr = (socket.TIPC_ADDR_NAME, TIPC_STYPE,
1649                TIPC_LOWER + (TIPC_UPPER - TIPC_LOWER) / 2, 0)
1650        cli.sendto(MSG, sendaddr)
1651
1652        msg, recvaddr = srv.recvfrom(1024)
1653
1654        self.assertEqual(cli.getsockname(), recvaddr)
1655        self.assertEqual(msg, MSG)
1656
1657
1658class TIPCThreadableTest (unittest.TestCase, ThreadableTest):
1659    def __init__(self, methodName = 'runTest'):
1660        unittest.TestCase.__init__(self, methodName = methodName)
1661        ThreadableTest.__init__(self)
1662
1663    def setUp(self):
1664        self.srv = socket.socket(socket.AF_TIPC, socket.SOCK_STREAM)
1665        self.srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1666        srvaddr = (socket.TIPC_ADDR_NAMESEQ, TIPC_STYPE,
1667                TIPC_LOWER, TIPC_UPPER)
1668        self.srv.bind(srvaddr)
1669        self.srv.listen(5)
1670        self.serverExplicitReady()
1671        self.conn, self.connaddr = self.srv.accept()
1672
1673    def clientSetUp(self):
1674        # The is a hittable race between serverExplicitReady() and the
1675        # accept() call; sleep a little while to avoid it, otherwise
1676        # we could get an exception
1677        time.sleep(0.1)
1678        self.cli = socket.socket(socket.AF_TIPC, socket.SOCK_STREAM)
1679        addr = (socket.TIPC_ADDR_NAME, TIPC_STYPE,
1680                TIPC_LOWER + (TIPC_UPPER - TIPC_LOWER) / 2, 0)
1681        self.cli.connect(addr)
1682        self.cliaddr = self.cli.getsockname()
1683
1684    def testStream(self):
1685        msg = self.conn.recv(1024)
1686        self.assertEqual(msg, MSG)
1687        self.assertEqual(self.cliaddr, self.connaddr)
1688
1689    def _testStream(self):
1690        self.cli.send(MSG)
1691        self.cli.close()
1692
1693
1694def test_main():
1695    tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
1696             TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest,
1697             UDPTimeoutTest ]
1698
1699    tests.extend([
1700        NonBlockingTCPTests,
1701        FileObjectClassTestCase,
1702        FileObjectInterruptedTestCase,
1703        UnbufferedFileObjectClassTestCase,
1704        LineBufferedFileObjectClassTestCase,
1705        SmallBufferedFileObjectClassTestCase,
1706        Urllib2FileobjectTest,
1707        NetworkConnectionNoServer,
1708        NetworkConnectionAttributesTest,
1709        NetworkConnectionBehaviourTest,
1710    ])
1711    if hasattr(socket, "socketpair"):
1712        tests.append(BasicSocketPairTest)
1713    if sys.platform == 'linux2':
1714        tests.append(TestLinuxAbstractNamespace)
1715    if isTipcAvailable():
1716        tests.append(TIPCTest)
1717        tests.append(TIPCThreadableTest)
1718
1719    thread_info = test_support.threading_setup()
1720    test_support.run_unittest(*tests)
1721    test_support.threading_cleanup(*thread_info)
1722
1723if __name__ == "__main__":
1724    test_main()
1725