1# Wrapper module for _ssl, providing some additional facilities
2# implemented in Python.  Written by Bill Janssen.
3
4"""This module provides some more Pythonic support for SSL.
5
6Object types:
7
8  SSLSocket -- subtype of socket.socket which does SSL over the socket
9
10Exceptions:
11
12  SSLError -- exception raised for I/O errors
13
14Functions:
15
16  cert_time_to_seconds -- convert time string used for certificate
17                          notBefore and notAfter functions to integer
18                          seconds past the Epoch (the time values
19                          returned from time.time())
20
21  fetch_server_certificate (HOST, PORT) -- fetch the certificate provided
22                          by the server running on HOST at port PORT.  No
23                          validation of the certificate is performed.
24
25Integer constants:
26
27SSL_ERROR_ZERO_RETURN
28SSL_ERROR_WANT_READ
29SSL_ERROR_WANT_WRITE
30SSL_ERROR_WANT_X509_LOOKUP
31SSL_ERROR_SYSCALL
32SSL_ERROR_SSL
33SSL_ERROR_WANT_CONNECT
34
35SSL_ERROR_EOF
36SSL_ERROR_INVALID_ERROR_CODE
37
38The following group define certificate requirements that one side is
39allowing/requiring from the other side:
40
41CERT_NONE - no certificates from the other side are required (or will
42            be looked at if provided)
43CERT_OPTIONAL - certificates are not required, but if provided will be
44                validated, and if validation fails, the connection will
45                also fail
46CERT_REQUIRED - certificates are required, and will be validated, and
47                if validation fails, the connection will also fail
48
49The following constants identify various SSL protocol variants:
50
51PROTOCOL_SSLv2
52PROTOCOL_SSLv3
53PROTOCOL_SSLv23
54PROTOCOL_TLS
55PROTOCOL_TLSv1
56PROTOCOL_TLSv1_1
57PROTOCOL_TLSv1_2
58
59The following constants identify various SSL alert message descriptions as per
60http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-6
61
62ALERT_DESCRIPTION_CLOSE_NOTIFY
63ALERT_DESCRIPTION_UNEXPECTED_MESSAGE
64ALERT_DESCRIPTION_BAD_RECORD_MAC
65ALERT_DESCRIPTION_RECORD_OVERFLOW
66ALERT_DESCRIPTION_DECOMPRESSION_FAILURE
67ALERT_DESCRIPTION_HANDSHAKE_FAILURE
68ALERT_DESCRIPTION_BAD_CERTIFICATE
69ALERT_DESCRIPTION_UNSUPPORTED_CERTIFICATE
70ALERT_DESCRIPTION_CERTIFICATE_REVOKED
71ALERT_DESCRIPTION_CERTIFICATE_EXPIRED
72ALERT_DESCRIPTION_CERTIFICATE_UNKNOWN
73ALERT_DESCRIPTION_ILLEGAL_PARAMETER
74ALERT_DESCRIPTION_UNKNOWN_CA
75ALERT_DESCRIPTION_ACCESS_DENIED
76ALERT_DESCRIPTION_DECODE_ERROR
77ALERT_DESCRIPTION_DECRYPT_ERROR
78ALERT_DESCRIPTION_PROTOCOL_VERSION
79ALERT_DESCRIPTION_INSUFFICIENT_SECURITY
80ALERT_DESCRIPTION_INTERNAL_ERROR
81ALERT_DESCRIPTION_USER_CANCELLED
82ALERT_DESCRIPTION_NO_RENEGOTIATION
83ALERT_DESCRIPTION_UNSUPPORTED_EXTENSION
84ALERT_DESCRIPTION_CERTIFICATE_UNOBTAINABLE
85ALERT_DESCRIPTION_UNRECOGNIZED_NAME
86ALERT_DESCRIPTION_BAD_CERTIFICATE_STATUS_RESPONSE
87ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE
88ALERT_DESCRIPTION_UNKNOWN_PSK_IDENTITY
89"""
90
91import textwrap
92import re
93import sys
94import os
95from collections import namedtuple
96from contextlib import closing
97
98import _ssl             # if we can't import it, let the error propagate
99
100from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
101from _ssl import _SSLContext
102from _ssl import (
103    SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError,
104    SSLSyscallError, SSLEOFError,
105    )
106from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
107from _ssl import txt2obj as _txt2obj, nid2obj as _nid2obj
108from _ssl import RAND_status, RAND_add
109try:
110    from _ssl import RAND_egd
111except ImportError:
112    # LibreSSL does not provide RAND_egd
113    pass
114
115def _import_symbols(prefix):
116    for n in dir(_ssl):
117        if n.startswith(prefix):
118            globals()[n] = getattr(_ssl, n)
119
120_import_symbols('OP_')
121_import_symbols('ALERT_DESCRIPTION_')
122_import_symbols('SSL_ERROR_')
123_import_symbols('PROTOCOL_')
124_import_symbols('VERIFY_')
125
126from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN
127
128from _ssl import _OPENSSL_API_VERSION
129
130_PROTOCOL_NAMES = {value: name for name, value in globals().items()
131                   if name.startswith('PROTOCOL_')
132                       and name != 'PROTOCOL_SSLv23'}
133PROTOCOL_SSLv23 = PROTOCOL_TLS
134
135try:
136    _SSLv2_IF_EXISTS = PROTOCOL_SSLv2
137except NameError:
138    _SSLv2_IF_EXISTS = None
139
140from socket import socket, _fileobject, _delegate_methods, error as socket_error
141if sys.platform == "win32":
142    from _ssl import enum_certificates, enum_crls
143
144from socket import socket, AF_INET, SOCK_STREAM, create_connection
145from socket import SOL_SOCKET, SO_TYPE
146import base64        # for DER-to-PEM translation
147import errno
148import warnings
149
150if _ssl.HAS_TLS_UNIQUE:
151    CHANNEL_BINDING_TYPES = ['tls-unique']
152else:
153    CHANNEL_BINDING_TYPES = []
154
155
156# Disable weak or insecure ciphers by default
157# (OpenSSL's default setting is 'DEFAULT:!aNULL:!eNULL')
158# Enable a better set of ciphers by default
159# This list has been explicitly chosen to:
160#   * Prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE)
161#   * Prefer ECDHE over DHE for better performance
162#   * Prefer AEAD over CBC for better performance and security
163#   * Prefer AES-GCM over ChaCha20 because most platforms have AES-NI
164#     (ChaCha20 needs OpenSSL 1.1.0 or patched 1.0.2)
165#   * Prefer any AES-GCM and ChaCha20 over any AES-CBC for better
166#     performance and security
167#   * Then Use HIGH cipher suites as a fallback
168#   * Disable NULL authentication, NULL encryption, 3DES and MD5 MACs
169#     for security reasons
170_DEFAULT_CIPHERS = (
171    'ECDH+AESGCM:ECDH+CHACHA20:DH+AESGCM:DH+CHACHA20:ECDH+AES256:DH+AES256:'
172    'ECDH+AES128:DH+AES:ECDH+HIGH:DH+HIGH:RSA+AESGCM:RSA+AES:RSA+HIGH:'
173    '!aNULL:!eNULL:!MD5:!3DES'
174    )
175
176# Restricted and more secure ciphers for the server side
177# This list has been explicitly chosen to:
178#   * Prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE)
179#   * Prefer ECDHE over DHE for better performance
180#   * Prefer AEAD over CBC for better performance and security
181#   * Prefer AES-GCM over ChaCha20 because most platforms have AES-NI
182#   * Prefer any AES-GCM and ChaCha20 over any AES-CBC for better
183#     performance and security
184#   * Then Use HIGH cipher suites as a fallback
185#   * Disable NULL authentication, NULL encryption, MD5 MACs, DSS, RC4, and
186#     3DES for security reasons
187_RESTRICTED_SERVER_CIPHERS = (
188    'ECDH+AESGCM:ECDH+CHACHA20:DH+AESGCM:DH+CHACHA20:ECDH+AES256:DH+AES256:'
189    'ECDH+AES128:DH+AES:ECDH+HIGH:DH+HIGH:RSA+AESGCM:RSA+AES:RSA+HIGH:'
190    '!aNULL:!eNULL:!MD5:!DSS:!RC4:!3DES'
191)
192
193
194class CertificateError(ValueError):
195    pass
196
197
198def _dnsname_match(dn, hostname, max_wildcards=1):
199    """Matching according to RFC 6125, section 6.4.3
200
201    http://tools.ietf.org/html/rfc6125#section-6.4.3
202    """
203    pats = []
204    if not dn:
205        return False
206
207    pieces = dn.split(r'.')
208    leftmost = pieces[0]
209    remainder = pieces[1:]
210
211    wildcards = leftmost.count('*')
212    if wildcards > max_wildcards:
213        # Issue #17980: avoid denials of service by refusing more
214        # than one wildcard per fragment.  A survery of established
215        # policy among SSL implementations showed it to be a
216        # reasonable choice.
217        raise CertificateError(
218            "too many wildcards in certificate DNS name: " + repr(dn))
219
220    # speed up common case w/o wildcards
221    if not wildcards:
222        return dn.lower() == hostname.lower()
223
224    # RFC 6125, section 6.4.3, subitem 1.
225    # The client SHOULD NOT attempt to match a presented identifier in which
226    # the wildcard character comprises a label other than the left-most label.
227    if leftmost == '*':
228        # When '*' is a fragment by itself, it matches a non-empty dotless
229        # fragment.
230        pats.append('[^.]+')
231    elif leftmost.startswith('xn--') or hostname.startswith('xn--'):
232        # RFC 6125, section 6.4.3, subitem 3.
233        # The client SHOULD NOT attempt to match a presented identifier
234        # where the wildcard character is embedded within an A-label or
235        # U-label of an internationalized domain name.
236        pats.append(re.escape(leftmost))
237    else:
238        # Otherwise, '*' matches any dotless string, e.g. www*
239        pats.append(re.escape(leftmost).replace(r'\*', '[^.]*'))
240
241    # add the remaining fragments, ignore any wildcards
242    for frag in remainder:
243        pats.append(re.escape(frag))
244
245    pat = re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE)
246    return pat.match(hostname)
247
248
249def match_hostname(cert, hostname):
250    """Verify that *cert* (in decoded format as returned by
251    SSLSocket.getpeercert()) matches the *hostname*.  RFC 2818 and RFC 6125
252    rules are followed, but IP addresses are not accepted for *hostname*.
253
254    CertificateError is raised on failure. On success, the function
255    returns nothing.
256    """
257    if not cert:
258        raise ValueError("empty or no certificate, match_hostname needs a "
259                         "SSL socket or SSL context with either "
260                         "CERT_OPTIONAL or CERT_REQUIRED")
261    dnsnames = []
262    san = cert.get('subjectAltName', ())
263    for key, value in san:
264        if key == 'DNS':
265            if _dnsname_match(value, hostname):
266                return
267            dnsnames.append(value)
268    if not dnsnames:
269        # The subject is only checked when there is no dNSName entry
270        # in subjectAltName
271        for sub in cert.get('subject', ()):
272            for key, value in sub:
273                # XXX according to RFC 2818, the most specific Common Name
274                # must be used.
275                if key == 'commonName':
276                    if _dnsname_match(value, hostname):
277                        return
278                    dnsnames.append(value)
279    if len(dnsnames) > 1:
280        raise CertificateError("hostname %r "
281            "doesn't match either of %s"
282            % (hostname, ', '.join(map(repr, dnsnames))))
283    elif len(dnsnames) == 1:
284        raise CertificateError("hostname %r "
285            "doesn't match %r"
286            % (hostname, dnsnames[0]))
287    else:
288        raise CertificateError("no appropriate commonName or "
289            "subjectAltName fields were found")
290
291
292DefaultVerifyPaths = namedtuple("DefaultVerifyPaths",
293    "cafile capath openssl_cafile_env openssl_cafile openssl_capath_env "
294    "openssl_capath")
295
296def get_default_verify_paths():
297    """Return paths to default cafile and capath.
298    """
299    parts = _ssl.get_default_verify_paths()
300
301    # environment vars shadow paths
302    cafile = os.environ.get(parts[0], parts[1])
303    capath = os.environ.get(parts[2], parts[3])
304
305    return DefaultVerifyPaths(cafile if os.path.isfile(cafile) else None,
306                              capath if os.path.isdir(capath) else None,
307                              *parts)
308
309
310class _ASN1Object(namedtuple("_ASN1Object", "nid shortname longname oid")):
311    """ASN.1 object identifier lookup
312    """
313    __slots__ = ()
314
315    def __new__(cls, oid):
316        return super(_ASN1Object, cls).__new__(cls, *_txt2obj(oid, name=False))
317
318    @classmethod
319    def fromnid(cls, nid):
320        """Create _ASN1Object from OpenSSL numeric ID
321        """
322        return super(_ASN1Object, cls).__new__(cls, *_nid2obj(nid))
323
324    @classmethod
325    def fromname(cls, name):
326        """Create _ASN1Object from short name, long name or OID
327        """
328        return super(_ASN1Object, cls).__new__(cls, *_txt2obj(name, name=True))
329
330
331class Purpose(_ASN1Object):
332    """SSLContext purpose flags with X509v3 Extended Key Usage objects
333    """
334
335Purpose.SERVER_AUTH = Purpose('1.3.6.1.5.5.7.3.1')
336Purpose.CLIENT_AUTH = Purpose('1.3.6.1.5.5.7.3.2')
337
338
339class SSLContext(_SSLContext):
340    """An SSLContext holds various SSL-related configuration options and
341    data, such as certificates and possibly a private key."""
342
343    __slots__ = ('protocol', '__weakref__')
344    _windows_cert_stores = ("CA", "ROOT")
345
346    def __new__(cls, protocol, *args, **kwargs):
347        self = _SSLContext.__new__(cls, protocol)
348        if protocol != _SSLv2_IF_EXISTS:
349            self.set_ciphers(_DEFAULT_CIPHERS)
350        return self
351
352    def __init__(self, protocol):
353        self.protocol = protocol
354
355    def wrap_socket(self, sock, server_side=False,
356                    do_handshake_on_connect=True,
357                    suppress_ragged_eofs=True,
358                    server_hostname=None):
359        return SSLSocket(sock=sock, server_side=server_side,
360                         do_handshake_on_connect=do_handshake_on_connect,
361                         suppress_ragged_eofs=suppress_ragged_eofs,
362                         server_hostname=server_hostname,
363                         _context=self)
364
365    def set_npn_protocols(self, npn_protocols):
366        protos = bytearray()
367        for protocol in npn_protocols:
368            b = protocol.encode('ascii')
369            if len(b) == 0 or len(b) > 255:
370                raise SSLError('NPN protocols must be 1 to 255 in length')
371            protos.append(len(b))
372            protos.extend(b)
373
374        self._set_npn_protocols(protos)
375
376    def set_alpn_protocols(self, alpn_protocols):
377        protos = bytearray()
378        for protocol in alpn_protocols:
379            b = protocol.encode('ascii')
380            if len(b) == 0 or len(b) > 255:
381                raise SSLError('ALPN protocols must be 1 to 255 in length')
382            protos.append(len(b))
383            protos.extend(b)
384
385        self._set_alpn_protocols(protos)
386
387    def _load_windows_store_certs(self, storename, purpose):
388        certs = bytearray()
389        try:
390            for cert, encoding, trust in enum_certificates(storename):
391                # CA certs are never PKCS#7 encoded
392                if encoding == "x509_asn":
393                    if trust is True or purpose.oid in trust:
394                        certs.extend(cert)
395        except OSError:
396            warnings.warn("unable to enumerate Windows certificate store")
397        if certs:
398            self.load_verify_locations(cadata=certs)
399        return certs
400
401    def load_default_certs(self, purpose=Purpose.SERVER_AUTH):
402        if not isinstance(purpose, _ASN1Object):
403            raise TypeError(purpose)
404        if sys.platform == "win32":
405            for storename in self._windows_cert_stores:
406                self._load_windows_store_certs(storename, purpose)
407        self.set_default_verify_paths()
408
409
410def create_default_context(purpose=Purpose.SERVER_AUTH, cafile=None,
411                           capath=None, cadata=None):
412    """Create a SSLContext object with default settings.
413
414    NOTE: The protocol and settings may change anytime without prior
415          deprecation. The values represent a fair balance between maximum
416          compatibility and security.
417    """
418    if not isinstance(purpose, _ASN1Object):
419        raise TypeError(purpose)
420
421    context = SSLContext(PROTOCOL_TLS)
422
423    # SSLv2 considered harmful.
424    context.options |= OP_NO_SSLv2
425
426    # SSLv3 has problematic security and is only required for really old
427    # clients such as IE6 on Windows XP
428    context.options |= OP_NO_SSLv3
429
430    # disable compression to prevent CRIME attacks (OpenSSL 1.0+)
431    context.options |= getattr(_ssl, "OP_NO_COMPRESSION", 0)
432
433    if purpose == Purpose.SERVER_AUTH:
434        # verify certs and host name in client mode
435        context.verify_mode = CERT_REQUIRED
436        context.check_hostname = True
437    elif purpose == Purpose.CLIENT_AUTH:
438        # Prefer the server's ciphers by default so that we get stronger
439        # encryption
440        context.options |= getattr(_ssl, "OP_CIPHER_SERVER_PREFERENCE", 0)
441
442        # Use single use keys in order to improve forward secrecy
443        context.options |= getattr(_ssl, "OP_SINGLE_DH_USE", 0)
444        context.options |= getattr(_ssl, "OP_SINGLE_ECDH_USE", 0)
445
446        # disallow ciphers with known vulnerabilities
447        context.set_ciphers(_RESTRICTED_SERVER_CIPHERS)
448
449    if cafile or capath or cadata:
450        context.load_verify_locations(cafile, capath, cadata)
451    elif context.verify_mode != CERT_NONE:
452        # no explicit cafile, capath or cadata but the verify mode is
453        # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
454        # root CA certificates for the given purpose. This may fail silently.
455        context.load_default_certs(purpose)
456    return context
457
458def _create_unverified_context(protocol=PROTOCOL_TLS, cert_reqs=None,
459                           check_hostname=False, purpose=Purpose.SERVER_AUTH,
460                           certfile=None, keyfile=None,
461                           cafile=None, capath=None, cadata=None):
462    """Create a SSLContext object for Python stdlib modules
463
464    All Python stdlib modules shall use this function to create SSLContext
465    objects in order to keep common settings in one place. The configuration
466    is less restrict than create_default_context()'s to increase backward
467    compatibility.
468    """
469    if not isinstance(purpose, _ASN1Object):
470        raise TypeError(purpose)
471
472    context = SSLContext(protocol)
473    # SSLv2 considered harmful.
474    context.options |= OP_NO_SSLv2
475    # SSLv3 has problematic security and is only required for really old
476    # clients such as IE6 on Windows XP
477    context.options |= OP_NO_SSLv3
478
479    if cert_reqs is not None:
480        context.verify_mode = cert_reqs
481    context.check_hostname = check_hostname
482
483    if keyfile and not certfile:
484        raise ValueError("certfile must be specified")
485    if certfile or keyfile:
486        context.load_cert_chain(certfile, keyfile)
487
488    # load CA root certs
489    if cafile or capath or cadata:
490        context.load_verify_locations(cafile, capath, cadata)
491    elif context.verify_mode != CERT_NONE:
492        # no explicit cafile, capath or cadata but the verify mode is
493        # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
494        # root CA certificates for the given purpose. This may fail silently.
495        context.load_default_certs(purpose)
496
497    return context
498
499# Backwards compatibility alias, even though it's not a public name.
500_create_stdlib_context = _create_unverified_context
501
502# PEP 493: Verify HTTPS by default, but allow envvar to override that
503_https_verify_envvar = 'PYTHONHTTPSVERIFY'
504
505def _get_https_context_factory():
506    if not sys.flags.ignore_environment:
507        config_setting = os.environ.get(_https_verify_envvar)
508        if config_setting == '0':
509            return _create_unverified_context
510    return create_default_context
511
512_create_default_https_context = _get_https_context_factory()
513
514# PEP 493: "private" API to configure HTTPS defaults without monkeypatching
515def _https_verify_certificates(enable=True):
516    """Verify server HTTPS certificates by default?"""
517    global _create_default_https_context
518    if enable:
519        _create_default_https_context = create_default_context
520    else:
521        _create_default_https_context = _create_unverified_context
522
523
524class SSLSocket(socket):
525    """This class implements a subtype of socket.socket that wraps
526    the underlying OS socket in an SSL context when necessary, and
527    provides read and write methods over that channel."""
528
529    def __init__(self, sock=None, keyfile=None, certfile=None,
530                 server_side=False, cert_reqs=CERT_NONE,
531                 ssl_version=PROTOCOL_TLS, ca_certs=None,
532                 do_handshake_on_connect=True,
533                 family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
534                 suppress_ragged_eofs=True, npn_protocols=None, ciphers=None,
535                 server_hostname=None,
536                 _context=None):
537
538        self._makefile_refs = 0
539        if _context:
540            self._context = _context
541        else:
542            if server_side and not certfile:
543                raise ValueError("certfile must be specified for server-side "
544                                 "operations")
545            if keyfile and not certfile:
546                raise ValueError("certfile must be specified")
547            if certfile and not keyfile:
548                keyfile = certfile
549            self._context = SSLContext(ssl_version)
550            self._context.verify_mode = cert_reqs
551            if ca_certs:
552                self._context.load_verify_locations(ca_certs)
553            if certfile:
554                self._context.load_cert_chain(certfile, keyfile)
555            if npn_protocols:
556                self._context.set_npn_protocols(npn_protocols)
557            if ciphers:
558                self._context.set_ciphers(ciphers)
559            self.keyfile = keyfile
560            self.certfile = certfile
561            self.cert_reqs = cert_reqs
562            self.ssl_version = ssl_version
563            self.ca_certs = ca_certs
564            self.ciphers = ciphers
565        # Can't use sock.type as other flags (such as SOCK_NONBLOCK) get
566        # mixed in.
567        if sock.getsockopt(SOL_SOCKET, SO_TYPE) != SOCK_STREAM:
568            raise NotImplementedError("only stream sockets are supported")
569        socket.__init__(self, _sock=sock._sock)
570        # The initializer for socket overrides the methods send(), recv(), etc.
571        # in the instancce, which we don't need -- but we want to provide the
572        # methods defined in SSLSocket.
573        for attr in _delegate_methods:
574            try:
575                delattr(self, attr)
576            except AttributeError:
577                pass
578        if server_side and server_hostname:
579            raise ValueError("server_hostname can only be specified "
580                             "in client mode")
581        if self._context.check_hostname and not server_hostname:
582            raise ValueError("check_hostname requires server_hostname")
583        self.server_side = server_side
584        self.server_hostname = server_hostname
585        self.do_handshake_on_connect = do_handshake_on_connect
586        self.suppress_ragged_eofs = suppress_ragged_eofs
587
588        # See if we are connected
589        try:
590            self.getpeername()
591        except socket_error as e:
592            if e.errno != errno.ENOTCONN:
593                raise
594            connected = False
595        else:
596            connected = True
597
598        self._closed = False
599        self._sslobj = None
600        self._connected = connected
601        if connected:
602            # create the SSL object
603            try:
604                self._sslobj = self._context._wrap_socket(self._sock, server_side,
605                                                          server_hostname, ssl_sock=self)
606                if do_handshake_on_connect:
607                    timeout = self.gettimeout()
608                    if timeout == 0.0:
609                        # non-blocking
610                        raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
611                    self.do_handshake()
612
613            except (OSError, ValueError):
614                self.close()
615                raise
616
617    @property
618    def context(self):
619        return self._context
620
621    @context.setter
622    def context(self, ctx):
623        self._context = ctx
624        self._sslobj.context = ctx
625
626    def dup(self):
627        raise NotImplemented("Can't dup() %s instances" %
628                             self.__class__.__name__)
629
630    def _checkClosed(self, msg=None):
631        # raise an exception here if you wish to check for spurious closes
632        pass
633
634    def _check_connected(self):
635        if not self._connected:
636            # getpeername() will raise ENOTCONN if the socket is really
637            # not connected; note that we can be connected even without
638            # _connected being set, e.g. if connect() first returned
639            # EAGAIN.
640            self.getpeername()
641
642    def read(self, len=1024, buffer=None):
643        """Read up to LEN bytes and return them.
644        Return zero-length string on EOF."""
645
646        self._checkClosed()
647        if not self._sslobj:
648            raise ValueError("Read on closed or unwrapped SSL socket.")
649        try:
650            if buffer is not None:
651                v = self._sslobj.read(len, buffer)
652            else:
653                v = self._sslobj.read(len)
654            return v
655        except SSLError as x:
656            if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
657                if buffer is not None:
658                    return 0
659                else:
660                    return b''
661            else:
662                raise
663
664    def write(self, data):
665        """Write DATA to the underlying SSL channel.  Returns
666        number of bytes of DATA actually transmitted."""
667
668        self._checkClosed()
669        if not self._sslobj:
670            raise ValueError("Write on closed or unwrapped SSL socket.")
671        return self._sslobj.write(data)
672
673    def getpeercert(self, binary_form=False):
674        """Returns a formatted version of the data in the
675        certificate provided by the other end of the SSL channel.
676        Return None if no certificate was provided, {} if a
677        certificate was provided, but not validated."""
678
679        self._checkClosed()
680        self._check_connected()
681        return self._sslobj.peer_certificate(binary_form)
682
683    def selected_npn_protocol(self):
684        self._checkClosed()
685        if not self._sslobj or not _ssl.HAS_NPN:
686            return None
687        else:
688            return self._sslobj.selected_npn_protocol()
689
690    def selected_alpn_protocol(self):
691        self._checkClosed()
692        if not self._sslobj or not _ssl.HAS_ALPN:
693            return None
694        else:
695            return self._sslobj.selected_alpn_protocol()
696
697    def cipher(self):
698        self._checkClosed()
699        if not self._sslobj:
700            return None
701        else:
702            return self._sslobj.cipher()
703
704    def compression(self):
705        self._checkClosed()
706        if not self._sslobj:
707            return None
708        else:
709            return self._sslobj.compression()
710
711    def send(self, data, flags=0):
712        self._checkClosed()
713        if self._sslobj:
714            if flags != 0:
715                raise ValueError(
716                    "non-zero flags not allowed in calls to send() on %s" %
717                    self.__class__)
718            try:
719                v = self._sslobj.write(data)
720            except SSLError as x:
721                if x.args[0] == SSL_ERROR_WANT_READ:
722                    return 0
723                elif x.args[0] == SSL_ERROR_WANT_WRITE:
724                    return 0
725                else:
726                    raise
727            else:
728                return v
729        else:
730            return self._sock.send(data, flags)
731
732    def sendto(self, data, flags_or_addr, addr=None):
733        self._checkClosed()
734        if self._sslobj:
735            raise ValueError("sendto not allowed on instances of %s" %
736                             self.__class__)
737        elif addr is None:
738            return self._sock.sendto(data, flags_or_addr)
739        else:
740            return self._sock.sendto(data, flags_or_addr, addr)
741
742
743    def sendall(self, data, flags=0):
744        self._checkClosed()
745        if self._sslobj:
746            if flags != 0:
747                raise ValueError(
748                    "non-zero flags not allowed in calls to sendall() on %s" %
749                    self.__class__)
750            amount = len(data)
751            count = 0
752            while (count < amount):
753                v = self.send(data[count:])
754                count += v
755            return amount
756        else:
757            return socket.sendall(self, data, flags)
758
759    def recv(self, buflen=1024, flags=0):
760        self._checkClosed()
761        if self._sslobj:
762            if flags != 0:
763                raise ValueError(
764                    "non-zero flags not allowed in calls to recv() on %s" %
765                    self.__class__)
766            return self.read(buflen)
767        else:
768            return self._sock.recv(buflen, flags)
769
770    def recv_into(self, buffer, nbytes=None, flags=0):
771        self._checkClosed()
772        if buffer and (nbytes is None):
773            nbytes = len(buffer)
774        elif nbytes is None:
775            nbytes = 1024
776        if self._sslobj:
777            if flags != 0:
778                raise ValueError(
779                  "non-zero flags not allowed in calls to recv_into() on %s" %
780                  self.__class__)
781            return self.read(nbytes, buffer)
782        else:
783            return self._sock.recv_into(buffer, nbytes, flags)
784
785    def recvfrom(self, buflen=1024, flags=0):
786        self._checkClosed()
787        if self._sslobj:
788            raise ValueError("recvfrom not allowed on instances of %s" %
789                             self.__class__)
790        else:
791            return self._sock.recvfrom(buflen, flags)
792
793    def recvfrom_into(self, buffer, nbytes=None, flags=0):
794        self._checkClosed()
795        if self._sslobj:
796            raise ValueError("recvfrom_into not allowed on instances of %s" %
797                             self.__class__)
798        else:
799            return self._sock.recvfrom_into(buffer, nbytes, flags)
800
801
802    def pending(self):
803        self._checkClosed()
804        if self._sslobj:
805            return self._sslobj.pending()
806        else:
807            return 0
808
809    def shutdown(self, how):
810        self._checkClosed()
811        self._sslobj = None
812        socket.shutdown(self, how)
813
814    def close(self):
815        if self._makefile_refs < 1:
816            self._sslobj = None
817            socket.close(self)
818        else:
819            self._makefile_refs -= 1
820
821    def unwrap(self):
822        if self._sslobj:
823            s = self._sslobj.shutdown()
824            self._sslobj = None
825            return s
826        else:
827            raise ValueError("No SSL wrapper around " + str(self))
828
829    def _real_close(self):
830        self._sslobj = None
831        socket._real_close(self)
832
833    def do_handshake(self, block=False):
834        """Perform a TLS/SSL handshake."""
835        self._check_connected()
836        timeout = self.gettimeout()
837        try:
838            if timeout == 0.0 and block:
839                self.settimeout(None)
840            self._sslobj.do_handshake()
841        finally:
842            self.settimeout(timeout)
843
844        if self.context.check_hostname:
845            if not self.server_hostname:
846                raise ValueError("check_hostname needs server_hostname "
847                                 "argument")
848            match_hostname(self.getpeercert(), self.server_hostname)
849
850    def _real_connect(self, addr, connect_ex):
851        if self.server_side:
852            raise ValueError("can't connect in server-side mode")
853        # Here we assume that the socket is client-side, and not
854        # connected at the time of the call.  We connect it, then wrap it.
855        if self._connected:
856            raise ValueError("attempt to connect already-connected SSLSocket!")
857        self._sslobj = self.context._wrap_socket(self._sock, False, self.server_hostname, ssl_sock=self)
858        try:
859            if connect_ex:
860                rc = socket.connect_ex(self, addr)
861            else:
862                rc = None
863                socket.connect(self, addr)
864            if not rc:
865                self._connected = True
866                if self.do_handshake_on_connect:
867                    self.do_handshake()
868            return rc
869        except (OSError, ValueError):
870            self._sslobj = None
871            raise
872
873    def connect(self, addr):
874        """Connects to remote ADDR, and then wraps the connection in
875        an SSL channel."""
876        self._real_connect(addr, False)
877
878    def connect_ex(self, addr):
879        """Connects to remote ADDR, and then wraps the connection in
880        an SSL channel."""
881        return self._real_connect(addr, True)
882
883    def accept(self):
884        """Accepts a new connection from a remote client, and returns
885        a tuple containing that new connection wrapped with a server-side
886        SSL channel, and the address of the remote client."""
887
888        newsock, addr = socket.accept(self)
889        newsock = self.context.wrap_socket(newsock,
890                    do_handshake_on_connect=self.do_handshake_on_connect,
891                    suppress_ragged_eofs=self.suppress_ragged_eofs,
892                    server_side=True)
893        return newsock, addr
894
895    def makefile(self, mode='r', bufsize=-1):
896
897        """Make and return a file-like object that
898        works with the SSL connection.  Just use the code
899        from the socket module."""
900
901        self._makefile_refs += 1
902        # close=True so as to decrement the reference count when done with
903        # the file-like object.
904        return _fileobject(self, mode, bufsize, close=True)
905
906    def get_channel_binding(self, cb_type="tls-unique"):
907        """Get channel binding data for current connection.  Raise ValueError
908        if the requested `cb_type` is not supported.  Return bytes of the data
909        or None if the data is not available (e.g. before the handshake).
910        """
911        if cb_type not in CHANNEL_BINDING_TYPES:
912            raise ValueError("Unsupported channel binding type")
913        if cb_type != "tls-unique":
914            raise NotImplementedError(
915                            "{0} channel binding type not implemented"
916                            .format(cb_type))
917        if self._sslobj is None:
918            return None
919        return self._sslobj.tls_unique_cb()
920
921    def version(self):
922        """
923        Return a string identifying the protocol version used by the
924        current SSL channel, or None if there is no established channel.
925        """
926        if self._sslobj is None:
927            return None
928        return self._sslobj.version()
929
930
931def wrap_socket(sock, keyfile=None, certfile=None,
932                server_side=False, cert_reqs=CERT_NONE,
933                ssl_version=PROTOCOL_TLS, ca_certs=None,
934                do_handshake_on_connect=True,
935                suppress_ragged_eofs=True,
936                ciphers=None):
937
938    return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,
939                     server_side=server_side, cert_reqs=cert_reqs,
940                     ssl_version=ssl_version, ca_certs=ca_certs,
941                     do_handshake_on_connect=do_handshake_on_connect,
942                     suppress_ragged_eofs=suppress_ragged_eofs,
943                     ciphers=ciphers)
944
945# some utility functions
946
947def cert_time_to_seconds(cert_time):
948    """Return the time in seconds since the Epoch, given the timestring
949    representing the "notBefore" or "notAfter" date from a certificate
950    in ``"%b %d %H:%M:%S %Y %Z"`` strptime format (C locale).
951
952    "notBefore" or "notAfter" dates must use UTC (RFC 5280).
953
954    Month is one of: Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec
955    UTC should be specified as GMT (see ASN1_TIME_print())
956    """
957    from time import strptime
958    from calendar import timegm
959
960    months = (
961        "Jan","Feb","Mar","Apr","May","Jun",
962        "Jul","Aug","Sep","Oct","Nov","Dec"
963    )
964    time_format = ' %d %H:%M:%S %Y GMT' # NOTE: no month, fixed GMT
965    try:
966        month_number = months.index(cert_time[:3].title()) + 1
967    except ValueError:
968        raise ValueError('time data %r does not match '
969                         'format "%%b%s"' % (cert_time, time_format))
970    else:
971        # found valid month
972        tt = strptime(cert_time[3:], time_format)
973        # return an integer, the previous mktime()-based implementation
974        # returned a float (fractional seconds are always zero here).
975        return timegm((tt[0], month_number) + tt[2:6])
976
977PEM_HEADER = "-----BEGIN CERTIFICATE-----"
978PEM_FOOTER = "-----END CERTIFICATE-----"
979
980def DER_cert_to_PEM_cert(der_cert_bytes):
981    """Takes a certificate in binary DER format and returns the
982    PEM version of it as a string."""
983
984    f = base64.standard_b64encode(der_cert_bytes).decode('ascii')
985    return (PEM_HEADER + '\n' +
986            textwrap.fill(f, 64) + '\n' +
987            PEM_FOOTER + '\n')
988
989def PEM_cert_to_DER_cert(pem_cert_string):
990    """Takes a certificate in ASCII PEM format and returns the
991    DER-encoded version of it as a byte sequence"""
992
993    if not pem_cert_string.startswith(PEM_HEADER):
994        raise ValueError("Invalid PEM encoding; must start with %s"
995                         % PEM_HEADER)
996    if not pem_cert_string.strip().endswith(PEM_FOOTER):
997        raise ValueError("Invalid PEM encoding; must end with %s"
998                         % PEM_FOOTER)
999    d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
1000    return base64.decodestring(d.encode('ASCII', 'strict'))
1001
1002def get_server_certificate(addr, ssl_version=PROTOCOL_TLS, ca_certs=None):
1003    """Retrieve the certificate from the server at the specified address,
1004    and return it as a PEM-encoded string.
1005    If 'ca_certs' is specified, validate the server cert against it.
1006    If 'ssl_version' is specified, use it in the connection attempt."""
1007
1008    host, port = addr
1009    if ca_certs is not None:
1010        cert_reqs = CERT_REQUIRED
1011    else:
1012        cert_reqs = CERT_NONE
1013    context = _create_stdlib_context(ssl_version,
1014                                     cert_reqs=cert_reqs,
1015                                     cafile=ca_certs)
1016    with closing(create_connection(addr)) as sock:
1017        with closing(context.wrap_socket(sock)) as sslsock:
1018            dercert = sslsock.getpeercert(True)
1019    return DER_cert_to_PEM_cert(dercert)
1020
1021def get_protocol_name(protocol_code):
1022    return _PROTOCOL_NAMES.get(protocol_code, '<unknown>')
1023
1024
1025# a replacement for the old socket.ssl function
1026
1027def sslwrap_simple(sock, keyfile=None, certfile=None):
1028    """A replacement for the old socket.ssl function.  Designed
1029    for compability with Python 2.5 and earlier.  Will disappear in
1030    Python 3.0."""
1031    if hasattr(sock, "_sock"):
1032        sock = sock._sock
1033
1034    ctx = SSLContext(PROTOCOL_SSLv23)
1035    if keyfile or certfile:
1036        ctx.load_cert_chain(certfile, keyfile)
1037    ssl_sock = ctx._wrap_socket(sock, server_side=False)
1038    try:
1039        sock.getpeername()
1040    except socket_error:
1041        # no, no connection yet
1042        pass
1043    else:
1044        # yes, do the handshake
1045        ssl_sock.do_handshake()
1046
1047    return ssl_sock
1048