1# Test the support for SSL and sockets
2
3import sys
4import unittest
5from test import support
6import socket
7import select
8import time
9import datetime
10import gc
11import os
12import errno
13import pprint
14import tempfile
15import urllib.request
16import traceback
17import asyncore
18import weakref
19import platform
20import functools
21
22ssl = support.import_module("ssl")
23
24try:
25    import threading
26except ImportError:
27    _have_threads = False
28else:
29    _have_threads = True
30
31PROTOCOLS = sorted(ssl._PROTOCOL_NAMES)
32HOST = support.HOST
33IS_LIBRESSL = ssl.OPENSSL_VERSION.startswith('LibreSSL')
34IS_OPENSSL_1_1 = not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (1, 1, 0)
35
36
37def data_file(*name):
38    return os.path.join(os.path.dirname(__file__), *name)
39
40# The custom key and certificate files used in test_ssl are generated
41# using Lib/test/make_ssl_certs.py.
42# Other certificates are simply fetched from the Internet servers they
43# are meant to authenticate.
44
45CERTFILE = data_file("keycert.pem")
46BYTES_CERTFILE = os.fsencode(CERTFILE)
47ONLYCERT = data_file("ssl_cert.pem")
48ONLYKEY = data_file("ssl_key.pem")
49BYTES_ONLYCERT = os.fsencode(ONLYCERT)
50BYTES_ONLYKEY = os.fsencode(ONLYKEY)
51CERTFILE_PROTECTED = data_file("keycert.passwd.pem")
52ONLYKEY_PROTECTED = data_file("ssl_key.passwd.pem")
53KEY_PASSWORD = "somepass"
54CAPATH = data_file("capath")
55BYTES_CAPATH = os.fsencode(CAPATH)
56CAFILE_NEURONIO = data_file("capath", "4e1295a3.0")
57CAFILE_CACERT = data_file("capath", "5ed36f99.0")
58
59
60# empty CRL
61CRLFILE = data_file("revocation.crl")
62
63# Two keys and certs signed by the same CA (for SNI tests)
64SIGNED_CERTFILE = data_file("keycert3.pem")
65SIGNED_CERTFILE2 = data_file("keycert4.pem")
66# Same certificate as pycacert.pem, but without extra text in file
67SIGNING_CA = data_file("capath", "ceff1710.0")
68# cert with all kinds of subject alt names
69ALLSANFILE = data_file("allsans.pem")
70
71REMOTE_HOST = "self-signed.pythontest.net"
72
73EMPTYCERT = data_file("nullcert.pem")
74BADCERT = data_file("badcert.pem")
75NONEXISTINGCERT = data_file("XXXnonexisting.pem")
76BADKEY = data_file("badkey.pem")
77NOKIACERT = data_file("nokia.pem")
78NULLBYTECERT = data_file("nullbytecert.pem")
79
80DHFILE = data_file("dh1024.pem")
81BYTES_DHFILE = os.fsencode(DHFILE)
82
83# Not defined in all versions of OpenSSL
84OP_NO_COMPRESSION = getattr(ssl, "OP_NO_COMPRESSION", 0)
85OP_SINGLE_DH_USE = getattr(ssl, "OP_SINGLE_DH_USE", 0)
86OP_SINGLE_ECDH_USE = getattr(ssl, "OP_SINGLE_ECDH_USE", 0)
87OP_CIPHER_SERVER_PREFERENCE = getattr(ssl, "OP_CIPHER_SERVER_PREFERENCE", 0)
88
89
90def handle_error(prefix):
91    exc_format = ' '.join(traceback.format_exception(*sys.exc_info()))
92    if support.verbose:
93        sys.stdout.write(prefix + exc_format)
94
95def can_clear_options():
96    # 0.9.8m or higher
97    return ssl._OPENSSL_API_VERSION >= (0, 9, 8, 13, 15)
98
99def no_sslv2_implies_sslv3_hello():
100    # 0.9.7h or higher
101    return ssl.OPENSSL_VERSION_INFO >= (0, 9, 7, 8, 15)
102
103def have_verify_flags():
104    # 0.9.8 or higher
105    return ssl.OPENSSL_VERSION_INFO >= (0, 9, 8, 0, 15)
106
107def utc_offset(): #NOTE: ignore issues like #1647654
108    # local time = utc time + utc offset
109    if time.daylight and time.localtime().tm_isdst > 0:
110        return -time.altzone  # seconds
111    return -time.timezone
112
113def asn1time(cert_time):
114    # Some versions of OpenSSL ignore seconds, see #18207
115    # 0.9.8.i
116    if ssl._OPENSSL_API_VERSION == (0, 9, 8, 9, 15):
117        fmt = "%b %d %H:%M:%S %Y GMT"
118        dt = datetime.datetime.strptime(cert_time, fmt)
119        dt = dt.replace(second=0)
120        cert_time = dt.strftime(fmt)
121        # %d adds leading zero but ASN1_TIME_print() uses leading space
122        if cert_time[4] == "0":
123            cert_time = cert_time[:4] + " " + cert_time[5:]
124
125    return cert_time
126
127# Issue #9415: Ubuntu hijacks their OpenSSL and forcefully disables SSLv2
128def skip_if_broken_ubuntu_ssl(func):
129    if hasattr(ssl, 'PROTOCOL_SSLv2'):
130        @functools.wraps(func)
131        def f(*args, **kwargs):
132            try:
133                ssl.SSLContext(ssl.PROTOCOL_SSLv2)
134            except ssl.SSLError:
135                if (ssl.OPENSSL_VERSION_INFO == (0, 9, 8, 15, 15) and
136                    platform.linux_distribution() == ('debian', 'squeeze/sid', '')):
137                    raise unittest.SkipTest("Patched Ubuntu OpenSSL breaks behaviour")
138            return func(*args, **kwargs)
139        return f
140    else:
141        return func
142
143needs_sni = unittest.skipUnless(ssl.HAS_SNI, "SNI support needed for this test")
144
145
146def test_wrap_socket(sock, ssl_version=ssl.PROTOCOL_TLS, *,
147                     cert_reqs=ssl.CERT_NONE, ca_certs=None,
148                     ciphers=None, certfile=None, keyfile=None,
149                     **kwargs):
150    context = ssl.SSLContext(ssl_version)
151    if cert_reqs is not None:
152        context.verify_mode = cert_reqs
153    if ca_certs is not None:
154        context.load_verify_locations(ca_certs)
155    if certfile is not None or keyfile is not None:
156        context.load_cert_chain(certfile, keyfile)
157    if ciphers is not None:
158        context.set_ciphers(ciphers)
159    return context.wrap_socket(sock, **kwargs)
160
161class BasicSocketTests(unittest.TestCase):
162
163    def test_constants(self):
164        ssl.CERT_NONE
165        ssl.CERT_OPTIONAL
166        ssl.CERT_REQUIRED
167        ssl.OP_CIPHER_SERVER_PREFERENCE
168        ssl.OP_SINGLE_DH_USE
169        if ssl.HAS_ECDH:
170            ssl.OP_SINGLE_ECDH_USE
171        if ssl.OPENSSL_VERSION_INFO >= (1, 0):
172            ssl.OP_NO_COMPRESSION
173        self.assertIn(ssl.HAS_SNI, {True, False})
174        self.assertIn(ssl.HAS_ECDH, {True, False})
175
176    def test_str_for_enums(self):
177        # Make sure that the PROTOCOL_* constants have enum-like string
178        # reprs.
179        proto = ssl.PROTOCOL_TLS
180        self.assertEqual(str(proto), '_SSLMethod.PROTOCOL_TLS')
181        ctx = ssl.SSLContext(proto)
182        self.assertIs(ctx.protocol, proto)
183
184    def test_random(self):
185        v = ssl.RAND_status()
186        if support.verbose:
187            sys.stdout.write("\n RAND_status is %d (%s)\n"
188                             % (v, (v and "sufficient randomness") or
189                                "insufficient randomness"))
190
191        data, is_cryptographic = ssl.RAND_pseudo_bytes(16)
192        self.assertEqual(len(data), 16)
193        self.assertEqual(is_cryptographic, v == 1)
194        if v:
195            data = ssl.RAND_bytes(16)
196            self.assertEqual(len(data), 16)
197        else:
198            self.assertRaises(ssl.SSLError, ssl.RAND_bytes, 16)
199
200        # negative num is invalid
201        self.assertRaises(ValueError, ssl.RAND_bytes, -5)
202        self.assertRaises(ValueError, ssl.RAND_pseudo_bytes, -5)
203
204        if hasattr(ssl, 'RAND_egd'):
205            self.assertRaises(TypeError, ssl.RAND_egd, 1)
206            self.assertRaises(TypeError, ssl.RAND_egd, 'foo', 1)
207        ssl.RAND_add("this is a random string", 75.0)
208        ssl.RAND_add(b"this is a random bytes object", 75.0)
209        ssl.RAND_add(bytearray(b"this is a random bytearray object"), 75.0)
210
211    @unittest.skipUnless(os.name == 'posix', 'requires posix')
212    def test_random_fork(self):
213        status = ssl.RAND_status()
214        if not status:
215            self.fail("OpenSSL's PRNG has insufficient randomness")
216
217        rfd, wfd = os.pipe()
218        pid = os.fork()
219        if pid == 0:
220            try:
221                os.close(rfd)
222                child_random = ssl.RAND_pseudo_bytes(16)[0]
223                self.assertEqual(len(child_random), 16)
224                os.write(wfd, child_random)
225                os.close(wfd)
226            except BaseException:
227                os._exit(1)
228            else:
229                os._exit(0)
230        else:
231            os.close(wfd)
232            self.addCleanup(os.close, rfd)
233            _, status = os.waitpid(pid, 0)
234            self.assertEqual(status, 0)
235
236            child_random = os.read(rfd, 16)
237            self.assertEqual(len(child_random), 16)
238            parent_random = ssl.RAND_pseudo_bytes(16)[0]
239            self.assertEqual(len(parent_random), 16)
240
241            self.assertNotEqual(child_random, parent_random)
242
243    def test_parse_cert(self):
244        # note that this uses an 'unofficial' function in _ssl.c,
245        # provided solely for this test, to exercise the certificate
246        # parsing code
247        p = ssl._ssl._test_decode_cert(CERTFILE)
248        if support.verbose:
249            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
250        self.assertEqual(p['issuer'],
251                         ((('countryName', 'XY'),),
252                          (('localityName', 'Castle Anthrax'),),
253                          (('organizationName', 'Python Software Foundation'),),
254                          (('commonName', 'localhost'),))
255                        )
256        # Note the next three asserts will fail if the keys are regenerated
257        self.assertEqual(p['notAfter'], asn1time('Oct  5 23:01:56 2020 GMT'))
258        self.assertEqual(p['notBefore'], asn1time('Oct  8 23:01:56 2010 GMT'))
259        self.assertEqual(p['serialNumber'], 'D7C7381919AFC24E')
260        self.assertEqual(p['subject'],
261                         ((('countryName', 'XY'),),
262                          (('localityName', 'Castle Anthrax'),),
263                          (('organizationName', 'Python Software Foundation'),),
264                          (('commonName', 'localhost'),))
265                        )
266        self.assertEqual(p['subjectAltName'], (('DNS', 'localhost'),))
267        # Issue #13034: the subjectAltName in some certificates
268        # (notably projects.developer.nokia.com:443) wasn't parsed
269        p = ssl._ssl._test_decode_cert(NOKIACERT)
270        if support.verbose:
271            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
272        self.assertEqual(p['subjectAltName'],
273                         (('DNS', 'projects.developer.nokia.com'),
274                          ('DNS', 'projects.forum.nokia.com'))
275                        )
276        # extra OCSP and AIA fields
277        self.assertEqual(p['OCSP'], ('http://ocsp.verisign.com',))
278        self.assertEqual(p['caIssuers'],
279                         ('http://SVRIntl-G3-aia.verisign.com/SVRIntlG3.cer',))
280        self.assertEqual(p['crlDistributionPoints'],
281                         ('http://SVRIntl-G3-crl.verisign.com/SVRIntlG3.crl',))
282
283    def test_parse_cert_CVE_2013_4238(self):
284        p = ssl._ssl._test_decode_cert(NULLBYTECERT)
285        if support.verbose:
286            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
287        subject = ((('countryName', 'US'),),
288                   (('stateOrProvinceName', 'Oregon'),),
289                   (('localityName', 'Beaverton'),),
290                   (('organizationName', 'Python Software Foundation'),),
291                   (('organizationalUnitName', 'Python Core Development'),),
292                   (('commonName', 'null.python.org\x00example.org'),),
293                   (('emailAddress', 'python-dev@python.org'),))
294        self.assertEqual(p['subject'], subject)
295        self.assertEqual(p['issuer'], subject)
296        if ssl._OPENSSL_API_VERSION >= (0, 9, 8):
297            san = (('DNS', 'altnull.python.org\x00example.com'),
298                   ('email', 'null@python.org\x00user@example.org'),
299                   ('URI', 'http://null.python.org\x00http://example.org'),
300                   ('IP Address', '192.0.2.1'),
301                   ('IP Address', '2001:DB8:0:0:0:0:0:1\n'))
302        else:
303            # OpenSSL 0.9.7 doesn't support IPv6 addresses in subjectAltName
304            san = (('DNS', 'altnull.python.org\x00example.com'),
305                   ('email', 'null@python.org\x00user@example.org'),
306                   ('URI', 'http://null.python.org\x00http://example.org'),
307                   ('IP Address', '192.0.2.1'),
308                   ('IP Address', '<invalid>'))
309
310        self.assertEqual(p['subjectAltName'], san)
311
312    def test_parse_all_sans(self):
313        p = ssl._ssl._test_decode_cert(ALLSANFILE)
314        self.assertEqual(p['subjectAltName'],
315            (
316                ('DNS', 'allsans'),
317                ('othername', '<unsupported>'),
318                ('othername', '<unsupported>'),
319                ('email', 'user@example.org'),
320                ('DNS', 'www.example.org'),
321                ('DirName',
322                    ((('countryName', 'XY'),),
323                    (('localityName', 'Castle Anthrax'),),
324                    (('organizationName', 'Python Software Foundation'),),
325                    (('commonName', 'dirname example'),))),
326                ('URI', 'https://www.python.org/'),
327                ('IP Address', '127.0.0.1'),
328                ('IP Address', '0:0:0:0:0:0:0:1\n'),
329                ('Registered ID', '1.2.3.4.5')
330            )
331        )
332
333    def test_DER_to_PEM(self):
334        with open(CAFILE_CACERT, 'r') as f:
335            pem = f.read()
336        d1 = ssl.PEM_cert_to_DER_cert(pem)
337        p2 = ssl.DER_cert_to_PEM_cert(d1)
338        d2 = ssl.PEM_cert_to_DER_cert(p2)
339        self.assertEqual(d1, d2)
340        if not p2.startswith(ssl.PEM_HEADER + '\n'):
341            self.fail("DER-to-PEM didn't include correct header:\n%r\n" % p2)
342        if not p2.endswith('\n' + ssl.PEM_FOOTER + '\n'):
343            self.fail("DER-to-PEM didn't include correct footer:\n%r\n" % p2)
344
345    def test_openssl_version(self):
346        n = ssl.OPENSSL_VERSION_NUMBER
347        t = ssl.OPENSSL_VERSION_INFO
348        s = ssl.OPENSSL_VERSION
349        self.assertIsInstance(n, int)
350        self.assertIsInstance(t, tuple)
351        self.assertIsInstance(s, str)
352        # Some sanity checks follow
353        # >= 0.9
354        self.assertGreaterEqual(n, 0x900000)
355        # < 3.0
356        self.assertLess(n, 0x30000000)
357        major, minor, fix, patch, status = t
358        self.assertGreaterEqual(major, 0)
359        self.assertLess(major, 3)
360        self.assertGreaterEqual(minor, 0)
361        self.assertLess(minor, 256)
362        self.assertGreaterEqual(fix, 0)
363        self.assertLess(fix, 256)
364        self.assertGreaterEqual(patch, 0)
365        self.assertLessEqual(patch, 63)
366        self.assertGreaterEqual(status, 0)
367        self.assertLessEqual(status, 15)
368        # Version string as returned by {Open,Libre}SSL, the format might change
369        if IS_LIBRESSL:
370            self.assertTrue(s.startswith("LibreSSL {:d}".format(major)),
371                            (s, t, hex(n)))
372        else:
373            self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)),
374                            (s, t, hex(n)))
375
376    @support.cpython_only
377    def test_refcycle(self):
378        # Issue #7943: an SSL object doesn't create reference cycles with
379        # itself.
380        s = socket.socket(socket.AF_INET)
381        ss = test_wrap_socket(s)
382        wr = weakref.ref(ss)
383        with support.check_warnings(("", ResourceWarning)):
384            del ss
385        self.assertEqual(wr(), None)
386
387    def test_wrapped_unconnected(self):
388        # Methods on an unconnected SSLSocket propagate the original
389        # OSError raise by the underlying socket object.
390        s = socket.socket(socket.AF_INET)
391        with test_wrap_socket(s) as ss:
392            self.assertRaises(OSError, ss.recv, 1)
393            self.assertRaises(OSError, ss.recv_into, bytearray(b'x'))
394            self.assertRaises(OSError, ss.recvfrom, 1)
395            self.assertRaises(OSError, ss.recvfrom_into, bytearray(b'x'), 1)
396            self.assertRaises(OSError, ss.send, b'x')
397            self.assertRaises(OSError, ss.sendto, b'x', ('0.0.0.0', 0))
398
399    def test_timeout(self):
400        # Issue #8524: when creating an SSL socket, the timeout of the
401        # original socket should be retained.
402        for timeout in (None, 0.0, 5.0):
403            s = socket.socket(socket.AF_INET)
404            s.settimeout(timeout)
405            with test_wrap_socket(s) as ss:
406                self.assertEqual(timeout, ss.gettimeout())
407
408    def test_errors_sslwrap(self):
409        sock = socket.socket()
410        self.assertRaisesRegex(ValueError,
411                        "certfile must be specified",
412                        ssl.wrap_socket, sock, keyfile=CERTFILE)
413        self.assertRaisesRegex(ValueError,
414                        "certfile must be specified for server-side operations",
415                        ssl.wrap_socket, sock, server_side=True)
416        self.assertRaisesRegex(ValueError,
417                        "certfile must be specified for server-side operations",
418                         ssl.wrap_socket, sock, server_side=True, certfile="")
419        with ssl.wrap_socket(sock, server_side=True, certfile=CERTFILE) as s:
420            self.assertRaisesRegex(ValueError, "can't connect in server-side mode",
421                                     s.connect, (HOST, 8080))
422        with self.assertRaises(OSError) as cm:
423            with socket.socket() as sock:
424                ssl.wrap_socket(sock, certfile=NONEXISTINGCERT)
425        self.assertEqual(cm.exception.errno, errno.ENOENT)
426        with self.assertRaises(OSError) as cm:
427            with socket.socket() as sock:
428                ssl.wrap_socket(sock,
429                    certfile=CERTFILE, keyfile=NONEXISTINGCERT)
430        self.assertEqual(cm.exception.errno, errno.ENOENT)
431        with self.assertRaises(OSError) as cm:
432            with socket.socket() as sock:
433                ssl.wrap_socket(sock,
434                    certfile=NONEXISTINGCERT, keyfile=NONEXISTINGCERT)
435        self.assertEqual(cm.exception.errno, errno.ENOENT)
436
437    def bad_cert_test(self, certfile):
438        """Check that trying to use the given client certificate fails"""
439        certfile = os.path.join(os.path.dirname(__file__) or os.curdir,
440                                   certfile)
441        sock = socket.socket()
442        self.addCleanup(sock.close)
443        with self.assertRaises(ssl.SSLError):
444            test_wrap_socket(sock,
445                            certfile=certfile,
446                            ssl_version=ssl.PROTOCOL_TLSv1)
447
448    def test_empty_cert(self):
449        """Wrapping with an empty cert file"""
450        self.bad_cert_test("nullcert.pem")
451
452    def test_malformed_cert(self):
453        """Wrapping with a badly formatted certificate (syntax error)"""
454        self.bad_cert_test("badcert.pem")
455
456    def test_malformed_key(self):
457        """Wrapping with a badly formatted key (syntax error)"""
458        self.bad_cert_test("badkey.pem")
459
460    def test_match_hostname(self):
461        def ok(cert, hostname):
462            ssl.match_hostname(cert, hostname)
463        def fail(cert, hostname):
464            self.assertRaises(ssl.CertificateError,
465                              ssl.match_hostname, cert, hostname)
466
467        # -- Hostname matching --
468
469        cert = {'subject': ((('commonName', 'example.com'),),)}
470        ok(cert, 'example.com')
471        ok(cert, 'ExAmple.cOm')
472        fail(cert, 'www.example.com')
473        fail(cert, '.example.com')
474        fail(cert, 'example.org')
475        fail(cert, 'exampleXcom')
476
477        cert = {'subject': ((('commonName', '*.a.com'),),)}
478        ok(cert, 'foo.a.com')
479        fail(cert, 'bar.foo.a.com')
480        fail(cert, 'a.com')
481        fail(cert, 'Xa.com')
482        fail(cert, '.a.com')
483
484        # only match one left-most wildcard
485        cert = {'subject': ((('commonName', 'f*.com'),),)}
486        ok(cert, 'foo.com')
487        ok(cert, 'f.com')
488        fail(cert, 'bar.com')
489        fail(cert, 'foo.a.com')
490        fail(cert, 'bar.foo.com')
491
492        # NULL bytes are bad, CVE-2013-4073
493        cert = {'subject': ((('commonName',
494                              'null.python.org\x00example.org'),),)}
495        ok(cert, 'null.python.org\x00example.org') # or raise an error?
496        fail(cert, 'example.org')
497        fail(cert, 'null.python.org')
498
499        # error cases with wildcards
500        cert = {'subject': ((('commonName', '*.*.a.com'),),)}
501        fail(cert, 'bar.foo.a.com')
502        fail(cert, 'a.com')
503        fail(cert, 'Xa.com')
504        fail(cert, '.a.com')
505
506        cert = {'subject': ((('commonName', 'a.*.com'),),)}
507        fail(cert, 'a.foo.com')
508        fail(cert, 'a..com')
509        fail(cert, 'a.com')
510
511        # wildcard doesn't match IDNA prefix 'xn--'
512        idna = 'püthon.python.org'.encode("idna").decode("ascii")
513        cert = {'subject': ((('commonName', idna),),)}
514        ok(cert, idna)
515        cert = {'subject': ((('commonName', 'x*.python.org'),),)}
516        fail(cert, idna)
517        cert = {'subject': ((('commonName', 'xn--p*.python.org'),),)}
518        fail(cert, idna)
519
520        # wildcard in first fragment and  IDNA A-labels in sequent fragments
521        # are supported.
522        idna = 'www*.pythön.org'.encode("idna").decode("ascii")
523        cert = {'subject': ((('commonName', idna),),)}
524        ok(cert, 'www.pythön.org'.encode("idna").decode("ascii"))
525        ok(cert, 'www1.pythön.org'.encode("idna").decode("ascii"))
526        fail(cert, 'ftp.pythön.org'.encode("idna").decode("ascii"))
527        fail(cert, 'pythön.org'.encode("idna").decode("ascii"))
528
529        # Slightly fake real-world example
530        cert = {'notAfter': 'Jun 26 21:41:46 2011 GMT',
531                'subject': ((('commonName', 'linuxfrz.org'),),),
532                'subjectAltName': (('DNS', 'linuxfr.org'),
533                                   ('DNS', 'linuxfr.com'),
534                                   ('othername', '<unsupported>'))}
535        ok(cert, 'linuxfr.org')
536        ok(cert, 'linuxfr.com')
537        # Not a "DNS" entry
538        fail(cert, '<unsupported>')
539        # When there is a subjectAltName, commonName isn't used
540        fail(cert, 'linuxfrz.org')
541
542        # A pristine real-world example
543        cert = {'notAfter': 'Dec 18 23:59:59 2011 GMT',
544                'subject': ((('countryName', 'US'),),
545                            (('stateOrProvinceName', 'California'),),
546                            (('localityName', 'Mountain View'),),
547                            (('organizationName', 'Google Inc'),),
548                            (('commonName', 'mail.google.com'),))}
549        ok(cert, 'mail.google.com')
550        fail(cert, 'gmail.com')
551        # Only commonName is considered
552        fail(cert, 'California')
553
554        # -- IPv4 matching --
555        cert = {'subject': ((('commonName', 'example.com'),),),
556                'subjectAltName': (('DNS', 'example.com'),
557                                   ('IP Address', '10.11.12.13'),
558                                   ('IP Address', '14.15.16.17'))}
559        ok(cert, '10.11.12.13')
560        ok(cert, '14.15.16.17')
561        fail(cert, '14.15.16.18')
562        fail(cert, 'example.net')
563
564        # -- IPv6 matching --
565        cert = {'subject': ((('commonName', 'example.com'),),),
566                'subjectAltName': (('DNS', 'example.com'),
567                                   ('IP Address', '2001:0:0:0:0:0:0:CAFE\n'),
568                                   ('IP Address', '2003:0:0:0:0:0:0:BABA\n'))}
569        ok(cert, '2001::cafe')
570        ok(cert, '2003::baba')
571        fail(cert, '2003::bebe')
572        fail(cert, 'example.net')
573
574        # -- Miscellaneous --
575
576        # Neither commonName nor subjectAltName
577        cert = {'notAfter': 'Dec 18 23:59:59 2011 GMT',
578                'subject': ((('countryName', 'US'),),
579                            (('stateOrProvinceName', 'California'),),
580                            (('localityName', 'Mountain View'),),
581                            (('organizationName', 'Google Inc'),))}
582        fail(cert, 'mail.google.com')
583
584        # No DNS entry in subjectAltName but a commonName
585        cert = {'notAfter': 'Dec 18 23:59:59 2099 GMT',
586                'subject': ((('countryName', 'US'),),
587                            (('stateOrProvinceName', 'California'),),
588                            (('localityName', 'Mountain View'),),
589                            (('commonName', 'mail.google.com'),)),
590                'subjectAltName': (('othername', 'blabla'), )}
591        ok(cert, 'mail.google.com')
592
593        # No DNS entry subjectAltName and no commonName
594        cert = {'notAfter': 'Dec 18 23:59:59 2099 GMT',
595                'subject': ((('countryName', 'US'),),
596                            (('stateOrProvinceName', 'California'),),
597                            (('localityName', 'Mountain View'),),
598                            (('organizationName', 'Google Inc'),)),
599                'subjectAltName': (('othername', 'blabla'),)}
600        fail(cert, 'google.com')
601
602        # Empty cert / no cert
603        self.assertRaises(ValueError, ssl.match_hostname, None, 'example.com')
604        self.assertRaises(ValueError, ssl.match_hostname, {}, 'example.com')
605
606        # Issue #17980: avoid denials of service by refusing more than one
607        # wildcard per fragment.
608        cert = {'subject': ((('commonName', 'a*b.com'),),)}
609        ok(cert, 'axxb.com')
610        cert = {'subject': ((('commonName', 'a*b.co*'),),)}
611        fail(cert, 'axxb.com')
612        cert = {'subject': ((('commonName', 'a*b*.com'),),)}
613        with self.assertRaises(ssl.CertificateError) as cm:
614            ssl.match_hostname(cert, 'axxbxxc.com')
615        self.assertIn("too many wildcards", str(cm.exception))
616
617    def test_server_side(self):
618        # server_hostname doesn't work for server sockets
619        ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
620        with socket.socket() as sock:
621            self.assertRaises(ValueError, ctx.wrap_socket, sock, True,
622                              server_hostname="some.hostname")
623
624    def test_unknown_channel_binding(self):
625        # should raise ValueError for unknown type
626        s = socket.socket(socket.AF_INET)
627        s.bind(('127.0.0.1', 0))
628        s.listen()
629        c = socket.socket(socket.AF_INET)
630        c.connect(s.getsockname())
631        with test_wrap_socket(c, do_handshake_on_connect=False) as ss:
632            with self.assertRaises(ValueError):
633                ss.get_channel_binding("unknown-type")
634        s.close()
635
636    @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
637                         "'tls-unique' channel binding not available")
638    def test_tls_unique_channel_binding(self):
639        # unconnected should return None for known type
640        s = socket.socket(socket.AF_INET)
641        with test_wrap_socket(s) as ss:
642            self.assertIsNone(ss.get_channel_binding("tls-unique"))
643        # the same for server-side
644        s = socket.socket(socket.AF_INET)
645        with test_wrap_socket(s, server_side=True, certfile=CERTFILE) as ss:
646            self.assertIsNone(ss.get_channel_binding("tls-unique"))
647
648    def test_dealloc_warn(self):
649        ss = test_wrap_socket(socket.socket(socket.AF_INET))
650        r = repr(ss)
651        with self.assertWarns(ResourceWarning) as cm:
652            ss = None
653            support.gc_collect()
654        self.assertIn(r, str(cm.warning.args[0]))
655
656    def test_get_default_verify_paths(self):
657        paths = ssl.get_default_verify_paths()
658        self.assertEqual(len(paths), 6)
659        self.assertIsInstance(paths, ssl.DefaultVerifyPaths)
660
661        with support.EnvironmentVarGuard() as env:
662            env["SSL_CERT_DIR"] = CAPATH
663            env["SSL_CERT_FILE"] = CERTFILE
664            paths = ssl.get_default_verify_paths()
665            self.assertEqual(paths.cafile, CERTFILE)
666            self.assertEqual(paths.capath, CAPATH)
667
668    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
669    def test_enum_certificates(self):
670        self.assertTrue(ssl.enum_certificates("CA"))
671        self.assertTrue(ssl.enum_certificates("ROOT"))
672
673        self.assertRaises(TypeError, ssl.enum_certificates)
674        self.assertRaises(WindowsError, ssl.enum_certificates, "")
675
676        trust_oids = set()
677        for storename in ("CA", "ROOT"):
678            store = ssl.enum_certificates(storename)
679            self.assertIsInstance(store, list)
680            for element in store:
681                self.assertIsInstance(element, tuple)
682                self.assertEqual(len(element), 3)
683                cert, enc, trust = element
684                self.assertIsInstance(cert, bytes)
685                self.assertIn(enc, {"x509_asn", "pkcs_7_asn"})
686                self.assertIsInstance(trust, (set, bool))
687                if isinstance(trust, set):
688                    trust_oids.update(trust)
689
690        serverAuth = "1.3.6.1.5.5.7.3.1"
691        self.assertIn(serverAuth, trust_oids)
692
693    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
694    def test_enum_crls(self):
695        self.assertTrue(ssl.enum_crls("CA"))
696        self.assertRaises(TypeError, ssl.enum_crls)
697        self.assertRaises(WindowsError, ssl.enum_crls, "")
698
699        crls = ssl.enum_crls("CA")
700        self.assertIsInstance(crls, list)
701        for element in crls:
702            self.assertIsInstance(element, tuple)
703            self.assertEqual(len(element), 2)
704            self.assertIsInstance(element[0], bytes)
705            self.assertIn(element[1], {"x509_asn", "pkcs_7_asn"})
706
707
708    def test_asn1object(self):
709        expected = (129, 'serverAuth', 'TLS Web Server Authentication',
710                    '1.3.6.1.5.5.7.3.1')
711
712        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.1')
713        self.assertEqual(val, expected)
714        self.assertEqual(val.nid, 129)
715        self.assertEqual(val.shortname, 'serverAuth')
716        self.assertEqual(val.longname, 'TLS Web Server Authentication')
717        self.assertEqual(val.oid, '1.3.6.1.5.5.7.3.1')
718        self.assertIsInstance(val, ssl._ASN1Object)
719        self.assertRaises(ValueError, ssl._ASN1Object, 'serverAuth')
720
721        val = ssl._ASN1Object.fromnid(129)
722        self.assertEqual(val, expected)
723        self.assertIsInstance(val, ssl._ASN1Object)
724        self.assertRaises(ValueError, ssl._ASN1Object.fromnid, -1)
725        with self.assertRaisesRegex(ValueError, "unknown NID 100000"):
726            ssl._ASN1Object.fromnid(100000)
727        for i in range(1000):
728            try:
729                obj = ssl._ASN1Object.fromnid(i)
730            except ValueError:
731                pass
732            else:
733                self.assertIsInstance(obj.nid, int)
734                self.assertIsInstance(obj.shortname, str)
735                self.assertIsInstance(obj.longname, str)
736                self.assertIsInstance(obj.oid, (str, type(None)))
737
738        val = ssl._ASN1Object.fromname('TLS Web Server Authentication')
739        self.assertEqual(val, expected)
740        self.assertIsInstance(val, ssl._ASN1Object)
741        self.assertEqual(ssl._ASN1Object.fromname('serverAuth'), expected)
742        self.assertEqual(ssl._ASN1Object.fromname('1.3.6.1.5.5.7.3.1'),
743                         expected)
744        with self.assertRaisesRegex(ValueError, "unknown object 'serverauth'"):
745            ssl._ASN1Object.fromname('serverauth')
746
747    def test_purpose_enum(self):
748        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.1')
749        self.assertIsInstance(ssl.Purpose.SERVER_AUTH, ssl._ASN1Object)
750        self.assertEqual(ssl.Purpose.SERVER_AUTH, val)
751        self.assertEqual(ssl.Purpose.SERVER_AUTH.nid, 129)
752        self.assertEqual(ssl.Purpose.SERVER_AUTH.shortname, 'serverAuth')
753        self.assertEqual(ssl.Purpose.SERVER_AUTH.oid,
754                              '1.3.6.1.5.5.7.3.1')
755
756        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.2')
757        self.assertIsInstance(ssl.Purpose.CLIENT_AUTH, ssl._ASN1Object)
758        self.assertEqual(ssl.Purpose.CLIENT_AUTH, val)
759        self.assertEqual(ssl.Purpose.CLIENT_AUTH.nid, 130)
760        self.assertEqual(ssl.Purpose.CLIENT_AUTH.shortname, 'clientAuth')
761        self.assertEqual(ssl.Purpose.CLIENT_AUTH.oid,
762                              '1.3.6.1.5.5.7.3.2')
763
764    def test_unsupported_dtls(self):
765        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
766        self.addCleanup(s.close)
767        with self.assertRaises(NotImplementedError) as cx:
768            test_wrap_socket(s, cert_reqs=ssl.CERT_NONE)
769        self.assertEqual(str(cx.exception), "only stream sockets are supported")
770        ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
771        with self.assertRaises(NotImplementedError) as cx:
772            ctx.wrap_socket(s)
773        self.assertEqual(str(cx.exception), "only stream sockets are supported")
774
775    def cert_time_ok(self, timestring, timestamp):
776        self.assertEqual(ssl.cert_time_to_seconds(timestring), timestamp)
777
778    def cert_time_fail(self, timestring):
779        with self.assertRaises(ValueError):
780            ssl.cert_time_to_seconds(timestring)
781
782    @unittest.skipUnless(utc_offset(),
783                         'local time needs to be different from UTC')
784    def test_cert_time_to_seconds_timezone(self):
785        # Issue #19940: ssl.cert_time_to_seconds() returns wrong
786        #               results if local timezone is not UTC
787        self.cert_time_ok("May  9 00:00:00 2007 GMT", 1178668800.0)
788        self.cert_time_ok("Jan  5 09:34:43 2018 GMT", 1515144883.0)
789
790    def test_cert_time_to_seconds(self):
791        timestring = "Jan  5 09:34:43 2018 GMT"
792        ts = 1515144883.0
793        self.cert_time_ok(timestring, ts)
794        # accept keyword parameter, assert its name
795        self.assertEqual(ssl.cert_time_to_seconds(cert_time=timestring), ts)
796        # accept both %e and %d (space or zero generated by strftime)
797        self.cert_time_ok("Jan 05 09:34:43 2018 GMT", ts)
798        # case-insensitive
799        self.cert_time_ok("JaN  5 09:34:43 2018 GmT", ts)
800        self.cert_time_fail("Jan  5 09:34 2018 GMT")     # no seconds
801        self.cert_time_fail("Jan  5 09:34:43 2018")      # no GMT
802        self.cert_time_fail("Jan  5 09:34:43 2018 UTC")  # not GMT timezone
803        self.cert_time_fail("Jan 35 09:34:43 2018 GMT")  # invalid day
804        self.cert_time_fail("Jon  5 09:34:43 2018 GMT")  # invalid month
805        self.cert_time_fail("Jan  5 24:00:00 2018 GMT")  # invalid hour
806        self.cert_time_fail("Jan  5 09:60:43 2018 GMT")  # invalid minute
807
808        newyear_ts = 1230768000.0
809        # leap seconds
810        self.cert_time_ok("Dec 31 23:59:60 2008 GMT", newyear_ts)
811        # same timestamp
812        self.cert_time_ok("Jan  1 00:00:00 2009 GMT", newyear_ts)
813
814        self.cert_time_ok("Jan  5 09:34:59 2018 GMT", 1515144899)
815        #  allow 60th second (even if it is not a leap second)
816        self.cert_time_ok("Jan  5 09:34:60 2018 GMT", 1515144900)
817        #  allow 2nd leap second for compatibility with time.strptime()
818        self.cert_time_ok("Jan  5 09:34:61 2018 GMT", 1515144901)
819        self.cert_time_fail("Jan  5 09:34:62 2018 GMT")  # invalid seconds
820
821        # no special treatement for the special value:
822        #   99991231235959Z (rfc 5280)
823        self.cert_time_ok("Dec 31 23:59:59 9999 GMT", 253402300799.0)
824
825    @support.run_with_locale('LC_ALL', '')
826    def test_cert_time_to_seconds_locale(self):
827        # `cert_time_to_seconds()` should be locale independent
828
829        def local_february_name():
830            return time.strftime('%b', (1, 2, 3, 4, 5, 6, 0, 0, 0))
831
832        if local_february_name().lower() == 'feb':
833            self.skipTest("locale-specific month name needs to be "
834                          "different from C locale")
835
836        # locale-independent
837        self.cert_time_ok("Feb  9 00:00:00 2007 GMT", 1170979200.0)
838        self.cert_time_fail(local_february_name() + "  9 00:00:00 2007 GMT")
839
840    def test_connect_ex_error(self):
841        server = socket.socket(socket.AF_INET)
842        self.addCleanup(server.close)
843        port = support.bind_port(server)  # Reserve port but don't listen
844        s = test_wrap_socket(socket.socket(socket.AF_INET),
845                            cert_reqs=ssl.CERT_REQUIRED)
846        self.addCleanup(s.close)
847        rc = s.connect_ex((HOST, port))
848        # Issue #19919: Windows machines or VMs hosted on Windows
849        # machines sometimes return EWOULDBLOCK.
850        errors = (
851            errno.ECONNREFUSED, errno.EHOSTUNREACH, errno.ETIMEDOUT,
852            errno.EWOULDBLOCK,
853        )
854        self.assertIn(rc, errors)
855
856
857class ContextTests(unittest.TestCase):
858
859    @skip_if_broken_ubuntu_ssl
860    def test_constructor(self):
861        for protocol in PROTOCOLS:
862            ssl.SSLContext(protocol)
863        ctx = ssl.SSLContext()
864        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
865        self.assertRaises(ValueError, ssl.SSLContext, -1)
866        self.assertRaises(ValueError, ssl.SSLContext, 42)
867
868    @skip_if_broken_ubuntu_ssl
869    def test_protocol(self):
870        for proto in PROTOCOLS:
871            ctx = ssl.SSLContext(proto)
872            self.assertEqual(ctx.protocol, proto)
873
874    def test_ciphers(self):
875        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
876        ctx.set_ciphers("ALL")
877        ctx.set_ciphers("DEFAULT")
878        with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"):
879            ctx.set_ciphers("^$:,;?*'dorothyx")
880
881    @unittest.skipIf(ssl.OPENSSL_VERSION_INFO < (1, 0, 2, 0, 0), 'OpenSSL too old')
882    def test_get_ciphers(self):
883        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
884        ctx.set_ciphers('AESGCM')
885        names = set(d['name'] for d in ctx.get_ciphers())
886        self.assertIn('AES256-GCM-SHA384', names)
887        self.assertIn('AES128-GCM-SHA256', names)
888
889    @skip_if_broken_ubuntu_ssl
890    def test_options(self):
891        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
892        # OP_ALL | OP_NO_SSLv2 | OP_NO_SSLv3 is the default value
893        default = (ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
894        # SSLContext also enables these by default
895        default |= (OP_NO_COMPRESSION | OP_CIPHER_SERVER_PREFERENCE |
896                    OP_SINGLE_DH_USE | OP_SINGLE_ECDH_USE)
897        self.assertEqual(default, ctx.options)
898        ctx.options |= ssl.OP_NO_TLSv1
899        self.assertEqual(default | ssl.OP_NO_TLSv1, ctx.options)
900        if can_clear_options():
901            ctx.options = (ctx.options & ~ssl.OP_NO_TLSv1)
902            self.assertEqual(default, ctx.options)
903            ctx.options = 0
904            # Ubuntu has OP_NO_SSLv3 forced on by default
905            self.assertEqual(0, ctx.options & ~ssl.OP_NO_SSLv3)
906        else:
907            with self.assertRaises(ValueError):
908                ctx.options = 0
909
910    def test_verify_mode(self):
911        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
912        # Default value
913        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
914        ctx.verify_mode = ssl.CERT_OPTIONAL
915        self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
916        ctx.verify_mode = ssl.CERT_REQUIRED
917        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
918        ctx.verify_mode = ssl.CERT_NONE
919        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
920        with self.assertRaises(TypeError):
921            ctx.verify_mode = None
922        with self.assertRaises(ValueError):
923            ctx.verify_mode = 42
924
925    @unittest.skipUnless(have_verify_flags(),
926                         "verify_flags need OpenSSL > 0.9.8")
927    def test_verify_flags(self):
928        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
929        # default value
930        tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
931        self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT | tf)
932        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF
933        self.assertEqual(ctx.verify_flags, ssl.VERIFY_CRL_CHECK_LEAF)
934        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_CHAIN
935        self.assertEqual(ctx.verify_flags, ssl.VERIFY_CRL_CHECK_CHAIN)
936        ctx.verify_flags = ssl.VERIFY_DEFAULT
937        self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT)
938        # supports any value
939        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF | ssl.VERIFY_X509_STRICT
940        self.assertEqual(ctx.verify_flags,
941                         ssl.VERIFY_CRL_CHECK_LEAF | ssl.VERIFY_X509_STRICT)
942        with self.assertRaises(TypeError):
943            ctx.verify_flags = None
944
945    def test_load_cert_chain(self):
946        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
947        # Combined key and cert in a single file
948        ctx.load_cert_chain(CERTFILE, keyfile=None)
949        ctx.load_cert_chain(CERTFILE, keyfile=CERTFILE)
950        self.assertRaises(TypeError, ctx.load_cert_chain, keyfile=CERTFILE)
951        with self.assertRaises(OSError) as cm:
952            ctx.load_cert_chain(NONEXISTINGCERT)
953        self.assertEqual(cm.exception.errno, errno.ENOENT)
954        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
955            ctx.load_cert_chain(BADCERT)
956        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
957            ctx.load_cert_chain(EMPTYCERT)
958        # Separate key and cert
959        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
960        ctx.load_cert_chain(ONLYCERT, ONLYKEY)
961        ctx.load_cert_chain(certfile=ONLYCERT, keyfile=ONLYKEY)
962        ctx.load_cert_chain(certfile=BYTES_ONLYCERT, keyfile=BYTES_ONLYKEY)
963        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
964            ctx.load_cert_chain(ONLYCERT)
965        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
966            ctx.load_cert_chain(ONLYKEY)
967        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
968            ctx.load_cert_chain(certfile=ONLYKEY, keyfile=ONLYCERT)
969        # Mismatching key and cert
970        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
971        with self.assertRaisesRegex(ssl.SSLError, "key values mismatch"):
972            ctx.load_cert_chain(CAFILE_CACERT, ONLYKEY)
973        # Password protected key and cert
974        ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD)
975        ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD.encode())
976        ctx.load_cert_chain(CERTFILE_PROTECTED,
977                            password=bytearray(KEY_PASSWORD.encode()))
978        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD)
979        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD.encode())
980        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED,
981                            bytearray(KEY_PASSWORD.encode()))
982        with self.assertRaisesRegex(TypeError, "should be a string"):
983            ctx.load_cert_chain(CERTFILE_PROTECTED, password=True)
984        with self.assertRaises(ssl.SSLError):
985            ctx.load_cert_chain(CERTFILE_PROTECTED, password="badpass")
986        with self.assertRaisesRegex(ValueError, "cannot be longer"):
987            # openssl has a fixed limit on the password buffer.
988            # PEM_BUFSIZE is generally set to 1kb.
989            # Return a string larger than this.
990            ctx.load_cert_chain(CERTFILE_PROTECTED, password=b'a' * 102400)
991        # Password callback
992        def getpass_unicode():
993            return KEY_PASSWORD
994        def getpass_bytes():
995            return KEY_PASSWORD.encode()
996        def getpass_bytearray():
997            return bytearray(KEY_PASSWORD.encode())
998        def getpass_badpass():
999            return "badpass"
1000        def getpass_huge():
1001            return b'a' * (1024 * 1024)
1002        def getpass_bad_type():
1003            return 9
1004        def getpass_exception():
1005            raise Exception('getpass error')
1006        class GetPassCallable:
1007            def __call__(self):
1008                return KEY_PASSWORD
1009            def getpass(self):
1010                return KEY_PASSWORD
1011        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_unicode)
1012        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytes)
1013        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytearray)
1014        ctx.load_cert_chain(CERTFILE_PROTECTED, password=GetPassCallable())
1015        ctx.load_cert_chain(CERTFILE_PROTECTED,
1016                            password=GetPassCallable().getpass)
1017        with self.assertRaises(ssl.SSLError):
1018            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_badpass)
1019        with self.assertRaisesRegex(ValueError, "cannot be longer"):
1020            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_huge)
1021        with self.assertRaisesRegex(TypeError, "must return a string"):
1022            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bad_type)
1023        with self.assertRaisesRegex(Exception, "getpass error"):
1024            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_exception)
1025        # Make sure the password function isn't called if it isn't needed
1026        ctx.load_cert_chain(CERTFILE, password=getpass_exception)
1027
1028    def test_load_verify_locations(self):
1029        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1030        ctx.load_verify_locations(CERTFILE)
1031        ctx.load_verify_locations(cafile=CERTFILE, capath=None)
1032        ctx.load_verify_locations(BYTES_CERTFILE)
1033        ctx.load_verify_locations(cafile=BYTES_CERTFILE, capath=None)
1034        self.assertRaises(TypeError, ctx.load_verify_locations)
1035        self.assertRaises(TypeError, ctx.load_verify_locations, None, None, None)
1036        with self.assertRaises(OSError) as cm:
1037            ctx.load_verify_locations(NONEXISTINGCERT)
1038        self.assertEqual(cm.exception.errno, errno.ENOENT)
1039        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1040            ctx.load_verify_locations(BADCERT)
1041        ctx.load_verify_locations(CERTFILE, CAPATH)
1042        ctx.load_verify_locations(CERTFILE, capath=BYTES_CAPATH)
1043
1044        # Issue #10989: crash if the second argument type is invalid
1045        self.assertRaises(TypeError, ctx.load_verify_locations, None, True)
1046
1047    def test_load_verify_cadata(self):
1048        # test cadata
1049        with open(CAFILE_CACERT) as f:
1050            cacert_pem = f.read()
1051        cacert_der = ssl.PEM_cert_to_DER_cert(cacert_pem)
1052        with open(CAFILE_NEURONIO) as f:
1053            neuronio_pem = f.read()
1054        neuronio_der = ssl.PEM_cert_to_DER_cert(neuronio_pem)
1055
1056        # test PEM
1057        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1058        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 0)
1059        ctx.load_verify_locations(cadata=cacert_pem)
1060        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 1)
1061        ctx.load_verify_locations(cadata=neuronio_pem)
1062        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1063        # cert already in hash table
1064        ctx.load_verify_locations(cadata=neuronio_pem)
1065        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1066
1067        # combined
1068        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1069        combined = "\n".join((cacert_pem, neuronio_pem))
1070        ctx.load_verify_locations(cadata=combined)
1071        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1072
1073        # with junk around the certs
1074        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1075        combined = ["head", cacert_pem, "other", neuronio_pem, "again",
1076                    neuronio_pem, "tail"]
1077        ctx.load_verify_locations(cadata="\n".join(combined))
1078        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1079
1080        # test DER
1081        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1082        ctx.load_verify_locations(cadata=cacert_der)
1083        ctx.load_verify_locations(cadata=neuronio_der)
1084        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1085        # cert already in hash table
1086        ctx.load_verify_locations(cadata=cacert_der)
1087        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1088
1089        # combined
1090        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1091        combined = b"".join((cacert_der, neuronio_der))
1092        ctx.load_verify_locations(cadata=combined)
1093        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1094
1095        # error cases
1096        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1097        self.assertRaises(TypeError, ctx.load_verify_locations, cadata=object)
1098
1099        with self.assertRaisesRegex(ssl.SSLError, "no start line"):
1100            ctx.load_verify_locations(cadata="broken")
1101        with self.assertRaisesRegex(ssl.SSLError, "not enough data"):
1102            ctx.load_verify_locations(cadata=b"broken")
1103
1104
1105    def test_load_dh_params(self):
1106        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1107        ctx.load_dh_params(DHFILE)
1108        if os.name != 'nt':
1109            ctx.load_dh_params(BYTES_DHFILE)
1110        self.assertRaises(TypeError, ctx.load_dh_params)
1111        self.assertRaises(TypeError, ctx.load_dh_params, None)
1112        with self.assertRaises(FileNotFoundError) as cm:
1113            ctx.load_dh_params(NONEXISTINGCERT)
1114        self.assertEqual(cm.exception.errno, errno.ENOENT)
1115        with self.assertRaises(ssl.SSLError) as cm:
1116            ctx.load_dh_params(CERTFILE)
1117
1118    @skip_if_broken_ubuntu_ssl
1119    def test_session_stats(self):
1120        for proto in PROTOCOLS:
1121            ctx = ssl.SSLContext(proto)
1122            self.assertEqual(ctx.session_stats(), {
1123                'number': 0,
1124                'connect': 0,
1125                'connect_good': 0,
1126                'connect_renegotiate': 0,
1127                'accept': 0,
1128                'accept_good': 0,
1129                'accept_renegotiate': 0,
1130                'hits': 0,
1131                'misses': 0,
1132                'timeouts': 0,
1133                'cache_full': 0,
1134            })
1135
1136    def test_set_default_verify_paths(self):
1137        # There's not much we can do to test that it acts as expected,
1138        # so just check it doesn't crash or raise an exception.
1139        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1140        ctx.set_default_verify_paths()
1141
1142    @unittest.skipUnless(ssl.HAS_ECDH, "ECDH disabled on this OpenSSL build")
1143    def test_set_ecdh_curve(self):
1144        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1145        ctx.set_ecdh_curve("prime256v1")
1146        ctx.set_ecdh_curve(b"prime256v1")
1147        self.assertRaises(TypeError, ctx.set_ecdh_curve)
1148        self.assertRaises(TypeError, ctx.set_ecdh_curve, None)
1149        self.assertRaises(ValueError, ctx.set_ecdh_curve, "foo")
1150        self.assertRaises(ValueError, ctx.set_ecdh_curve, b"foo")
1151
1152    @needs_sni
1153    def test_sni_callback(self):
1154        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1155
1156        # set_servername_callback expects a callable, or None
1157        self.assertRaises(TypeError, ctx.set_servername_callback)
1158        self.assertRaises(TypeError, ctx.set_servername_callback, 4)
1159        self.assertRaises(TypeError, ctx.set_servername_callback, "")
1160        self.assertRaises(TypeError, ctx.set_servername_callback, ctx)
1161
1162        def dummycallback(sock, servername, ctx):
1163            pass
1164        ctx.set_servername_callback(None)
1165        ctx.set_servername_callback(dummycallback)
1166
1167    @needs_sni
1168    def test_sni_callback_refcycle(self):
1169        # Reference cycles through the servername callback are detected
1170        # and cleared.
1171        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1172        def dummycallback(sock, servername, ctx, cycle=ctx):
1173            pass
1174        ctx.set_servername_callback(dummycallback)
1175        wr = weakref.ref(ctx)
1176        del ctx, dummycallback
1177        gc.collect()
1178        self.assertIs(wr(), None)
1179
1180    def test_cert_store_stats(self):
1181        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1182        self.assertEqual(ctx.cert_store_stats(),
1183            {'x509_ca': 0, 'crl': 0, 'x509': 0})
1184        ctx.load_cert_chain(CERTFILE)
1185        self.assertEqual(ctx.cert_store_stats(),
1186            {'x509_ca': 0, 'crl': 0, 'x509': 0})
1187        ctx.load_verify_locations(CERTFILE)
1188        self.assertEqual(ctx.cert_store_stats(),
1189            {'x509_ca': 0, 'crl': 0, 'x509': 1})
1190        ctx.load_verify_locations(CAFILE_CACERT)
1191        self.assertEqual(ctx.cert_store_stats(),
1192            {'x509_ca': 1, 'crl': 0, 'x509': 2})
1193
1194    def test_get_ca_certs(self):
1195        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1196        self.assertEqual(ctx.get_ca_certs(), [])
1197        # CERTFILE is not flagged as X509v3 Basic Constraints: CA:TRUE
1198        ctx.load_verify_locations(CERTFILE)
1199        self.assertEqual(ctx.get_ca_certs(), [])
1200        # but CAFILE_CACERT is a CA cert
1201        ctx.load_verify_locations(CAFILE_CACERT)
1202        self.assertEqual(ctx.get_ca_certs(),
1203            [{'issuer': ((('organizationName', 'Root CA'),),
1204                         (('organizationalUnitName', 'http://www.cacert.org'),),
1205                         (('commonName', 'CA Cert Signing Authority'),),
1206                         (('emailAddress', 'support@cacert.org'),)),
1207              'notAfter': asn1time('Mar 29 12:29:49 2033 GMT'),
1208              'notBefore': asn1time('Mar 30 12:29:49 2003 GMT'),
1209              'serialNumber': '00',
1210              'crlDistributionPoints': ('https://www.cacert.org/revoke.crl',),
1211              'subject': ((('organizationName', 'Root CA'),),
1212                          (('organizationalUnitName', 'http://www.cacert.org'),),
1213                          (('commonName', 'CA Cert Signing Authority'),),
1214                          (('emailAddress', 'support@cacert.org'),)),
1215              'version': 3}])
1216
1217        with open(CAFILE_CACERT) as f:
1218            pem = f.read()
1219        der = ssl.PEM_cert_to_DER_cert(pem)
1220        self.assertEqual(ctx.get_ca_certs(True), [der])
1221
1222    def test_load_default_certs(self):
1223        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1224        ctx.load_default_certs()
1225
1226        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1227        ctx.load_default_certs(ssl.Purpose.SERVER_AUTH)
1228        ctx.load_default_certs()
1229
1230        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1231        ctx.load_default_certs(ssl.Purpose.CLIENT_AUTH)
1232
1233        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1234        self.assertRaises(TypeError, ctx.load_default_certs, None)
1235        self.assertRaises(TypeError, ctx.load_default_certs, 'SERVER_AUTH')
1236
1237    @unittest.skipIf(sys.platform == "win32", "not-Windows specific")
1238    @unittest.skipIf(IS_LIBRESSL, "LibreSSL doesn't support env vars")
1239    def test_load_default_certs_env(self):
1240        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1241        with support.EnvironmentVarGuard() as env:
1242            env["SSL_CERT_DIR"] = CAPATH
1243            env["SSL_CERT_FILE"] = CERTFILE
1244            ctx.load_default_certs()
1245            self.assertEqual(ctx.cert_store_stats(), {"crl": 0, "x509": 1, "x509_ca": 0})
1246
1247    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
1248    def test_load_default_certs_env_windows(self):
1249        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1250        ctx.load_default_certs()
1251        stats = ctx.cert_store_stats()
1252
1253        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1254        with support.EnvironmentVarGuard() as env:
1255            env["SSL_CERT_DIR"] = CAPATH
1256            env["SSL_CERT_FILE"] = CERTFILE
1257            ctx.load_default_certs()
1258            stats["x509"] += 1
1259            self.assertEqual(ctx.cert_store_stats(), stats)
1260
1261    def _assert_context_options(self, ctx):
1262        self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2)
1263        if OP_NO_COMPRESSION != 0:
1264            self.assertEqual(ctx.options & OP_NO_COMPRESSION,
1265                             OP_NO_COMPRESSION)
1266        if OP_SINGLE_DH_USE != 0:
1267            self.assertEqual(ctx.options & OP_SINGLE_DH_USE,
1268                             OP_SINGLE_DH_USE)
1269        if OP_SINGLE_ECDH_USE != 0:
1270            self.assertEqual(ctx.options & OP_SINGLE_ECDH_USE,
1271                             OP_SINGLE_ECDH_USE)
1272        if OP_CIPHER_SERVER_PREFERENCE != 0:
1273            self.assertEqual(ctx.options & OP_CIPHER_SERVER_PREFERENCE,
1274                             OP_CIPHER_SERVER_PREFERENCE)
1275
1276    def test_create_default_context(self):
1277        ctx = ssl.create_default_context()
1278
1279        self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23)
1280        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1281        self.assertTrue(ctx.check_hostname)
1282        self._assert_context_options(ctx)
1283
1284
1285        with open(SIGNING_CA) as f:
1286            cadata = f.read()
1287        ctx = ssl.create_default_context(cafile=SIGNING_CA, capath=CAPATH,
1288                                         cadata=cadata)
1289        self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23)
1290        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1291        self._assert_context_options(ctx)
1292
1293        ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
1294        self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23)
1295        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1296        self._assert_context_options(ctx)
1297
1298    def test__create_stdlib_context(self):
1299        ctx = ssl._create_stdlib_context()
1300        self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23)
1301        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1302        self.assertFalse(ctx.check_hostname)
1303        self._assert_context_options(ctx)
1304
1305        ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1)
1306        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1)
1307        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1308        self._assert_context_options(ctx)
1309
1310        ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1,
1311                                         cert_reqs=ssl.CERT_REQUIRED,
1312                                         check_hostname=True)
1313        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1)
1314        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1315        self.assertTrue(ctx.check_hostname)
1316        self._assert_context_options(ctx)
1317
1318        ctx = ssl._create_stdlib_context(purpose=ssl.Purpose.CLIENT_AUTH)
1319        self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23)
1320        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1321        self._assert_context_options(ctx)
1322
1323    def test_check_hostname(self):
1324        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1325        self.assertFalse(ctx.check_hostname)
1326
1327        # Requires CERT_REQUIRED or CERT_OPTIONAL
1328        with self.assertRaises(ValueError):
1329            ctx.check_hostname = True
1330        ctx.verify_mode = ssl.CERT_REQUIRED
1331        self.assertFalse(ctx.check_hostname)
1332        ctx.check_hostname = True
1333        self.assertTrue(ctx.check_hostname)
1334
1335        ctx.verify_mode = ssl.CERT_OPTIONAL
1336        ctx.check_hostname = True
1337        self.assertTrue(ctx.check_hostname)
1338
1339        # Cannot set CERT_NONE with check_hostname enabled
1340        with self.assertRaises(ValueError):
1341            ctx.verify_mode = ssl.CERT_NONE
1342        ctx.check_hostname = False
1343        self.assertFalse(ctx.check_hostname)
1344
1345    def test_context_client_server(self):
1346        # PROTOCOL_TLS_CLIENT has sane defaults
1347        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1348        self.assertTrue(ctx.check_hostname)
1349        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1350
1351        # PROTOCOL_TLS_SERVER has different but also sane defaults
1352        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1353        self.assertFalse(ctx.check_hostname)
1354        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1355
1356
1357class SSLErrorTests(unittest.TestCase):
1358
1359    def test_str(self):
1360        # The str() of a SSLError doesn't include the errno
1361        e = ssl.SSLError(1, "foo")
1362        self.assertEqual(str(e), "foo")
1363        self.assertEqual(e.errno, 1)
1364        # Same for a subclass
1365        e = ssl.SSLZeroReturnError(1, "foo")
1366        self.assertEqual(str(e), "foo")
1367        self.assertEqual(e.errno, 1)
1368
1369    def test_lib_reason(self):
1370        # Test the library and reason attributes
1371        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1372        with self.assertRaises(ssl.SSLError) as cm:
1373            ctx.load_dh_params(CERTFILE)
1374        self.assertEqual(cm.exception.library, 'PEM')
1375        self.assertEqual(cm.exception.reason, 'NO_START_LINE')
1376        s = str(cm.exception)
1377        self.assertTrue(s.startswith("[PEM: NO_START_LINE] no start line"), s)
1378
1379    def test_subclass(self):
1380        # Check that the appropriate SSLError subclass is raised
1381        # (this only tests one of them)
1382        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1383        with socket.socket() as s:
1384            s.bind(("127.0.0.1", 0))
1385            s.listen()
1386            c = socket.socket()
1387            c.connect(s.getsockname())
1388            c.setblocking(False)
1389            with ctx.wrap_socket(c, False, do_handshake_on_connect=False) as c:
1390                with self.assertRaises(ssl.SSLWantReadError) as cm:
1391                    c.do_handshake()
1392                s = str(cm.exception)
1393                self.assertTrue(s.startswith("The operation did not complete (read)"), s)
1394                # For compatibility
1395                self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ)
1396
1397
1398class MemoryBIOTests(unittest.TestCase):
1399
1400    def test_read_write(self):
1401        bio = ssl.MemoryBIO()
1402        bio.write(b'foo')
1403        self.assertEqual(bio.read(), b'foo')
1404        self.assertEqual(bio.read(), b'')
1405        bio.write(b'foo')
1406        bio.write(b'bar')
1407        self.assertEqual(bio.read(), b'foobar')
1408        self.assertEqual(bio.read(), b'')
1409        bio.write(b'baz')
1410        self.assertEqual(bio.read(2), b'ba')
1411        self.assertEqual(bio.read(1), b'z')
1412        self.assertEqual(bio.read(1), b'')
1413
1414    def test_eof(self):
1415        bio = ssl.MemoryBIO()
1416        self.assertFalse(bio.eof)
1417        self.assertEqual(bio.read(), b'')
1418        self.assertFalse(bio.eof)
1419        bio.write(b'foo')
1420        self.assertFalse(bio.eof)
1421        bio.write_eof()
1422        self.assertFalse(bio.eof)
1423        self.assertEqual(bio.read(2), b'fo')
1424        self.assertFalse(bio.eof)
1425        self.assertEqual(bio.read(1), b'o')
1426        self.assertTrue(bio.eof)
1427        self.assertEqual(bio.read(), b'')
1428        self.assertTrue(bio.eof)
1429
1430    def test_pending(self):
1431        bio = ssl.MemoryBIO()
1432        self.assertEqual(bio.pending, 0)
1433        bio.write(b'foo')
1434        self.assertEqual(bio.pending, 3)
1435        for i in range(3):
1436            bio.read(1)
1437            self.assertEqual(bio.pending, 3-i-1)
1438        for i in range(3):
1439            bio.write(b'x')
1440            self.assertEqual(bio.pending, i+1)
1441        bio.read()
1442        self.assertEqual(bio.pending, 0)
1443
1444    def test_buffer_types(self):
1445        bio = ssl.MemoryBIO()
1446        bio.write(b'foo')
1447        self.assertEqual(bio.read(), b'foo')
1448        bio.write(bytearray(b'bar'))
1449        self.assertEqual(bio.read(), b'bar')
1450        bio.write(memoryview(b'baz'))
1451        self.assertEqual(bio.read(), b'baz')
1452
1453    def test_error_types(self):
1454        bio = ssl.MemoryBIO()
1455        self.assertRaises(TypeError, bio.write, 'foo')
1456        self.assertRaises(TypeError, bio.write, None)
1457        self.assertRaises(TypeError, bio.write, True)
1458        self.assertRaises(TypeError, bio.write, 1)
1459
1460
1461@unittest.skipUnless(_have_threads, "Needs threading module")
1462class SimpleBackgroundTests(unittest.TestCase):
1463
1464    """Tests that connect to a simple server running in the background"""
1465
1466    def setUp(self):
1467        server = ThreadedEchoServer(SIGNED_CERTFILE)
1468        self.server_addr = (HOST, server.port)
1469        server.__enter__()
1470        self.addCleanup(server.__exit__, None, None, None)
1471
1472    def test_connect(self):
1473        with test_wrap_socket(socket.socket(socket.AF_INET),
1474                            cert_reqs=ssl.CERT_NONE) as s:
1475            s.connect(self.server_addr)
1476            self.assertEqual({}, s.getpeercert())
1477            self.assertFalse(s.server_side)
1478
1479        # this should succeed because we specify the root cert
1480        with test_wrap_socket(socket.socket(socket.AF_INET),
1481                            cert_reqs=ssl.CERT_REQUIRED,
1482                            ca_certs=SIGNING_CA) as s:
1483            s.connect(self.server_addr)
1484            self.assertTrue(s.getpeercert())
1485            self.assertFalse(s.server_side)
1486
1487    def test_connect_fail(self):
1488        # This should fail because we have no verification certs. Connection
1489        # failure crashes ThreadedEchoServer, so run this in an independent
1490        # test method.
1491        s = test_wrap_socket(socket.socket(socket.AF_INET),
1492                            cert_reqs=ssl.CERT_REQUIRED)
1493        self.addCleanup(s.close)
1494        self.assertRaisesRegex(ssl.SSLError, "certificate verify failed",
1495                               s.connect, self.server_addr)
1496
1497    def test_connect_ex(self):
1498        # Issue #11326: check connect_ex() implementation
1499        s = test_wrap_socket(socket.socket(socket.AF_INET),
1500                            cert_reqs=ssl.CERT_REQUIRED,
1501                            ca_certs=SIGNING_CA)
1502        self.addCleanup(s.close)
1503        self.assertEqual(0, s.connect_ex(self.server_addr))
1504        self.assertTrue(s.getpeercert())
1505
1506    def test_non_blocking_connect_ex(self):
1507        # Issue #11326: non-blocking connect_ex() should allow handshake
1508        # to proceed after the socket gets ready.
1509        s = test_wrap_socket(socket.socket(socket.AF_INET),
1510                            cert_reqs=ssl.CERT_REQUIRED,
1511                            ca_certs=SIGNING_CA,
1512                            do_handshake_on_connect=False)
1513        self.addCleanup(s.close)
1514        s.setblocking(False)
1515        rc = s.connect_ex(self.server_addr)
1516        # EWOULDBLOCK under Windows, EINPROGRESS elsewhere
1517        self.assertIn(rc, (0, errno.EINPROGRESS, errno.EWOULDBLOCK))
1518        # Wait for connect to finish
1519        select.select([], [s], [], 5.0)
1520        # Non-blocking handshake
1521        while True:
1522            try:
1523                s.do_handshake()
1524                break
1525            except ssl.SSLWantReadError:
1526                select.select([s], [], [], 5.0)
1527            except ssl.SSLWantWriteError:
1528                select.select([], [s], [], 5.0)
1529        # SSL established
1530        self.assertTrue(s.getpeercert())
1531
1532    def test_connect_with_context(self):
1533        # Same as test_connect, but with a separately created context
1534        ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
1535        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
1536            s.connect(self.server_addr)
1537            self.assertEqual({}, s.getpeercert())
1538        # Same with a server hostname
1539        with ctx.wrap_socket(socket.socket(socket.AF_INET),
1540                            server_hostname="dummy") as s:
1541            s.connect(self.server_addr)
1542        ctx.verify_mode = ssl.CERT_REQUIRED
1543        # This should succeed because we specify the root cert
1544        ctx.load_verify_locations(SIGNING_CA)
1545        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
1546            s.connect(self.server_addr)
1547            cert = s.getpeercert()
1548            self.assertTrue(cert)
1549
1550    def test_connect_with_context_fail(self):
1551        # This should fail because we have no verification certs. Connection
1552        # failure crashes ThreadedEchoServer, so run this in an independent
1553        # test method.
1554        ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
1555        ctx.verify_mode = ssl.CERT_REQUIRED
1556        s = ctx.wrap_socket(socket.socket(socket.AF_INET))
1557        self.addCleanup(s.close)
1558        self.assertRaisesRegex(ssl.SSLError, "certificate verify failed",
1559                                s.connect, self.server_addr)
1560
1561    def test_connect_capath(self):
1562        # Verify server certificates using the `capath` argument
1563        # NOTE: the subject hashing algorithm has been changed between
1564        # OpenSSL 0.9.8n and 1.0.0, as a result the capath directory must
1565        # contain both versions of each certificate (same content, different
1566        # filename) for this test to be portable across OpenSSL releases.
1567        ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
1568        ctx.verify_mode = ssl.CERT_REQUIRED
1569        ctx.load_verify_locations(capath=CAPATH)
1570        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
1571            s.connect(self.server_addr)
1572            cert = s.getpeercert()
1573            self.assertTrue(cert)
1574        # Same with a bytes `capath` argument
1575        ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
1576        ctx.verify_mode = ssl.CERT_REQUIRED
1577        ctx.load_verify_locations(capath=BYTES_CAPATH)
1578        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
1579            s.connect(self.server_addr)
1580            cert = s.getpeercert()
1581            self.assertTrue(cert)
1582
1583    def test_connect_cadata(self):
1584        with open(SIGNING_CA) as f:
1585            pem = f.read()
1586        der = ssl.PEM_cert_to_DER_cert(pem)
1587        ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
1588        ctx.verify_mode = ssl.CERT_REQUIRED
1589        ctx.load_verify_locations(cadata=pem)
1590        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
1591            s.connect(self.server_addr)
1592            cert = s.getpeercert()
1593            self.assertTrue(cert)
1594
1595        # same with DER
1596        ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
1597        ctx.verify_mode = ssl.CERT_REQUIRED
1598        ctx.load_verify_locations(cadata=der)
1599        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
1600            s.connect(self.server_addr)
1601            cert = s.getpeercert()
1602            self.assertTrue(cert)
1603
1604    @unittest.skipIf(os.name == "nt", "Can't use a socket as a file under Windows")
1605    def test_makefile_close(self):
1606        # Issue #5238: creating a file-like object with makefile() shouldn't
1607        # delay closing the underlying "real socket" (here tested with its
1608        # file descriptor, hence skipping the test under Windows).
1609        ss = test_wrap_socket(socket.socket(socket.AF_INET))
1610        ss.connect(self.server_addr)
1611        fd = ss.fileno()
1612        f = ss.makefile()
1613        f.close()
1614        # The fd is still open
1615        os.read(fd, 0)
1616        # Closing the SSL socket should close the fd too
1617        ss.close()
1618        gc.collect()
1619        with self.assertRaises(OSError) as e:
1620            os.read(fd, 0)
1621        self.assertEqual(e.exception.errno, errno.EBADF)
1622
1623    def test_non_blocking_handshake(self):
1624        s = socket.socket(socket.AF_INET)
1625        s.connect(self.server_addr)
1626        s.setblocking(False)
1627        s = test_wrap_socket(s,
1628                            cert_reqs=ssl.CERT_NONE,
1629                            do_handshake_on_connect=False)
1630        self.addCleanup(s.close)
1631        count = 0
1632        while True:
1633            try:
1634                count += 1
1635                s.do_handshake()
1636                break
1637            except ssl.SSLWantReadError:
1638                select.select([s], [], [])
1639            except ssl.SSLWantWriteError:
1640                select.select([], [s], [])
1641        if support.verbose:
1642            sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
1643
1644    def test_get_server_certificate(self):
1645        _test_get_server_certificate(self, *self.server_addr, cert=SIGNING_CA)
1646
1647    def test_get_server_certificate_fail(self):
1648        # Connection failure crashes ThreadedEchoServer, so run this in an
1649        # independent test method
1650        _test_get_server_certificate_fail(self, *self.server_addr)
1651
1652    def test_ciphers(self):
1653        with test_wrap_socket(socket.socket(socket.AF_INET),
1654                             cert_reqs=ssl.CERT_NONE, ciphers="ALL") as s:
1655            s.connect(self.server_addr)
1656        with test_wrap_socket(socket.socket(socket.AF_INET),
1657                             cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT") as s:
1658            s.connect(self.server_addr)
1659        # Error checking can happen at instantiation or when connecting
1660        with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"):
1661            with socket.socket(socket.AF_INET) as sock:
1662                s = test_wrap_socket(sock,
1663                                    cert_reqs=ssl.CERT_NONE, ciphers="^$:,;?*'dorothyx")
1664                s.connect(self.server_addr)
1665
1666    def test_get_ca_certs_capath(self):
1667        # capath certs are loaded on request
1668        ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
1669        ctx.verify_mode = ssl.CERT_REQUIRED
1670        ctx.load_verify_locations(capath=CAPATH)
1671        self.assertEqual(ctx.get_ca_certs(), [])
1672        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
1673            s.connect(self.server_addr)
1674            cert = s.getpeercert()
1675            self.assertTrue(cert)
1676        self.assertEqual(len(ctx.get_ca_certs()), 1)
1677
1678    @needs_sni
1679    def test_context_setget(self):
1680        # Check that the context of a connected socket can be replaced.
1681        ctx1 = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1682        ctx2 = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
1683        s = socket.socket(socket.AF_INET)
1684        with ctx1.wrap_socket(s) as ss:
1685            ss.connect(self.server_addr)
1686            self.assertIs(ss.context, ctx1)
1687            self.assertIs(ss._sslobj.context, ctx1)
1688            ss.context = ctx2
1689            self.assertIs(ss.context, ctx2)
1690            self.assertIs(ss._sslobj.context, ctx2)
1691
1692    def ssl_io_loop(self, sock, incoming, outgoing, func, *args, **kwargs):
1693        # A simple IO loop. Call func(*args) depending on the error we get
1694        # (WANT_READ or WANT_WRITE) move data between the socket and the BIOs.
1695        timeout = kwargs.get('timeout', 10)
1696        count = 0
1697        while True:
1698            errno = None
1699            count += 1
1700            try:
1701                ret = func(*args)
1702            except ssl.SSLError as e:
1703                if e.errno not in (ssl.SSL_ERROR_WANT_READ,
1704                                   ssl.SSL_ERROR_WANT_WRITE):
1705                    raise
1706                errno = e.errno
1707            # Get any data from the outgoing BIO irrespective of any error, and
1708            # send it to the socket.
1709            buf = outgoing.read()
1710            sock.sendall(buf)
1711            # If there's no error, we're done. For WANT_READ, we need to get
1712            # data from the socket and put it in the incoming BIO.
1713            if errno is None:
1714                break
1715            elif errno == ssl.SSL_ERROR_WANT_READ:
1716                buf = sock.recv(32768)
1717                if buf:
1718                    incoming.write(buf)
1719                else:
1720                    incoming.write_eof()
1721        if support.verbose:
1722            sys.stdout.write("Needed %d calls to complete %s().\n"
1723                             % (count, func.__name__))
1724        return ret
1725
1726    def test_bio_handshake(self):
1727        sock = socket.socket(socket.AF_INET)
1728        self.addCleanup(sock.close)
1729        sock.connect(self.server_addr)
1730        incoming = ssl.MemoryBIO()
1731        outgoing = ssl.MemoryBIO()
1732        ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
1733        ctx.verify_mode = ssl.CERT_REQUIRED
1734        ctx.load_verify_locations(SIGNING_CA)
1735        ctx.check_hostname = True
1736        sslobj = ctx.wrap_bio(incoming, outgoing, False, 'localhost')
1737        self.assertIs(sslobj._sslobj.owner, sslobj)
1738        self.assertIsNone(sslobj.cipher())
1739        self.assertIsNotNone(sslobj.shared_ciphers())
1740        self.assertRaises(ValueError, sslobj.getpeercert)
1741        if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
1742            self.assertIsNone(sslobj.get_channel_binding('tls-unique'))
1743        self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
1744        self.assertTrue(sslobj.cipher())
1745        self.assertIsNotNone(sslobj.shared_ciphers())
1746        self.assertTrue(sslobj.getpeercert())
1747        if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
1748            self.assertTrue(sslobj.get_channel_binding('tls-unique'))
1749        try:
1750            self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
1751        except ssl.SSLSyscallError:
1752            # If the server shuts down the TCP connection without sending a
1753            # secure shutdown message, this is reported as SSL_ERROR_SYSCALL
1754            pass
1755        self.assertRaises(ssl.SSLError, sslobj.write, b'foo')
1756
1757    def test_bio_read_write_data(self):
1758        sock = socket.socket(socket.AF_INET)
1759        self.addCleanup(sock.close)
1760        sock.connect(self.server_addr)
1761        incoming = ssl.MemoryBIO()
1762        outgoing = ssl.MemoryBIO()
1763        ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
1764        ctx.verify_mode = ssl.CERT_NONE
1765        sslobj = ctx.wrap_bio(incoming, outgoing, False)
1766        self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
1767        req = b'FOO\n'
1768        self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req)
1769        buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024)
1770        self.assertEqual(buf, b'foo\n')
1771        self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
1772
1773
1774class NetworkedTests(unittest.TestCase):
1775
1776    def test_timeout_connect_ex(self):
1777        # Issue #12065: on a timeout, connect_ex() should return the original
1778        # errno (mimicking the behaviour of non-SSL sockets).
1779        with support.transient_internet(REMOTE_HOST):
1780            s = test_wrap_socket(socket.socket(socket.AF_INET),
1781                                cert_reqs=ssl.CERT_REQUIRED,
1782                                do_handshake_on_connect=False)
1783            self.addCleanup(s.close)
1784            s.settimeout(0.0000001)
1785            rc = s.connect_ex((REMOTE_HOST, 443))
1786            if rc == 0:
1787                self.skipTest("REMOTE_HOST responded too quickly")
1788            self.assertIn(rc, (errno.EAGAIN, errno.EWOULDBLOCK))
1789
1790    @unittest.skipUnless(support.IPV6_ENABLED, 'Needs IPv6')
1791    def test_get_server_certificate_ipv6(self):
1792        with support.transient_internet('ipv6.google.com'):
1793            _test_get_server_certificate(self, 'ipv6.google.com', 443)
1794            _test_get_server_certificate_fail(self, 'ipv6.google.com', 443)
1795
1796    def test_algorithms(self):
1797        # Issue #8484: all algorithms should be available when verifying a
1798        # certificate.
1799        # SHA256 was added in OpenSSL 0.9.8
1800        if ssl.OPENSSL_VERSION_INFO < (0, 9, 8, 0, 15):
1801            self.skipTest("SHA256 not available on %r" % ssl.OPENSSL_VERSION)
1802        # sha256.tbs-internet.com needs SNI to use the correct certificate
1803        if not ssl.HAS_SNI:
1804            self.skipTest("SNI needed for this test")
1805        # https://sha2.hboeck.de/ was used until 2011-01-08 (no route to host)
1806        remote = ("sha256.tbs-internet.com", 443)
1807        sha256_cert = os.path.join(os.path.dirname(__file__), "sha256.pem")
1808        with support.transient_internet("sha256.tbs-internet.com"):
1809            ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1810            ctx.verify_mode = ssl.CERT_REQUIRED
1811            ctx.load_verify_locations(sha256_cert)
1812            s = ctx.wrap_socket(socket.socket(socket.AF_INET),
1813                                server_hostname="sha256.tbs-internet.com")
1814            try:
1815                s.connect(remote)
1816                if support.verbose:
1817                    sys.stdout.write("\nCipher with %r is %r\n" %
1818                                     (remote, s.cipher()))
1819                    sys.stdout.write("Certificate is:\n%s\n" %
1820                                     pprint.pformat(s.getpeercert()))
1821            finally:
1822                s.close()
1823
1824
1825def _test_get_server_certificate(test, host, port, cert=None):
1826    pem = ssl.get_server_certificate((host, port))
1827    if not pem:
1828        test.fail("No server certificate on %s:%s!" % (host, port))
1829
1830    pem = ssl.get_server_certificate((host, port), ca_certs=cert)
1831    if not pem:
1832        test.fail("No server certificate on %s:%s!" % (host, port))
1833    if support.verbose:
1834        sys.stdout.write("\nVerified certificate for %s:%s is\n%s\n" % (host, port ,pem))
1835
1836def _test_get_server_certificate_fail(test, host, port):
1837    try:
1838        pem = ssl.get_server_certificate((host, port), ca_certs=CERTFILE)
1839    except ssl.SSLError as x:
1840        #should fail
1841        if support.verbose:
1842            sys.stdout.write("%s\n" % x)
1843    else:
1844        test.fail("Got server certificate %s for %s:%s!" % (pem, host, port))
1845
1846
1847if _have_threads:
1848    from test.ssl_servers import make_https_server
1849
1850    class ThreadedEchoServer(threading.Thread):
1851
1852        class ConnectionHandler(threading.Thread):
1853
1854            """A mildly complicated class, because we want it to work both
1855            with and without the SSL wrapper around the socket connection, so
1856            that we can test the STARTTLS functionality."""
1857
1858            def __init__(self, server, connsock, addr):
1859                self.server = server
1860                self.running = False
1861                self.sock = connsock
1862                self.addr = addr
1863                self.sock.setblocking(1)
1864                self.sslconn = None
1865                threading.Thread.__init__(self)
1866                self.daemon = True
1867
1868            def wrap_conn(self):
1869                try:
1870                    self.sslconn = self.server.context.wrap_socket(
1871                        self.sock, server_side=True)
1872                    self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol())
1873                    self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol())
1874                except (ssl.SSLError, ConnectionResetError) as e:
1875                    # We treat ConnectionResetError as though it were an
1876                    # SSLError - OpenSSL on Ubuntu abruptly closes the
1877                    # connection when asked to use an unsupported protocol.
1878                    #
1879                    # XXX Various errors can have happened here, for example
1880                    # a mismatching protocol version, an invalid certificate,
1881                    # or a low-level bug. This should be made more discriminating.
1882                    self.server.conn_errors.append(e)
1883                    if self.server.chatty:
1884                        handle_error("\n server:  bad connection attempt from " + repr(self.addr) + ":\n")
1885                    self.running = False
1886                    self.server.stop()
1887                    self.close()
1888                    return False
1889                else:
1890                    self.server.shared_ciphers.append(self.sslconn.shared_ciphers())
1891                    if self.server.context.verify_mode == ssl.CERT_REQUIRED:
1892                        cert = self.sslconn.getpeercert()
1893                        if support.verbose and self.server.chatty:
1894                            sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n")
1895                        cert_binary = self.sslconn.getpeercert(True)
1896                        if support.verbose and self.server.chatty:
1897                            sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n")
1898                    cipher = self.sslconn.cipher()
1899                    if support.verbose and self.server.chatty:
1900                        sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
1901                        sys.stdout.write(" server: selected protocol is now "
1902                                + str(self.sslconn.selected_npn_protocol()) + "\n")
1903                    return True
1904
1905            def read(self):
1906                if self.sslconn:
1907                    return self.sslconn.read()
1908                else:
1909                    return self.sock.recv(1024)
1910
1911            def write(self, bytes):
1912                if self.sslconn:
1913                    return self.sslconn.write(bytes)
1914                else:
1915                    return self.sock.send(bytes)
1916
1917            def close(self):
1918                if self.sslconn:
1919                    self.sslconn.close()
1920                else:
1921                    self.sock.close()
1922
1923            def run(self):
1924                self.running = True
1925                if not self.server.starttls_server:
1926                    if not self.wrap_conn():
1927                        return
1928                while self.running:
1929                    try:
1930                        msg = self.read()
1931                        stripped = msg.strip()
1932                        if not stripped:
1933                            # eof, so quit this handler
1934                            self.running = False
1935                            try:
1936                                self.sock = self.sslconn.unwrap()
1937                            except OSError:
1938                                # Many tests shut the TCP connection down
1939                                # without an SSL shutdown. This causes
1940                                # unwrap() to raise OSError with errno=0!
1941                                pass
1942                            else:
1943                                self.sslconn = None
1944                            self.close()
1945                        elif stripped == b'over':
1946                            if support.verbose and self.server.connectionchatty:
1947                                sys.stdout.write(" server: client closed connection\n")
1948                            self.close()
1949                            return
1950                        elif (self.server.starttls_server and
1951                              stripped == b'STARTTLS'):
1952                            if support.verbose and self.server.connectionchatty:
1953                                sys.stdout.write(" server: read STARTTLS from client, sending OK...\n")
1954                            self.write(b"OK\n")
1955                            if not self.wrap_conn():
1956                                return
1957                        elif (self.server.starttls_server and self.sslconn
1958                              and stripped == b'ENDTLS'):
1959                            if support.verbose and self.server.connectionchatty:
1960                                sys.stdout.write(" server: read ENDTLS from client, sending OK...\n")
1961                            self.write(b"OK\n")
1962                            self.sock = self.sslconn.unwrap()
1963                            self.sslconn = None
1964                            if support.verbose and self.server.connectionchatty:
1965                                sys.stdout.write(" server: connection is now unencrypted...\n")
1966                        elif stripped == b'CB tls-unique':
1967                            if support.verbose and self.server.connectionchatty:
1968                                sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n")
1969                            data = self.sslconn.get_channel_binding("tls-unique")
1970                            self.write(repr(data).encode("us-ascii") + b"\n")
1971                        else:
1972                            if (support.verbose and
1973                                self.server.connectionchatty):
1974                                ctype = (self.sslconn and "encrypted") or "unencrypted"
1975                                sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n"
1976                                                 % (msg, ctype, msg.lower(), ctype))
1977                            self.write(msg.lower())
1978                    except OSError:
1979                        if self.server.chatty:
1980                            handle_error("Test server failure:\n")
1981                        self.close()
1982                        self.running = False
1983                        # normally, we'd just stop here, but for the test
1984                        # harness, we want to stop the server
1985                        self.server.stop()
1986
1987        def __init__(self, certificate=None, ssl_version=None,
1988                     certreqs=None, cacerts=None,
1989                     chatty=True, connectionchatty=False, starttls_server=False,
1990                     npn_protocols=None, alpn_protocols=None,
1991                     ciphers=None, context=None):
1992            if context:
1993                self.context = context
1994            else:
1995                self.context = ssl.SSLContext(ssl_version
1996                                              if ssl_version is not None
1997                                              else ssl.PROTOCOL_TLSv1)
1998                self.context.verify_mode = (certreqs if certreqs is not None
1999                                            else ssl.CERT_NONE)
2000                if cacerts:
2001                    self.context.load_verify_locations(cacerts)
2002                if certificate:
2003                    self.context.load_cert_chain(certificate)
2004                if npn_protocols:
2005                    self.context.set_npn_protocols(npn_protocols)
2006                if alpn_protocols:
2007                    self.context.set_alpn_protocols(alpn_protocols)
2008                if ciphers:
2009                    self.context.set_ciphers(ciphers)
2010            self.chatty = chatty
2011            self.connectionchatty = connectionchatty
2012            self.starttls_server = starttls_server
2013            self.sock = socket.socket()
2014            self.port = support.bind_port(self.sock)
2015            self.flag = None
2016            self.active = False
2017            self.selected_npn_protocols = []
2018            self.selected_alpn_protocols = []
2019            self.shared_ciphers = []
2020            self.conn_errors = []
2021            threading.Thread.__init__(self)
2022            self.daemon = True
2023
2024        def __enter__(self):
2025            self.start(threading.Event())
2026            self.flag.wait()
2027            return self
2028
2029        def __exit__(self, *args):
2030            self.stop()
2031            self.join()
2032
2033        def start(self, flag=None):
2034            self.flag = flag
2035            threading.Thread.start(self)
2036
2037        def run(self):
2038            self.sock.settimeout(0.05)
2039            self.sock.listen()
2040            self.active = True
2041            if self.flag:
2042                # signal an event
2043                self.flag.set()
2044            while self.active:
2045                try:
2046                    newconn, connaddr = self.sock.accept()
2047                    if support.verbose and self.chatty:
2048                        sys.stdout.write(' server:  new connection from '
2049                                         + repr(connaddr) + '\n')
2050                    handler = self.ConnectionHandler(self, newconn, connaddr)
2051                    handler.start()
2052                    handler.join()
2053                except socket.timeout:
2054                    pass
2055                except KeyboardInterrupt:
2056                    self.stop()
2057            self.sock.close()
2058
2059        def stop(self):
2060            self.active = False
2061
2062    class AsyncoreEchoServer(threading.Thread):
2063
2064        # this one's based on asyncore.dispatcher
2065
2066        class EchoServer (asyncore.dispatcher):
2067
2068            class ConnectionHandler (asyncore.dispatcher_with_send):
2069
2070                def __init__(self, conn, certfile):
2071                    self.socket = test_wrap_socket(conn, server_side=True,
2072                                                  certfile=certfile,
2073                                                  do_handshake_on_connect=False)
2074                    asyncore.dispatcher_with_send.__init__(self, self.socket)
2075                    self._ssl_accepting = True
2076                    self._do_ssl_handshake()
2077
2078                def readable(self):
2079                    if isinstance(self.socket, ssl.SSLSocket):
2080                        while self.socket.pending() > 0:
2081                            self.handle_read_event()
2082                    return True
2083
2084                def _do_ssl_handshake(self):
2085                    try:
2086                        self.socket.do_handshake()
2087                    except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
2088                        return
2089                    except ssl.SSLEOFError:
2090                        return self.handle_close()
2091                    except ssl.SSLError:
2092                        raise
2093                    except OSError as err:
2094                        if err.args[0] == errno.ECONNABORTED:
2095                            return self.handle_close()
2096                    else:
2097                        self._ssl_accepting = False
2098
2099                def handle_read(self):
2100                    if self._ssl_accepting:
2101                        self._do_ssl_handshake()
2102                    else:
2103                        data = self.recv(1024)
2104                        if support.verbose:
2105                            sys.stdout.write(" server:  read %s from client\n" % repr(data))
2106                        if not data:
2107                            self.close()
2108                        else:
2109                            self.send(data.lower())
2110
2111                def handle_close(self):
2112                    self.close()
2113                    if support.verbose:
2114                        sys.stdout.write(" server:  closed connection %s\n" % self.socket)
2115
2116                def handle_error(self):
2117                    raise
2118
2119            def __init__(self, certfile):
2120                self.certfile = certfile
2121                sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
2122                self.port = support.bind_port(sock, '')
2123                asyncore.dispatcher.__init__(self, sock)
2124                self.listen(5)
2125
2126            def handle_accepted(self, sock_obj, addr):
2127                if support.verbose:
2128                    sys.stdout.write(" server:  new connection from %s:%s\n" %addr)
2129                self.ConnectionHandler(sock_obj, self.certfile)
2130
2131            def handle_error(self):
2132                raise
2133
2134        def __init__(self, certfile):
2135            self.flag = None
2136            self.active = False
2137            self.server = self.EchoServer(certfile)
2138            self.port = self.server.port
2139            threading.Thread.__init__(self)
2140            self.daemon = True
2141
2142        def __str__(self):
2143            return "<%s %s>" % (self.__class__.__name__, self.server)
2144
2145        def __enter__(self):
2146            self.start(threading.Event())
2147            self.flag.wait()
2148            return self
2149
2150        def __exit__(self, *args):
2151            if support.verbose:
2152                sys.stdout.write(" cleanup: stopping server.\n")
2153            self.stop()
2154            if support.verbose:
2155                sys.stdout.write(" cleanup: joining server thread.\n")
2156            self.join()
2157            if support.verbose:
2158                sys.stdout.write(" cleanup: successfully joined.\n")
2159
2160        def start (self, flag=None):
2161            self.flag = flag
2162            threading.Thread.start(self)
2163
2164        def run(self):
2165            self.active = True
2166            if self.flag:
2167                self.flag.set()
2168            while self.active:
2169                try:
2170                    asyncore.loop(1)
2171                except:
2172                    pass
2173
2174        def stop(self):
2175            self.active = False
2176            self.server.close()
2177
2178    def server_params_test(client_context, server_context, indata=b"FOO\n",
2179                           chatty=True, connectionchatty=False, sni_name=None,
2180                           session=None):
2181        """
2182        Launch a server, connect a client to it and try various reads
2183        and writes.
2184        """
2185        stats = {}
2186        server = ThreadedEchoServer(context=server_context,
2187                                    chatty=chatty,
2188                                    connectionchatty=False)
2189        with server:
2190            with client_context.wrap_socket(socket.socket(),
2191                    server_hostname=sni_name, session=session) as s:
2192                s.connect((HOST, server.port))
2193                for arg in [indata, bytearray(indata), memoryview(indata)]:
2194                    if connectionchatty:
2195                        if support.verbose:
2196                            sys.stdout.write(
2197                                " client:  sending %r...\n" % indata)
2198                    s.write(arg)
2199                    outdata = s.read()
2200                    if connectionchatty:
2201                        if support.verbose:
2202                            sys.stdout.write(" client:  read %r\n" % outdata)
2203                    if outdata != indata.lower():
2204                        raise AssertionError(
2205                            "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
2206                            % (outdata[:20], len(outdata),
2207                               indata[:20].lower(), len(indata)))
2208                s.write(b"over\n")
2209                if connectionchatty:
2210                    if support.verbose:
2211                        sys.stdout.write(" client:  closing connection.\n")
2212                stats.update({
2213                    'compression': s.compression(),
2214                    'cipher': s.cipher(),
2215                    'peercert': s.getpeercert(),
2216                    'client_alpn_protocol': s.selected_alpn_protocol(),
2217                    'client_npn_protocol': s.selected_npn_protocol(),
2218                    'version': s.version(),
2219                    'session_reused': s.session_reused,
2220                    'session': s.session,
2221                })
2222                s.close()
2223            stats['server_alpn_protocols'] = server.selected_alpn_protocols
2224            stats['server_npn_protocols'] = server.selected_npn_protocols
2225            stats['server_shared_ciphers'] = server.shared_ciphers
2226        return stats
2227
2228    def try_protocol_combo(server_protocol, client_protocol, expect_success,
2229                           certsreqs=None, server_options=0, client_options=0):
2230        """
2231        Try to SSL-connect using *client_protocol* to *server_protocol*.
2232        If *expect_success* is true, assert that the connection succeeds,
2233        if it's false, assert that the connection fails.
2234        Also, if *expect_success* is a string, assert that it is the protocol
2235        version actually used by the connection.
2236        """
2237        if certsreqs is None:
2238            certsreqs = ssl.CERT_NONE
2239        certtype = {
2240            ssl.CERT_NONE: "CERT_NONE",
2241            ssl.CERT_OPTIONAL: "CERT_OPTIONAL",
2242            ssl.CERT_REQUIRED: "CERT_REQUIRED",
2243        }[certsreqs]
2244        if support.verbose:
2245            formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n"
2246            sys.stdout.write(formatstr %
2247                             (ssl.get_protocol_name(client_protocol),
2248                              ssl.get_protocol_name(server_protocol),
2249                              certtype))
2250        client_context = ssl.SSLContext(client_protocol)
2251        client_context.options |= client_options
2252        server_context = ssl.SSLContext(server_protocol)
2253        server_context.options |= server_options
2254
2255        # NOTE: we must enable "ALL" ciphers on the client, otherwise an
2256        # SSLv23 client will send an SSLv3 hello (rather than SSLv2)
2257        # starting from OpenSSL 1.0.0 (see issue #8322).
2258        if client_context.protocol == ssl.PROTOCOL_SSLv23:
2259            client_context.set_ciphers("ALL")
2260
2261        for ctx in (client_context, server_context):
2262            ctx.verify_mode = certsreqs
2263            ctx.load_cert_chain(CERTFILE)
2264            ctx.load_verify_locations(CERTFILE)
2265        try:
2266            stats = server_params_test(client_context, server_context,
2267                                       chatty=False, connectionchatty=False)
2268        # Protocol mismatch can result in either an SSLError, or a
2269        # "Connection reset by peer" error.
2270        except ssl.SSLError:
2271            if expect_success:
2272                raise
2273        except OSError as e:
2274            if expect_success or e.errno != errno.ECONNRESET:
2275                raise
2276        else:
2277            if not expect_success:
2278                raise AssertionError(
2279                    "Client protocol %s succeeded with server protocol %s!"
2280                    % (ssl.get_protocol_name(client_protocol),
2281                       ssl.get_protocol_name(server_protocol)))
2282            elif (expect_success is not True
2283                  and expect_success != stats['version']):
2284                raise AssertionError("version mismatch: expected %r, got %r"
2285                                     % (expect_success, stats['version']))
2286
2287
2288    class ThreadedTests(unittest.TestCase):
2289
2290        @skip_if_broken_ubuntu_ssl
2291        def test_echo(self):
2292            """Basic test of an SSL client connecting to a server"""
2293            if support.verbose:
2294                sys.stdout.write("\n")
2295            for protocol in PROTOCOLS:
2296                if protocol in {ssl.PROTOCOL_TLS_CLIENT, ssl.PROTOCOL_TLS_SERVER}:
2297                    continue
2298                with self.subTest(protocol=ssl._PROTOCOL_NAMES[protocol]):
2299                    context = ssl.SSLContext(protocol)
2300                    context.load_cert_chain(CERTFILE)
2301                    server_params_test(context, context,
2302                                       chatty=True, connectionchatty=True)
2303
2304            client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2305            client_context.load_verify_locations(SIGNING_CA)
2306            server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
2307            # server_context.load_verify_locations(SIGNING_CA)
2308            server_context.load_cert_chain(SIGNED_CERTFILE2)
2309
2310            with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_SERVER):
2311                server_params_test(client_context=client_context,
2312                                   server_context=server_context,
2313                                   chatty=True, connectionchatty=True,
2314                                   sni_name='fakehostname')
2315
2316            client_context.check_hostname = False
2317            with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_CLIENT):
2318                with self.assertRaises(ssl.SSLError) as e:
2319                    server_params_test(client_context=server_context,
2320                                       server_context=client_context,
2321                                       chatty=True, connectionchatty=True,
2322                                       sni_name='fakehostname')
2323                self.assertIn('called a function you should not call',
2324                              str(e.exception))
2325
2326            with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_SERVER):
2327                with self.assertRaises(ssl.SSLError) as e:
2328                    server_params_test(client_context=server_context,
2329                                       server_context=server_context,
2330                                       chatty=True, connectionchatty=True)
2331                self.assertIn('called a function you should not call',
2332                              str(e.exception))
2333
2334            with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_CLIENT):
2335                with self.assertRaises(ssl.SSLError) as e:
2336                    server_params_test(client_context=server_context,
2337                                       server_context=client_context,
2338                                       chatty=True, connectionchatty=True)
2339                self.assertIn('called a function you should not call',
2340                              str(e.exception))
2341
2342
2343        def test_getpeercert(self):
2344            if support.verbose:
2345                sys.stdout.write("\n")
2346            context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
2347            context.verify_mode = ssl.CERT_REQUIRED
2348            context.load_verify_locations(CERTFILE)
2349            context.load_cert_chain(CERTFILE)
2350            server = ThreadedEchoServer(context=context, chatty=False)
2351            with server:
2352                s = context.wrap_socket(socket.socket(),
2353                                        do_handshake_on_connect=False)
2354                s.connect((HOST, server.port))
2355                # getpeercert() raise ValueError while the handshake isn't
2356                # done.
2357                with self.assertRaises(ValueError):
2358                    s.getpeercert()
2359                s.do_handshake()
2360                cert = s.getpeercert()
2361                self.assertTrue(cert, "Can't get peer certificate.")
2362                cipher = s.cipher()
2363                if support.verbose:
2364                    sys.stdout.write(pprint.pformat(cert) + '\n')
2365                    sys.stdout.write("Connection cipher is " + str(cipher) + '.\n')
2366                if 'subject' not in cert:
2367                    self.fail("No subject field in certificate: %s." %
2368                              pprint.pformat(cert))
2369                if ((('organizationName', 'Python Software Foundation'),)
2370                    not in cert['subject']):
2371                    self.fail(
2372                        "Missing or invalid 'organizationName' field in certificate subject; "
2373                        "should be 'Python Software Foundation'.")
2374                self.assertIn('notBefore', cert)
2375                self.assertIn('notAfter', cert)
2376                before = ssl.cert_time_to_seconds(cert['notBefore'])
2377                after = ssl.cert_time_to_seconds(cert['notAfter'])
2378                self.assertLess(before, after)
2379                s.close()
2380
2381        @unittest.skipUnless(have_verify_flags(),
2382                            "verify_flags need OpenSSL > 0.9.8")
2383        def test_crl_check(self):
2384            if support.verbose:
2385                sys.stdout.write("\n")
2386
2387            server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
2388            server_context.load_cert_chain(SIGNED_CERTFILE)
2389
2390            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
2391            context.verify_mode = ssl.CERT_REQUIRED
2392            context.load_verify_locations(SIGNING_CA)
2393            tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
2394            self.assertEqual(context.verify_flags, ssl.VERIFY_DEFAULT | tf)
2395
2396            # VERIFY_DEFAULT should pass
2397            server = ThreadedEchoServer(context=server_context, chatty=True)
2398            with server:
2399                with context.wrap_socket(socket.socket()) as s:
2400                    s.connect((HOST, server.port))
2401                    cert = s.getpeercert()
2402                    self.assertTrue(cert, "Can't get peer certificate.")
2403
2404            # VERIFY_CRL_CHECK_LEAF without a loaded CRL file fails
2405            context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF
2406
2407            server = ThreadedEchoServer(context=server_context, chatty=True)
2408            with server:
2409                with context.wrap_socket(socket.socket()) as s:
2410                    with self.assertRaisesRegex(ssl.SSLError,
2411                                                "certificate verify failed"):
2412                        s.connect((HOST, server.port))
2413
2414            # now load a CRL file. The CRL file is signed by the CA.
2415            context.load_verify_locations(CRLFILE)
2416
2417            server = ThreadedEchoServer(context=server_context, chatty=True)
2418            with server:
2419                with context.wrap_socket(socket.socket()) as s:
2420                    s.connect((HOST, server.port))
2421                    cert = s.getpeercert()
2422                    self.assertTrue(cert, "Can't get peer certificate.")
2423
2424        def test_check_hostname(self):
2425            if support.verbose:
2426                sys.stdout.write("\n")
2427
2428            server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
2429            server_context.load_cert_chain(SIGNED_CERTFILE)
2430
2431            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
2432            context.verify_mode = ssl.CERT_REQUIRED
2433            context.check_hostname = True
2434            context.load_verify_locations(SIGNING_CA)
2435
2436            # correct hostname should verify
2437            server = ThreadedEchoServer(context=server_context, chatty=True)
2438            with server:
2439                with context.wrap_socket(socket.socket(),
2440                                         server_hostname="localhost") as s:
2441                    s.connect((HOST, server.port))
2442                    cert = s.getpeercert()
2443                    self.assertTrue(cert, "Can't get peer certificate.")
2444
2445            # incorrect hostname should raise an exception
2446            server = ThreadedEchoServer(context=server_context, chatty=True)
2447            with server:
2448                with context.wrap_socket(socket.socket(),
2449                                         server_hostname="invalid") as s:
2450                    with self.assertRaisesRegex(ssl.CertificateError,
2451                                                "hostname 'invalid' doesn't match 'localhost'"):
2452                        s.connect((HOST, server.port))
2453
2454            # missing server_hostname arg should cause an exception, too
2455            server = ThreadedEchoServer(context=server_context, chatty=True)
2456            with server:
2457                with socket.socket() as s:
2458                    with self.assertRaisesRegex(ValueError,
2459                                                "check_hostname requires server_hostname"):
2460                        context.wrap_socket(s)
2461
2462        def test_wrong_cert(self):
2463            """Connecting when the server rejects the client's certificate
2464
2465            Launch a server with CERT_REQUIRED, and check that trying to
2466            connect to it with a wrong client certificate fails.
2467            """
2468            certfile = os.path.join(os.path.dirname(__file__) or os.curdir,
2469                                       "wrongcert.pem")
2470            server = ThreadedEchoServer(CERTFILE,
2471                                        certreqs=ssl.CERT_REQUIRED,
2472                                        cacerts=CERTFILE, chatty=False,
2473                                        connectionchatty=False)
2474            with server, \
2475                    socket.socket() as sock, \
2476                    test_wrap_socket(sock,
2477                                        certfile=certfile,
2478                                        ssl_version=ssl.PROTOCOL_TLSv1) as s:
2479                try:
2480                    # Expect either an SSL error about the server rejecting
2481                    # the connection, or a low-level connection reset (which
2482                    # sometimes happens on Windows)
2483                    s.connect((HOST, server.port))
2484                except ssl.SSLError as e:
2485                    if support.verbose:
2486                        sys.stdout.write("\nSSLError is %r\n" % e)
2487                except OSError as e:
2488                    if e.errno != errno.ECONNRESET:
2489                        raise
2490                    if support.verbose:
2491                        sys.stdout.write("\nsocket.error is %r\n" % e)
2492                else:
2493                    self.fail("Use of invalid cert should have failed!")
2494
2495        def test_rude_shutdown(self):
2496            """A brutal shutdown of an SSL server should raise an OSError
2497            in the client when attempting handshake.
2498            """
2499            listener_ready = threading.Event()
2500            listener_gone = threading.Event()
2501
2502            s = socket.socket()
2503            port = support.bind_port(s, HOST)
2504
2505            # `listener` runs in a thread.  It sits in an accept() until
2506            # the main thread connects.  Then it rudely closes the socket,
2507            # and sets Event `listener_gone` to let the main thread know
2508            # the socket is gone.
2509            def listener():
2510                s.listen()
2511                listener_ready.set()
2512                newsock, addr = s.accept()
2513                newsock.close()
2514                s.close()
2515                listener_gone.set()
2516
2517            def connector():
2518                listener_ready.wait()
2519                with socket.socket() as c:
2520                    c.connect((HOST, port))
2521                    listener_gone.wait()
2522                    try:
2523                        ssl_sock = test_wrap_socket(c)
2524                    except OSError:
2525                        pass
2526                    else:
2527                        self.fail('connecting to closed SSL socket should have failed')
2528
2529            t = threading.Thread(target=listener)
2530            t.start()
2531            try:
2532                connector()
2533            finally:
2534                t.join()
2535
2536        @skip_if_broken_ubuntu_ssl
2537        @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv2'),
2538                             "OpenSSL is compiled without SSLv2 support")
2539        def test_protocol_sslv2(self):
2540            """Connecting to an SSLv2 server with various client options"""
2541            if support.verbose:
2542                sys.stdout.write("\n")
2543            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True)
2544            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL)
2545            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED)
2546            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False)
2547            if hasattr(ssl, 'PROTOCOL_SSLv3'):
2548                try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False)
2549            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False)
2550            # SSLv23 client with specific SSL options
2551            if no_sslv2_implies_sslv3_hello():
2552                # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
2553                try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
2554                                   client_options=ssl.OP_NO_SSLv2)
2555            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
2556                               client_options=ssl.OP_NO_SSLv3)
2557            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
2558                               client_options=ssl.OP_NO_TLSv1)
2559
2560        @skip_if_broken_ubuntu_ssl
2561        def test_protocol_sslv23(self):
2562            """Connecting to an SSLv23 server with various client options"""
2563            if support.verbose:
2564                sys.stdout.write("\n")
2565            if hasattr(ssl, 'PROTOCOL_SSLv2'):
2566                try:
2567                    try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True)
2568                except OSError as x:
2569                    # this fails on some older versions of OpenSSL (0.9.7l, for instance)
2570                    if support.verbose:
2571                        sys.stdout.write(
2572                            " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n"
2573                            % str(x))
2574            if hasattr(ssl, 'PROTOCOL_SSLv3'):
2575                try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False)
2576            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True)
2577            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1')
2578
2579            if hasattr(ssl, 'PROTOCOL_SSLv3'):
2580                try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_OPTIONAL)
2581            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL)
2582            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
2583
2584            if hasattr(ssl, 'PROTOCOL_SSLv3'):
2585                try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_REQUIRED)
2586            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED)
2587            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
2588
2589            # Server with specific SSL options
2590            if hasattr(ssl, 'PROTOCOL_SSLv3'):
2591                try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False,
2592                               server_options=ssl.OP_NO_SSLv3)
2593            # Will choose TLSv1
2594            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True,
2595                               server_options=ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
2596            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, False,
2597                               server_options=ssl.OP_NO_TLSv1)
2598
2599
2600        @skip_if_broken_ubuntu_ssl
2601        @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv3'),
2602                             "OpenSSL is compiled without SSLv3 support")
2603        def test_protocol_sslv3(self):
2604            """Connecting to an SSLv3 server with various client options"""
2605            if support.verbose:
2606                sys.stdout.write("\n")
2607            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3')
2608            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL)
2609            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED)
2610            if hasattr(ssl, 'PROTOCOL_SSLv2'):
2611                try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)
2612            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False,
2613                               client_options=ssl.OP_NO_SSLv3)
2614            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False)
2615            if no_sslv2_implies_sslv3_hello():
2616                # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
2617                try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23,
2618                                   False, client_options=ssl.OP_NO_SSLv2)
2619
2620        @skip_if_broken_ubuntu_ssl
2621        def test_protocol_tlsv1(self):
2622            """Connecting to a TLSv1 server with various client options"""
2623            if support.verbose:
2624                sys.stdout.write("\n")
2625            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1')
2626            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
2627            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
2628            if hasattr(ssl, 'PROTOCOL_SSLv2'):
2629                try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False)
2630            if hasattr(ssl, 'PROTOCOL_SSLv3'):
2631                try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False)
2632            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False,
2633                               client_options=ssl.OP_NO_TLSv1)
2634
2635        @skip_if_broken_ubuntu_ssl
2636        @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_1"),
2637                             "TLS version 1.1 not supported.")
2638        def test_protocol_tlsv1_1(self):
2639            """Connecting to a TLSv1.1 server with various client options.
2640               Testing against older TLS versions."""
2641            if support.verbose:
2642                sys.stdout.write("\n")
2643            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
2644            if hasattr(ssl, 'PROTOCOL_SSLv2'):
2645                try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False)
2646            if hasattr(ssl, 'PROTOCOL_SSLv3'):
2647                try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False)
2648            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False,
2649                               client_options=ssl.OP_NO_TLSv1_1)
2650
2651            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
2652            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False)
2653            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False)
2654
2655
2656        @skip_if_broken_ubuntu_ssl
2657        @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_2"),
2658                             "TLS version 1.2 not supported.")
2659        def test_protocol_tlsv1_2(self):
2660            """Connecting to a TLSv1.2 server with various client options.
2661               Testing against older TLS versions."""
2662            if support.verbose:
2663                sys.stdout.write("\n")
2664            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2',
2665                               server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,
2666                               client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,)
2667            if hasattr(ssl, 'PROTOCOL_SSLv2'):
2668                try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False)
2669            if hasattr(ssl, 'PROTOCOL_SSLv3'):
2670                try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False)
2671            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False,
2672                               client_options=ssl.OP_NO_TLSv1_2)
2673
2674            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2')
2675            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False)
2676            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False)
2677            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
2678            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False)
2679
2680        def test_starttls(self):
2681            """Switching from clear text to encrypted and back again."""
2682            msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6")
2683
2684            server = ThreadedEchoServer(CERTFILE,
2685                                        ssl_version=ssl.PROTOCOL_TLSv1,
2686                                        starttls_server=True,
2687                                        chatty=True,
2688                                        connectionchatty=True)
2689            wrapped = False
2690            with server:
2691                s = socket.socket()
2692                s.setblocking(1)
2693                s.connect((HOST, server.port))
2694                if support.verbose:
2695                    sys.stdout.write("\n")
2696                for indata in msgs:
2697                    if support.verbose:
2698                        sys.stdout.write(
2699                            " client:  sending %r...\n" % indata)
2700                    if wrapped:
2701                        conn.write(indata)
2702                        outdata = conn.read()
2703                    else:
2704                        s.send(indata)
2705                        outdata = s.recv(1024)
2706                    msg = outdata.strip().lower()
2707                    if indata == b"STARTTLS" and msg.startswith(b"ok"):
2708                        # STARTTLS ok, switch to secure mode
2709                        if support.verbose:
2710                            sys.stdout.write(
2711                                " client:  read %r from server, starting TLS...\n"
2712                                % msg)
2713                        conn = test_wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1)
2714                        wrapped = True
2715                    elif indata == b"ENDTLS" and msg.startswith(b"ok"):
2716                        # ENDTLS ok, switch back to clear text
2717                        if support.verbose:
2718                            sys.stdout.write(
2719                                " client:  read %r from server, ending TLS...\n"
2720                                % msg)
2721                        s = conn.unwrap()
2722                        wrapped = False
2723                    else:
2724                        if support.verbose:
2725                            sys.stdout.write(
2726                                " client:  read %r from server\n" % msg)
2727                if support.verbose:
2728                    sys.stdout.write(" client:  closing connection.\n")
2729                if wrapped:
2730                    conn.write(b"over\n")
2731                else:
2732                    s.send(b"over\n")
2733                if wrapped:
2734                    conn.close()
2735                else:
2736                    s.close()
2737
2738        def test_socketserver(self):
2739            """Using socketserver to create and manage SSL connections."""
2740            server = make_https_server(self, certfile=CERTFILE)
2741            # try to connect
2742            if support.verbose:
2743                sys.stdout.write('\n')
2744            with open(CERTFILE, 'rb') as f:
2745                d1 = f.read()
2746            d2 = ''
2747            # now fetch the same data from the HTTPS server
2748            url = 'https://localhost:%d/%s' % (
2749                server.port, os.path.split(CERTFILE)[1])
2750            context = ssl.create_default_context(cafile=CERTFILE)
2751            f = urllib.request.urlopen(url, context=context)
2752            try:
2753                dlen = f.info().get("content-length")
2754                if dlen and (int(dlen) > 0):
2755                    d2 = f.read(int(dlen))
2756                    if support.verbose:
2757                        sys.stdout.write(
2758                            " client: read %d bytes from remote server '%s'\n"
2759                            % (len(d2), server))
2760            finally:
2761                f.close()
2762            self.assertEqual(d1, d2)
2763
2764        def test_asyncore_server(self):
2765            """Check the example asyncore integration."""
2766            if support.verbose:
2767                sys.stdout.write("\n")
2768
2769            indata = b"FOO\n"
2770            server = AsyncoreEchoServer(CERTFILE)
2771            with server:
2772                s = test_wrap_socket(socket.socket())
2773                s.connect(('127.0.0.1', server.port))
2774                if support.verbose:
2775                    sys.stdout.write(
2776                        " client:  sending %r...\n" % indata)
2777                s.write(indata)
2778                outdata = s.read()
2779                if support.verbose:
2780                    sys.stdout.write(" client:  read %r\n" % outdata)
2781                if outdata != indata.lower():
2782                    self.fail(
2783                        "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
2784                        % (outdata[:20], len(outdata),
2785                           indata[:20].lower(), len(indata)))
2786                s.write(b"over\n")
2787                if support.verbose:
2788                    sys.stdout.write(" client:  closing connection.\n")
2789                s.close()
2790                if support.verbose:
2791                    sys.stdout.write(" client:  connection closed.\n")
2792
2793        def test_recv_send(self):
2794            """Test recv(), send() and friends."""
2795            if support.verbose:
2796                sys.stdout.write("\n")
2797
2798            server = ThreadedEchoServer(CERTFILE,
2799                                        certreqs=ssl.CERT_NONE,
2800                                        ssl_version=ssl.PROTOCOL_TLSv1,
2801                                        cacerts=CERTFILE,
2802                                        chatty=True,
2803                                        connectionchatty=False)
2804            with server:
2805                s = test_wrap_socket(socket.socket(),
2806                                    server_side=False,
2807                                    certfile=CERTFILE,
2808                                    ca_certs=CERTFILE,
2809                                    cert_reqs=ssl.CERT_NONE,
2810                                    ssl_version=ssl.PROTOCOL_TLSv1)
2811                s.connect((HOST, server.port))
2812                # helper methods for standardising recv* method signatures
2813                def _recv_into():
2814                    b = bytearray(b"\0"*100)
2815                    count = s.recv_into(b)
2816                    return b[:count]
2817
2818                def _recvfrom_into():
2819                    b = bytearray(b"\0"*100)
2820                    count, addr = s.recvfrom_into(b)
2821                    return b[:count]
2822
2823                # (name, method, expect success?, *args, return value func)
2824                send_methods = [
2825                    ('send', s.send, True, [], len),
2826                    ('sendto', s.sendto, False, ["some.address"], len),
2827                    ('sendall', s.sendall, True, [], lambda x: None),
2828                ]
2829                # (name, method, whether to expect success, *args)
2830                recv_methods = [
2831                    ('recv', s.recv, True, []),
2832                    ('recvfrom', s.recvfrom, False, ["some.address"]),
2833                    ('recv_into', _recv_into, True, []),
2834                    ('recvfrom_into', _recvfrom_into, False, []),
2835                ]
2836                data_prefix = "PREFIX_"
2837
2838                for (meth_name, send_meth, expect_success, args,
2839                        ret_val_meth) in send_methods:
2840                    indata = (data_prefix + meth_name).encode('ascii')
2841                    try:
2842                        ret = send_meth(indata, *args)
2843                        msg = "sending with {}".format(meth_name)
2844                        self.assertEqual(ret, ret_val_meth(indata), msg=msg)
2845                        outdata = s.read()
2846                        if outdata != indata.lower():
2847                            self.fail(
2848                                "While sending with <<{name:s}>> bad data "
2849                                "<<{outdata:r}>> ({nout:d}) received; "
2850                                "expected <<{indata:r}>> ({nin:d})\n".format(
2851                                    name=meth_name, outdata=outdata[:20],
2852                                    nout=len(outdata),
2853                                    indata=indata[:20], nin=len(indata)
2854                                )
2855                            )
2856                    except ValueError as e:
2857                        if expect_success:
2858                            self.fail(
2859                                "Failed to send with method <<{name:s}>>; "
2860                                "expected to succeed.\n".format(name=meth_name)
2861                            )
2862                        if not str(e).startswith(meth_name):
2863                            self.fail(
2864                                "Method <<{name:s}>> failed with unexpected "
2865                                "exception message: {exp:s}\n".format(
2866                                    name=meth_name, exp=e
2867                                )
2868                            )
2869
2870                for meth_name, recv_meth, expect_success, args in recv_methods:
2871                    indata = (data_prefix + meth_name).encode('ascii')
2872                    try:
2873                        s.send(indata)
2874                        outdata = recv_meth(*args)
2875                        if outdata != indata.lower():
2876                            self.fail(
2877                                "While receiving with <<{name:s}>> bad data "
2878                                "<<{outdata:r}>> ({nout:d}) received; "
2879                                "expected <<{indata:r}>> ({nin:d})\n".format(
2880                                    name=meth_name, outdata=outdata[:20],
2881                                    nout=len(outdata),
2882                                    indata=indata[:20], nin=len(indata)
2883                                )
2884                            )
2885                    except ValueError as e:
2886                        if expect_success:
2887                            self.fail(
2888                                "Failed to receive with method <<{name:s}>>; "
2889                                "expected to succeed.\n".format(name=meth_name)
2890                            )
2891                        if not str(e).startswith(meth_name):
2892                            self.fail(
2893                                "Method <<{name:s}>> failed with unexpected "
2894                                "exception message: {exp:s}\n".format(
2895                                    name=meth_name, exp=e
2896                                )
2897                            )
2898                        # consume data
2899                        s.read()
2900
2901                # read(-1, buffer) is supported, even though read(-1) is not
2902                data = b"data"
2903                s.send(data)
2904                buffer = bytearray(len(data))
2905                self.assertEqual(s.read(-1, buffer), len(data))
2906                self.assertEqual(buffer, data)
2907
2908                # Make sure sendmsg et al are disallowed to avoid
2909                # inadvertent disclosure of data and/or corruption
2910                # of the encrypted data stream
2911                self.assertRaises(NotImplementedError, s.sendmsg, [b"data"])
2912                self.assertRaises(NotImplementedError, s.recvmsg, 100)
2913                self.assertRaises(NotImplementedError,
2914                                  s.recvmsg_into, bytearray(100))
2915
2916                s.write(b"over\n")
2917
2918                self.assertRaises(ValueError, s.recv, -1)
2919                self.assertRaises(ValueError, s.read, -1)
2920
2921                s.close()
2922
2923        def test_recv_zero(self):
2924            server = ThreadedEchoServer(CERTFILE)
2925            server.__enter__()
2926            self.addCleanup(server.__exit__, None, None)
2927            s = socket.create_connection((HOST, server.port))
2928            self.addCleanup(s.close)
2929            s = test_wrap_socket(s, suppress_ragged_eofs=False)
2930            self.addCleanup(s.close)
2931
2932            # recv/read(0) should return no data
2933            s.send(b"data")
2934            self.assertEqual(s.recv(0), b"")
2935            self.assertEqual(s.read(0), b"")
2936            self.assertEqual(s.read(), b"data")
2937
2938            # Should not block if the other end sends no data
2939            s.setblocking(False)
2940            self.assertEqual(s.recv(0), b"")
2941            self.assertEqual(s.recv_into(bytearray()), 0)
2942
2943        def test_nonblocking_send(self):
2944            server = ThreadedEchoServer(CERTFILE,
2945                                        certreqs=ssl.CERT_NONE,
2946                                        ssl_version=ssl.PROTOCOL_TLSv1,
2947                                        cacerts=CERTFILE,
2948                                        chatty=True,
2949                                        connectionchatty=False)
2950            with server:
2951                s = test_wrap_socket(socket.socket(),
2952                                    server_side=False,
2953                                    certfile=CERTFILE,
2954                                    ca_certs=CERTFILE,
2955                                    cert_reqs=ssl.CERT_NONE,
2956                                    ssl_version=ssl.PROTOCOL_TLSv1)
2957                s.connect((HOST, server.port))
2958                s.setblocking(False)
2959
2960                # If we keep sending data, at some point the buffers
2961                # will be full and the call will block
2962                buf = bytearray(8192)
2963                def fill_buffer():
2964                    while True:
2965                        s.send(buf)
2966                self.assertRaises((ssl.SSLWantWriteError,
2967                                   ssl.SSLWantReadError), fill_buffer)
2968
2969                # Now read all the output and discard it
2970                s.setblocking(True)
2971                s.close()
2972
2973        def test_handshake_timeout(self):
2974            # Issue #5103: SSL handshake must respect the socket timeout
2975            server = socket.socket(socket.AF_INET)
2976            host = "127.0.0.1"
2977            port = support.bind_port(server)
2978            started = threading.Event()
2979            finish = False
2980
2981            def serve():
2982                server.listen()
2983                started.set()
2984                conns = []
2985                while not finish:
2986                    r, w, e = select.select([server], [], [], 0.1)
2987                    if server in r:
2988                        # Let the socket hang around rather than having
2989                        # it closed by garbage collection.
2990                        conns.append(server.accept()[0])
2991                for sock in conns:
2992                    sock.close()
2993
2994            t = threading.Thread(target=serve)
2995            t.start()
2996            started.wait()
2997
2998            try:
2999                try:
3000                    c = socket.socket(socket.AF_INET)
3001                    c.settimeout(0.2)
3002                    c.connect((host, port))
3003                    # Will attempt handshake and time out
3004                    self.assertRaisesRegex(socket.timeout, "timed out",
3005                                           test_wrap_socket, c)
3006                finally:
3007                    c.close()
3008                try:
3009                    c = socket.socket(socket.AF_INET)
3010                    c = test_wrap_socket(c)
3011                    c.settimeout(0.2)
3012                    # Will attempt handshake and time out
3013                    self.assertRaisesRegex(socket.timeout, "timed out",
3014                                           c.connect, (host, port))
3015                finally:
3016                    c.close()
3017            finally:
3018                finish = True
3019                t.join()
3020                server.close()
3021
3022        def test_server_accept(self):
3023            # Issue #16357: accept() on a SSLSocket created through
3024            # SSLContext.wrap_socket().
3025            context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
3026            context.verify_mode = ssl.CERT_REQUIRED
3027            context.load_verify_locations(CERTFILE)
3028            context.load_cert_chain(CERTFILE)
3029            server = socket.socket(socket.AF_INET)
3030            host = "127.0.0.1"
3031            port = support.bind_port(server)
3032            server = context.wrap_socket(server, server_side=True)
3033            self.assertTrue(server.server_side)
3034
3035            evt = threading.Event()
3036            remote = None
3037            peer = None
3038            def serve():
3039                nonlocal remote, peer
3040                server.listen()
3041                # Block on the accept and wait on the connection to close.
3042                evt.set()
3043                remote, peer = server.accept()
3044                remote.recv(1)
3045
3046            t = threading.Thread(target=serve)
3047            t.start()
3048            # Client wait until server setup and perform a connect.
3049            evt.wait()
3050            client = context.wrap_socket(socket.socket())
3051            client.connect((host, port))
3052            client_addr = client.getsockname()
3053            client.close()
3054            t.join()
3055            remote.close()
3056            server.close()
3057            # Sanity checks.
3058            self.assertIsInstance(remote, ssl.SSLSocket)
3059            self.assertEqual(peer, client_addr)
3060
3061        def test_getpeercert_enotconn(self):
3062            context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
3063            with context.wrap_socket(socket.socket()) as sock:
3064                with self.assertRaises(OSError) as cm:
3065                    sock.getpeercert()
3066                self.assertEqual(cm.exception.errno, errno.ENOTCONN)
3067
3068        def test_do_handshake_enotconn(self):
3069            context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
3070            with context.wrap_socket(socket.socket()) as sock:
3071                with self.assertRaises(OSError) as cm:
3072                    sock.do_handshake()
3073                self.assertEqual(cm.exception.errno, errno.ENOTCONN)
3074
3075        def test_default_ciphers(self):
3076            context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
3077            try:
3078                # Force a set of weak ciphers on our client context
3079                context.set_ciphers("DES")
3080            except ssl.SSLError:
3081                self.skipTest("no DES cipher available")
3082            with ThreadedEchoServer(CERTFILE,
3083                                    ssl_version=ssl.PROTOCOL_SSLv23,
3084                                    chatty=False) as server:
3085                with context.wrap_socket(socket.socket()) as s:
3086                    with self.assertRaises(OSError):
3087                        s.connect((HOST, server.port))
3088            self.assertIn("no shared cipher", str(server.conn_errors[0]))
3089
3090        def test_version_basic(self):
3091            """
3092            Basic tests for SSLSocket.version().
3093            More tests are done in the test_protocol_*() methods.
3094            """
3095            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3096            with ThreadedEchoServer(CERTFILE,
3097                                    ssl_version=ssl.PROTOCOL_TLSv1,
3098                                    chatty=False) as server:
3099                with context.wrap_socket(socket.socket()) as s:
3100                    self.assertIs(s.version(), None)
3101                    s.connect((HOST, server.port))
3102                    self.assertEqual(s.version(), 'TLSv1')
3103                self.assertIs(s.version(), None)
3104
3105        @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL")
3106        def test_default_ecdh_curve(self):
3107            # Issue #21015: elliptic curve-based Diffie Hellman key exchange
3108            # should be enabled by default on SSL contexts.
3109            context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
3110            context.load_cert_chain(CERTFILE)
3111            # Prior to OpenSSL 1.0.0, ECDH ciphers have to be enabled
3112            # explicitly using the 'ECCdraft' cipher alias.  Otherwise,
3113            # our default cipher list should prefer ECDH-based ciphers
3114            # automatically.
3115            if ssl.OPENSSL_VERSION_INFO < (1, 0, 0):
3116                context.set_ciphers("ECCdraft:ECDH")
3117            with ThreadedEchoServer(context=context) as server:
3118                with context.wrap_socket(socket.socket()) as s:
3119                    s.connect((HOST, server.port))
3120                    self.assertIn("ECDH", s.cipher()[0])
3121
3122        @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
3123                             "'tls-unique' channel binding not available")
3124        def test_tls_unique_channel_binding(self):
3125            """Test tls-unique channel binding."""
3126            if support.verbose:
3127                sys.stdout.write("\n")
3128
3129            server = ThreadedEchoServer(CERTFILE,
3130                                        certreqs=ssl.CERT_NONE,
3131                                        ssl_version=ssl.PROTOCOL_TLSv1,
3132                                        cacerts=CERTFILE,
3133                                        chatty=True,
3134                                        connectionchatty=False)
3135            with server:
3136                s = test_wrap_socket(socket.socket(),
3137                                    server_side=False,
3138                                    certfile=CERTFILE,
3139                                    ca_certs=CERTFILE,
3140                                    cert_reqs=ssl.CERT_NONE,
3141                                    ssl_version=ssl.PROTOCOL_TLSv1)
3142                s.connect((HOST, server.port))
3143                # get the data
3144                cb_data = s.get_channel_binding("tls-unique")
3145                if support.verbose:
3146                    sys.stdout.write(" got channel binding data: {0!r}\n"
3147                                     .format(cb_data))
3148
3149                # check if it is sane
3150                self.assertIsNotNone(cb_data)
3151                self.assertEqual(len(cb_data), 12) # True for TLSv1
3152
3153                # and compare with the peers version
3154                s.write(b"CB tls-unique\n")
3155                peer_data_repr = s.read().strip()
3156                self.assertEqual(peer_data_repr,
3157                                 repr(cb_data).encode("us-ascii"))
3158                s.close()
3159
3160                # now, again
3161                s = test_wrap_socket(socket.socket(),
3162                                    server_side=False,
3163                                    certfile=CERTFILE,
3164                                    ca_certs=CERTFILE,
3165                                    cert_reqs=ssl.CERT_NONE,
3166                                    ssl_version=ssl.PROTOCOL_TLSv1)
3167                s.connect((HOST, server.port))
3168                new_cb_data = s.get_channel_binding("tls-unique")
3169                if support.verbose:
3170                    sys.stdout.write(" got another channel binding data: {0!r}\n"
3171                                     .format(new_cb_data))
3172                # is it really unique
3173                self.assertNotEqual(cb_data, new_cb_data)
3174                self.assertIsNotNone(cb_data)
3175                self.assertEqual(len(cb_data), 12) # True for TLSv1
3176                s.write(b"CB tls-unique\n")
3177                peer_data_repr = s.read().strip()
3178                self.assertEqual(peer_data_repr,
3179                                 repr(new_cb_data).encode("us-ascii"))
3180                s.close()
3181
3182        def test_compression(self):
3183            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3184            context.load_cert_chain(CERTFILE)
3185            stats = server_params_test(context, context,
3186                                       chatty=True, connectionchatty=True)
3187            if support.verbose:
3188                sys.stdout.write(" got compression: {!r}\n".format(stats['compression']))
3189            self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' })
3190
3191        @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'),
3192                             "ssl.OP_NO_COMPRESSION needed for this test")
3193        def test_compression_disabled(self):
3194            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3195            context.load_cert_chain(CERTFILE)
3196            context.options |= ssl.OP_NO_COMPRESSION
3197            stats = server_params_test(context, context,
3198                                       chatty=True, connectionchatty=True)
3199            self.assertIs(stats['compression'], None)
3200
3201        def test_dh_params(self):
3202            # Check we can get a connection with ephemeral Diffie-Hellman
3203            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3204            context.load_cert_chain(CERTFILE)
3205            context.load_dh_params(DHFILE)
3206            context.set_ciphers("kEDH")
3207            stats = server_params_test(context, context,
3208                                       chatty=True, connectionchatty=True)
3209            cipher = stats["cipher"][0]
3210            parts = cipher.split("-")
3211            if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
3212                self.fail("Non-DH cipher: " + cipher[0])
3213
3214        def test_selected_alpn_protocol(self):
3215            # selected_alpn_protocol() is None unless ALPN is used.
3216            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3217            context.load_cert_chain(CERTFILE)
3218            stats = server_params_test(context, context,
3219                                       chatty=True, connectionchatty=True)
3220            self.assertIs(stats['client_alpn_protocol'], None)
3221
3222        @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required")
3223        def test_selected_alpn_protocol_if_server_uses_alpn(self):
3224            # selected_alpn_protocol() is None unless ALPN is used by the client.
3225            client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3226            client_context.load_verify_locations(CERTFILE)
3227            server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3228            server_context.load_cert_chain(CERTFILE)
3229            server_context.set_alpn_protocols(['foo', 'bar'])
3230            stats = server_params_test(client_context, server_context,
3231                                       chatty=True, connectionchatty=True)
3232            self.assertIs(stats['client_alpn_protocol'], None)
3233
3234        @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test")
3235        def test_alpn_protocols(self):
3236            server_protocols = ['foo', 'bar', 'milkshake']
3237            protocol_tests = [
3238                (['foo', 'bar'], 'foo'),
3239                (['bar', 'foo'], 'foo'),
3240                (['milkshake'], 'milkshake'),
3241                (['http/3.0', 'http/4.0'], None)
3242            ]
3243            for client_protocols, expected in protocol_tests:
3244                server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
3245                server_context.load_cert_chain(CERTFILE)
3246                server_context.set_alpn_protocols(server_protocols)
3247                client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
3248                client_context.load_cert_chain(CERTFILE)
3249                client_context.set_alpn_protocols(client_protocols)
3250
3251                try:
3252                    stats = server_params_test(client_context,
3253                                               server_context,
3254                                               chatty=True,
3255                                               connectionchatty=True)
3256                except ssl.SSLError as e:
3257                    stats = e
3258
3259                if expected is None and IS_OPENSSL_1_1:
3260                    # OpenSSL 1.1.0 raises handshake error
3261                    self.assertIsInstance(stats, ssl.SSLError)
3262                else:
3263                    msg = "failed trying %s (s) and %s (c).\n" \
3264                        "was expecting %s, but got %%s from the %%s" \
3265                            % (str(server_protocols), str(client_protocols),
3266                                str(expected))
3267                    client_result = stats['client_alpn_protocol']
3268                    self.assertEqual(client_result, expected,
3269                                     msg % (client_result, "client"))
3270                    server_result = stats['server_alpn_protocols'][-1] \
3271                        if len(stats['server_alpn_protocols']) else 'nothing'
3272                    self.assertEqual(server_result, expected,
3273                                     msg % (server_result, "server"))
3274
3275        def test_selected_npn_protocol(self):
3276            # selected_npn_protocol() is None unless NPN is used
3277            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3278            context.load_cert_chain(CERTFILE)
3279            stats = server_params_test(context, context,
3280                                       chatty=True, connectionchatty=True)
3281            self.assertIs(stats['client_npn_protocol'], None)
3282
3283        @unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test")
3284        def test_npn_protocols(self):
3285            server_protocols = ['http/1.1', 'spdy/2']
3286            protocol_tests = [
3287                (['http/1.1', 'spdy/2'], 'http/1.1'),
3288                (['spdy/2', 'http/1.1'], 'http/1.1'),
3289                (['spdy/2', 'test'], 'spdy/2'),
3290                (['abc', 'def'], 'abc')
3291            ]
3292            for client_protocols, expected in protocol_tests:
3293                server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3294                server_context.load_cert_chain(CERTFILE)
3295                server_context.set_npn_protocols(server_protocols)
3296                client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3297                client_context.load_cert_chain(CERTFILE)
3298                client_context.set_npn_protocols(client_protocols)
3299                stats = server_params_test(client_context, server_context,
3300                                           chatty=True, connectionchatty=True)
3301
3302                msg = "failed trying %s (s) and %s (c).\n" \
3303                      "was expecting %s, but got %%s from the %%s" \
3304                          % (str(server_protocols), str(client_protocols),
3305                             str(expected))
3306                client_result = stats['client_npn_protocol']
3307                self.assertEqual(client_result, expected, msg % (client_result, "client"))
3308                server_result = stats['server_npn_protocols'][-1] \
3309                    if len(stats['server_npn_protocols']) else 'nothing'
3310                self.assertEqual(server_result, expected, msg % (server_result, "server"))
3311
3312        def sni_contexts(self):
3313            server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3314            server_context.load_cert_chain(SIGNED_CERTFILE)
3315            other_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3316            other_context.load_cert_chain(SIGNED_CERTFILE2)
3317            client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3318            client_context.verify_mode = ssl.CERT_REQUIRED
3319            client_context.load_verify_locations(SIGNING_CA)
3320            return server_context, other_context, client_context
3321
3322        def check_common_name(self, stats, name):
3323            cert = stats['peercert']
3324            self.assertIn((('commonName', name),), cert['subject'])
3325
3326        @needs_sni
3327        def test_sni_callback(self):
3328            calls = []
3329            server_context, other_context, client_context = self.sni_contexts()
3330
3331            def servername_cb(ssl_sock, server_name, initial_context):
3332                calls.append((server_name, initial_context))
3333                if server_name is not None:
3334                    ssl_sock.context = other_context
3335            server_context.set_servername_callback(servername_cb)
3336
3337            stats = server_params_test(client_context, server_context,
3338                                       chatty=True,
3339                                       sni_name='supermessage')
3340            # The hostname was fetched properly, and the certificate was
3341            # changed for the connection.
3342            self.assertEqual(calls, [("supermessage", server_context)])
3343            # CERTFILE4 was selected
3344            self.check_common_name(stats, 'fakehostname')
3345
3346            calls = []
3347            # The callback is called with server_name=None
3348            stats = server_params_test(client_context, server_context,
3349                                       chatty=True,
3350                                       sni_name=None)
3351            self.assertEqual(calls, [(None, server_context)])
3352            self.check_common_name(stats, 'localhost')
3353
3354            # Check disabling the callback
3355            calls = []
3356            server_context.set_servername_callback(None)
3357
3358            stats = server_params_test(client_context, server_context,
3359                                       chatty=True,
3360                                       sni_name='notfunny')
3361            # Certificate didn't change
3362            self.check_common_name(stats, 'localhost')
3363            self.assertEqual(calls, [])
3364
3365        @needs_sni
3366        def test_sni_callback_alert(self):
3367            # Returning a TLS alert is reflected to the connecting client
3368            server_context, other_context, client_context = self.sni_contexts()
3369
3370            def cb_returning_alert(ssl_sock, server_name, initial_context):
3371                return ssl.ALERT_DESCRIPTION_ACCESS_DENIED
3372            server_context.set_servername_callback(cb_returning_alert)
3373
3374            with self.assertRaises(ssl.SSLError) as cm:
3375                stats = server_params_test(client_context, server_context,
3376                                           chatty=False,
3377                                           sni_name='supermessage')
3378            self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_ACCESS_DENIED')
3379
3380        @needs_sni
3381        def test_sni_callback_raising(self):
3382            # Raising fails the connection with a TLS handshake failure alert.
3383            server_context, other_context, client_context = self.sni_contexts()
3384
3385            def cb_raising(ssl_sock, server_name, initial_context):
3386                1/0
3387            server_context.set_servername_callback(cb_raising)
3388
3389            with self.assertRaises(ssl.SSLError) as cm, \
3390                 support.captured_stderr() as stderr:
3391                stats = server_params_test(client_context, server_context,
3392                                           chatty=False,
3393                                           sni_name='supermessage')
3394            self.assertEqual(cm.exception.reason, 'SSLV3_ALERT_HANDSHAKE_FAILURE')
3395            self.assertIn("ZeroDivisionError", stderr.getvalue())
3396
3397        @needs_sni
3398        def test_sni_callback_wrong_return_type(self):
3399            # Returning the wrong return type terminates the TLS connection
3400            # with an internal error alert.
3401            server_context, other_context, client_context = self.sni_contexts()
3402
3403            def cb_wrong_return_type(ssl_sock, server_name, initial_context):
3404                return "foo"
3405            server_context.set_servername_callback(cb_wrong_return_type)
3406
3407            with self.assertRaises(ssl.SSLError) as cm, \
3408                 support.captured_stderr() as stderr:
3409                stats = server_params_test(client_context, server_context,
3410                                           chatty=False,
3411                                           sni_name='supermessage')
3412            self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR')
3413            self.assertIn("TypeError", stderr.getvalue())
3414
3415        def test_shared_ciphers(self):
3416            server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3417            server_context.load_cert_chain(SIGNED_CERTFILE)
3418            client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3419            client_context.verify_mode = ssl.CERT_REQUIRED
3420            client_context.load_verify_locations(SIGNING_CA)
3421            if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2):
3422                client_context.set_ciphers("AES128:AES256")
3423                server_context.set_ciphers("AES256")
3424                alg1 = "AES256"
3425                alg2 = "AES-256"
3426            else:
3427                client_context.set_ciphers("AES:3DES")
3428                server_context.set_ciphers("3DES")
3429                alg1 = "3DES"
3430                alg2 = "DES-CBC3"
3431
3432            stats = server_params_test(client_context, server_context)
3433            ciphers = stats['server_shared_ciphers'][0]
3434            self.assertGreater(len(ciphers), 0)
3435            for name, tls_version, bits in ciphers:
3436                if not alg1 in name.split("-") and alg2 not in name:
3437                    self.fail(name)
3438
3439        def test_read_write_after_close_raises_valuerror(self):
3440            context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
3441            context.verify_mode = ssl.CERT_REQUIRED
3442            context.load_verify_locations(CERTFILE)
3443            context.load_cert_chain(CERTFILE)
3444            server = ThreadedEchoServer(context=context, chatty=False)
3445
3446            with server:
3447                s = context.wrap_socket(socket.socket())
3448                s.connect((HOST, server.port))
3449                s.close()
3450
3451                self.assertRaises(ValueError, s.read, 1024)
3452                self.assertRaises(ValueError, s.write, b'hello')
3453
3454        def test_sendfile(self):
3455            TEST_DATA = b"x" * 512
3456            with open(support.TESTFN, 'wb') as f:
3457                f.write(TEST_DATA)
3458            self.addCleanup(support.unlink, support.TESTFN)
3459            context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
3460            context.verify_mode = ssl.CERT_REQUIRED
3461            context.load_verify_locations(CERTFILE)
3462            context.load_cert_chain(CERTFILE)
3463            server = ThreadedEchoServer(context=context, chatty=False)
3464            with server:
3465                with context.wrap_socket(socket.socket()) as s:
3466                    s.connect((HOST, server.port))
3467                    with open(support.TESTFN, 'rb') as file:
3468                        s.sendfile(file)
3469                        self.assertEqual(s.recv(1024), TEST_DATA)
3470
3471        def test_session(self):
3472            server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3473            server_context.load_cert_chain(SIGNED_CERTFILE)
3474            client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3475            client_context.verify_mode = ssl.CERT_REQUIRED
3476            client_context.load_verify_locations(SIGNING_CA)
3477
3478            # first connection without session
3479            stats = server_params_test(client_context, server_context)
3480            session = stats['session']
3481            self.assertTrue(session.id)
3482            self.assertGreater(session.time, 0)
3483            self.assertGreater(session.timeout, 0)
3484            self.assertTrue(session.has_ticket)
3485            if ssl.OPENSSL_VERSION_INFO > (1, 0, 1):
3486                self.assertGreater(session.ticket_lifetime_hint, 0)
3487            self.assertFalse(stats['session_reused'])
3488            sess_stat = server_context.session_stats()
3489            self.assertEqual(sess_stat['accept'], 1)
3490            self.assertEqual(sess_stat['hits'], 0)
3491
3492            # reuse session
3493            stats = server_params_test(client_context, server_context, session=session)
3494            sess_stat = server_context.session_stats()
3495            self.assertEqual(sess_stat['accept'], 2)
3496            self.assertEqual(sess_stat['hits'], 1)
3497            self.assertTrue(stats['session_reused'])
3498            session2 = stats['session']
3499            self.assertEqual(session2.id, session.id)
3500            self.assertEqual(session2, session)
3501            self.assertIsNot(session2, session)
3502            self.assertGreaterEqual(session2.time, session.time)
3503            self.assertGreaterEqual(session2.timeout, session.timeout)
3504
3505            # another one without session
3506            stats = server_params_test(client_context, server_context)
3507            self.assertFalse(stats['session_reused'])
3508            session3 = stats['session']
3509            self.assertNotEqual(session3.id, session.id)
3510            self.assertNotEqual(session3, session)
3511            sess_stat = server_context.session_stats()
3512            self.assertEqual(sess_stat['accept'], 3)
3513            self.assertEqual(sess_stat['hits'], 1)
3514
3515            # reuse session again
3516            stats = server_params_test(client_context, server_context, session=session)
3517            self.assertTrue(stats['session_reused'])
3518            session4 = stats['session']
3519            self.assertEqual(session4.id, session.id)
3520            self.assertEqual(session4, session)
3521            self.assertGreaterEqual(session4.time, session.time)
3522            self.assertGreaterEqual(session4.timeout, session.timeout)
3523            sess_stat = server_context.session_stats()
3524            self.assertEqual(sess_stat['accept'], 4)
3525            self.assertEqual(sess_stat['hits'], 2)
3526
3527        def test_session_handling(self):
3528            context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
3529            context.verify_mode = ssl.CERT_REQUIRED
3530            context.load_verify_locations(CERTFILE)
3531            context.load_cert_chain(CERTFILE)
3532
3533            context2 = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
3534            context2.verify_mode = ssl.CERT_REQUIRED
3535            context2.load_verify_locations(CERTFILE)
3536            context2.load_cert_chain(CERTFILE)
3537
3538            server = ThreadedEchoServer(context=context, chatty=False)
3539            with server:
3540                with context.wrap_socket(socket.socket()) as s:
3541                    # session is None before handshake
3542                    self.assertEqual(s.session, None)
3543                    self.assertEqual(s.session_reused, None)
3544                    s.connect((HOST, server.port))
3545                    session = s.session
3546                    self.assertTrue(session)
3547                    with self.assertRaises(TypeError) as e:
3548                        s.session = object
3549                    self.assertEqual(str(e.exception), 'Value is not a SSLSession.')
3550
3551                with context.wrap_socket(socket.socket()) as s:
3552                    s.connect((HOST, server.port))
3553                    # cannot set session after handshake
3554                    with self.assertRaises(ValueError) as e:
3555                        s.session = session
3556                    self.assertEqual(str(e.exception),
3557                                     'Cannot set session after handshake.')
3558
3559                with context.wrap_socket(socket.socket()) as s:
3560                    # can set session before handshake and before the
3561                    # connection was established
3562                    s.session = session
3563                    s.connect((HOST, server.port))
3564                    self.assertEqual(s.session.id, session.id)
3565                    self.assertEqual(s.session, session)
3566                    self.assertEqual(s.session_reused, True)
3567
3568                with context2.wrap_socket(socket.socket()) as s:
3569                    # cannot re-use session with a different SSLContext
3570                    with self.assertRaises(ValueError) as e:
3571                        s.session = session
3572                        s.connect((HOST, server.port))
3573                    self.assertEqual(str(e.exception),
3574                                     'Session refers to a different SSLContext.')
3575
3576
3577def test_main(verbose=False):
3578    if support.verbose:
3579        import warnings
3580        plats = {
3581            'Linux': platform.linux_distribution,
3582            'Mac': platform.mac_ver,
3583            'Windows': platform.win32_ver,
3584        }
3585        with warnings.catch_warnings():
3586            warnings.filterwarnings(
3587                'ignore',
3588                r'dist\(\) and linux_distribution\(\) '
3589                'functions are deprecated .*',
3590                PendingDeprecationWarning,
3591            )
3592            for name, func in plats.items():
3593                plat = func()
3594                if plat and plat[0]:
3595                    plat = '%s %r' % (name, plat)
3596                    break
3597            else:
3598                plat = repr(platform.platform())
3599        print("test_ssl: testing with %r %r" %
3600            (ssl.OPENSSL_VERSION, ssl.OPENSSL_VERSION_INFO))
3601        print("          under %s" % plat)
3602        print("          HAS_SNI = %r" % ssl.HAS_SNI)
3603        print("          OP_ALL = 0x%8x" % ssl.OP_ALL)
3604        try:
3605            print("          OP_NO_TLSv1_1 = 0x%8x" % ssl.OP_NO_TLSv1_1)
3606        except AttributeError:
3607            pass
3608
3609    for filename in [
3610        CERTFILE, BYTES_CERTFILE,
3611        ONLYCERT, ONLYKEY, BYTES_ONLYCERT, BYTES_ONLYKEY,
3612        SIGNED_CERTFILE, SIGNED_CERTFILE2, SIGNING_CA,
3613        BADCERT, BADKEY, EMPTYCERT]:
3614        if not os.path.exists(filename):
3615            raise support.TestFailed("Can't read certificate file %r" % filename)
3616
3617    tests = [
3618        ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests,
3619        SimpleBackgroundTests,
3620    ]
3621
3622    if support.is_resource_enabled('network'):
3623        tests.append(NetworkedTests)
3624
3625    if _have_threads:
3626        thread_info = support.threading_setup()
3627        if thread_info:
3628            tests.append(ThreadedTests)
3629
3630    try:
3631        support.run_unittest(*tests)
3632    finally:
3633        if _have_threads:
3634            support.threading_cleanup(*thread_info)
3635
3636if __name__ == "__main__":
3637    test_main()
3638