1# Wrapper module for _ssl, providing some additional facilities
2# implemented in Python.  Written by Bill Janssen.
3
4"""This module provides some more Pythonic support for SSL.
5
6Object types:
7
8  SSLSocket -- subtype of socket.socket which does SSL over the socket
9
10Exceptions:
11
12  SSLError -- exception raised for I/O errors
13
14Functions:
15
16  cert_time_to_seconds -- convert time string used for certificate
17                          notBefore and notAfter functions to integer
18                          seconds past the Epoch (the time values
19                          returned from time.time())
20
21  fetch_server_certificate (HOST, PORT) -- fetch the certificate provided
22                          by the server running on HOST at port PORT.  No
23                          validation of the certificate is performed.
24
25Integer constants:
26
27SSL_ERROR_ZERO_RETURN
28SSL_ERROR_WANT_READ
29SSL_ERROR_WANT_WRITE
30SSL_ERROR_WANT_X509_LOOKUP
31SSL_ERROR_SYSCALL
32SSL_ERROR_SSL
33SSL_ERROR_WANT_CONNECT
34
35SSL_ERROR_EOF
36SSL_ERROR_INVALID_ERROR_CODE
37
38The following group define certificate requirements that one side is
39allowing/requiring from the other side:
40
41CERT_NONE - no certificates from the other side are required (or will
42            be looked at if provided)
43CERT_OPTIONAL - certificates are not required, but if provided will be
44                validated, and if validation fails, the connection will
45                also fail
46CERT_REQUIRED - certificates are required, and will be validated, and
47                if validation fails, the connection will also fail
48
49The following constants identify various SSL protocol variants:
50
51PROTOCOL_SSLv2
52PROTOCOL_SSLv3
53PROTOCOL_SSLv23
54PROTOCOL_TLS
55PROTOCOL_TLS_CLIENT
56PROTOCOL_TLS_SERVER
57PROTOCOL_TLSv1
58PROTOCOL_TLSv1_1
59PROTOCOL_TLSv1_2
60
61The following constants identify various SSL alert message descriptions as per
62http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-6
63
64ALERT_DESCRIPTION_CLOSE_NOTIFY
65ALERT_DESCRIPTION_UNEXPECTED_MESSAGE
66ALERT_DESCRIPTION_BAD_RECORD_MAC
67ALERT_DESCRIPTION_RECORD_OVERFLOW
68ALERT_DESCRIPTION_DECOMPRESSION_FAILURE
69ALERT_DESCRIPTION_HANDSHAKE_FAILURE
70ALERT_DESCRIPTION_BAD_CERTIFICATE
71ALERT_DESCRIPTION_UNSUPPORTED_CERTIFICATE
72ALERT_DESCRIPTION_CERTIFICATE_REVOKED
73ALERT_DESCRIPTION_CERTIFICATE_EXPIRED
74ALERT_DESCRIPTION_CERTIFICATE_UNKNOWN
75ALERT_DESCRIPTION_ILLEGAL_PARAMETER
76ALERT_DESCRIPTION_UNKNOWN_CA
77ALERT_DESCRIPTION_ACCESS_DENIED
78ALERT_DESCRIPTION_DECODE_ERROR
79ALERT_DESCRIPTION_DECRYPT_ERROR
80ALERT_DESCRIPTION_PROTOCOL_VERSION
81ALERT_DESCRIPTION_INSUFFICIENT_SECURITY
82ALERT_DESCRIPTION_INTERNAL_ERROR
83ALERT_DESCRIPTION_USER_CANCELLED
84ALERT_DESCRIPTION_NO_RENEGOTIATION
85ALERT_DESCRIPTION_UNSUPPORTED_EXTENSION
86ALERT_DESCRIPTION_CERTIFICATE_UNOBTAINABLE
87ALERT_DESCRIPTION_UNRECOGNIZED_NAME
88ALERT_DESCRIPTION_BAD_CERTIFICATE_STATUS_RESPONSE
89ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE
90ALERT_DESCRIPTION_UNKNOWN_PSK_IDENTITY
91"""
92
93import ipaddress
94import textwrap
95import re
96import sys
97import os
98from collections import namedtuple
99from enum import Enum as _Enum, IntEnum as _IntEnum, IntFlag as _IntFlag
100
101import _ssl             # if we can't import it, let the error propagate
102
103from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
104from _ssl import _SSLContext, MemoryBIO, SSLSession
105from _ssl import (
106    SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError,
107    SSLSyscallError, SSLEOFError,
108    )
109from _ssl import txt2obj as _txt2obj, nid2obj as _nid2obj
110from _ssl import RAND_status, RAND_add, RAND_bytes, RAND_pseudo_bytes
111try:
112    from _ssl import RAND_egd
113except ImportError:
114    # LibreSSL does not provide RAND_egd
115    pass
116
117
118from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN
119from _ssl import _OPENSSL_API_VERSION
120
121
122_IntEnum._convert(
123    '_SSLMethod', __name__,
124    lambda name: name.startswith('PROTOCOL_') and name != 'PROTOCOL_SSLv23',
125    source=_ssl)
126
127_IntFlag._convert(
128    'Options', __name__,
129    lambda name: name.startswith('OP_'),
130    source=_ssl)
131
132_IntEnum._convert(
133    'AlertDescription', __name__,
134    lambda name: name.startswith('ALERT_DESCRIPTION_'),
135    source=_ssl)
136
137_IntEnum._convert(
138    'SSLErrorNumber', __name__,
139    lambda name: name.startswith('SSL_ERROR_'),
140    source=_ssl)
141
142_IntFlag._convert(
143    'VerifyFlags', __name__,
144    lambda name: name.startswith('VERIFY_'),
145    source=_ssl)
146
147_IntEnum._convert(
148    'VerifyMode', __name__,
149    lambda name: name.startswith('CERT_'),
150    source=_ssl)
151
152
153PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_TLS
154_PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items()}
155
156_SSLv2_IF_EXISTS = getattr(_SSLMethod, 'PROTOCOL_SSLv2', None)
157
158
159if sys.platform == "win32":
160    from _ssl import enum_certificates, enum_crls
161
162from socket import socket, AF_INET, SOCK_STREAM, create_connection
163from socket import SOL_SOCKET, SO_TYPE
164import base64        # for DER-to-PEM translation
165import errno
166import warnings
167
168
169socket_error = OSError  # keep that public name in module namespace
170
171if _ssl.HAS_TLS_UNIQUE:
172    CHANNEL_BINDING_TYPES = ['tls-unique']
173else:
174    CHANNEL_BINDING_TYPES = []
175
176
177# Disable weak or insecure ciphers by default
178# (OpenSSL's default setting is 'DEFAULT:!aNULL:!eNULL')
179# Enable a better set of ciphers by default
180# This list has been explicitly chosen to:
181#   * Prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE)
182#   * Prefer ECDHE over DHE for better performance
183#   * Prefer AEAD over CBC for better performance and security
184#   * Prefer AES-GCM over ChaCha20 because most platforms have AES-NI
185#     (ChaCha20 needs OpenSSL 1.1.0 or patched 1.0.2)
186#   * Prefer any AES-GCM and ChaCha20 over any AES-CBC for better
187#     performance and security
188#   * Then Use HIGH cipher suites as a fallback
189#   * Disable NULL authentication, NULL encryption, 3DES and MD5 MACs
190#     for security reasons
191_DEFAULT_CIPHERS = (
192    'ECDH+AESGCM:ECDH+CHACHA20:DH+AESGCM:DH+CHACHA20:ECDH+AES256:DH+AES256:'
193    'ECDH+AES128:DH+AES:ECDH+HIGH:DH+HIGH:RSA+AESGCM:RSA+AES:RSA+HIGH:'
194    '!aNULL:!eNULL:!MD5:!3DES'
195    )
196
197# Restricted and more secure ciphers for the server side
198# This list has been explicitly chosen to:
199#   * Prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE)
200#   * Prefer ECDHE over DHE for better performance
201#   * Prefer AEAD over CBC for better performance and security
202#   * Prefer AES-GCM over ChaCha20 because most platforms have AES-NI
203#   * Prefer any AES-GCM and ChaCha20 over any AES-CBC for better
204#     performance and security
205#   * Then Use HIGH cipher suites as a fallback
206#   * Disable NULL authentication, NULL encryption, MD5 MACs, DSS, RC4, and
207#     3DES for security reasons
208_RESTRICTED_SERVER_CIPHERS = (
209    'ECDH+AESGCM:ECDH+CHACHA20:DH+AESGCM:DH+CHACHA20:ECDH+AES256:DH+AES256:'
210    'ECDH+AES128:DH+AES:ECDH+HIGH:DH+HIGH:RSA+AESGCM:RSA+AES:RSA+HIGH:'
211    '!aNULL:!eNULL:!MD5:!DSS:!RC4:!3DES'
212)
213
214
215class CertificateError(ValueError):
216    pass
217
218
219def _dnsname_match(dn, hostname, max_wildcards=1):
220    """Matching according to RFC 6125, section 6.4.3
221
222    http://tools.ietf.org/html/rfc6125#section-6.4.3
223    """
224    pats = []
225    if not dn:
226        return False
227
228    leftmost, *remainder = dn.split(r'.')
229
230    wildcards = leftmost.count('*')
231    if wildcards > max_wildcards:
232        # Issue #17980: avoid denials of service by refusing more
233        # than one wildcard per fragment.  A survey of established
234        # policy among SSL implementations showed it to be a
235        # reasonable choice.
236        raise CertificateError(
237            "too many wildcards in certificate DNS name: " + repr(dn))
238
239    # speed up common case w/o wildcards
240    if not wildcards:
241        return dn.lower() == hostname.lower()
242
243    # RFC 6125, section 6.4.3, subitem 1.
244    # The client SHOULD NOT attempt to match a presented identifier in which
245    # the wildcard character comprises a label other than the left-most label.
246    if leftmost == '*':
247        # When '*' is a fragment by itself, it matches a non-empty dotless
248        # fragment.
249        pats.append('[^.]+')
250    elif leftmost.startswith('xn--') or hostname.startswith('xn--'):
251        # RFC 6125, section 6.4.3, subitem 3.
252        # The client SHOULD NOT attempt to match a presented identifier
253        # where the wildcard character is embedded within an A-label or
254        # U-label of an internationalized domain name.
255        pats.append(re.escape(leftmost))
256    else:
257        # Otherwise, '*' matches any dotless string, e.g. www*
258        pats.append(re.escape(leftmost).replace(r'\*', '[^.]*'))
259
260    # add the remaining fragments, ignore any wildcards
261    for frag in remainder:
262        pats.append(re.escape(frag))
263
264    pat = re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE)
265    return pat.match(hostname)
266
267
268def _ipaddress_match(ipname, host_ip):
269    """Exact matching of IP addresses.
270
271    RFC 6125 explicitly doesn't define an algorithm for this
272    (section 1.7.2 - "Out of Scope").
273    """
274    # OpenSSL may add a trailing newline to a subjectAltName's IP address
275    ip = ipaddress.ip_address(ipname.rstrip())
276    return ip == host_ip
277
278
279def match_hostname(cert, hostname):
280    """Verify that *cert* (in decoded format as returned by
281    SSLSocket.getpeercert()) matches the *hostname*.  RFC 2818 and RFC 6125
282    rules are followed, but IP addresses are not accepted for *hostname*.
283
284    CertificateError is raised on failure. On success, the function
285    returns nothing.
286    """
287    if not cert:
288        raise ValueError("empty or no certificate, match_hostname needs a "
289                         "SSL socket or SSL context with either "
290                         "CERT_OPTIONAL or CERT_REQUIRED")
291    try:
292        host_ip = ipaddress.ip_address(hostname)
293    except ValueError:
294        # Not an IP address (common case)
295        host_ip = None
296    dnsnames = []
297    san = cert.get('subjectAltName', ())
298    for key, value in san:
299        if key == 'DNS':
300            if host_ip is None and _dnsname_match(value, hostname):
301                return
302            dnsnames.append(value)
303        elif key == 'IP Address':
304            if host_ip is not None and _ipaddress_match(value, host_ip):
305                return
306            dnsnames.append(value)
307    if not dnsnames:
308        # The subject is only checked when there is no dNSName entry
309        # in subjectAltName
310        for sub in cert.get('subject', ()):
311            for key, value in sub:
312                # XXX according to RFC 2818, the most specific Common Name
313                # must be used.
314                if key == 'commonName':
315                    if _dnsname_match(value, hostname):
316                        return
317                    dnsnames.append(value)
318    if len(dnsnames) > 1:
319        raise CertificateError("hostname %r "
320            "doesn't match either of %s"
321            % (hostname, ', '.join(map(repr, dnsnames))))
322    elif len(dnsnames) == 1:
323        raise CertificateError("hostname %r "
324            "doesn't match %r"
325            % (hostname, dnsnames[0]))
326    else:
327        raise CertificateError("no appropriate commonName or "
328            "subjectAltName fields were found")
329
330
331DefaultVerifyPaths = namedtuple("DefaultVerifyPaths",
332    "cafile capath openssl_cafile_env openssl_cafile openssl_capath_env "
333    "openssl_capath")
334
335def get_default_verify_paths():
336    """Return paths to default cafile and capath.
337    """
338    parts = _ssl.get_default_verify_paths()
339
340    # environment vars shadow paths
341    cafile = os.environ.get(parts[0], parts[1])
342    capath = os.environ.get(parts[2], parts[3])
343
344    return DefaultVerifyPaths(cafile if os.path.isfile(cafile) else None,
345                              capath if os.path.isdir(capath) else None,
346                              *parts)
347
348
349class _ASN1Object(namedtuple("_ASN1Object", "nid shortname longname oid")):
350    """ASN.1 object identifier lookup
351    """
352    __slots__ = ()
353
354    def __new__(cls, oid):
355        return super().__new__(cls, *_txt2obj(oid, name=False))
356
357    @classmethod
358    def fromnid(cls, nid):
359        """Create _ASN1Object from OpenSSL numeric ID
360        """
361        return super().__new__(cls, *_nid2obj(nid))
362
363    @classmethod
364    def fromname(cls, name):
365        """Create _ASN1Object from short name, long name or OID
366        """
367        return super().__new__(cls, *_txt2obj(name, name=True))
368
369
370class Purpose(_ASN1Object, _Enum):
371    """SSLContext purpose flags with X509v3 Extended Key Usage objects
372    """
373    SERVER_AUTH = '1.3.6.1.5.5.7.3.1'
374    CLIENT_AUTH = '1.3.6.1.5.5.7.3.2'
375
376
377class SSLContext(_SSLContext):
378    """An SSLContext holds various SSL-related configuration options and
379    data, such as certificates and possibly a private key."""
380
381    __slots__ = ('protocol', '__weakref__')
382    _windows_cert_stores = ("CA", "ROOT")
383
384    def __new__(cls, protocol=PROTOCOL_TLS, *args, **kwargs):
385        self = _SSLContext.__new__(cls, protocol)
386        if protocol != _SSLv2_IF_EXISTS:
387            self.set_ciphers(_DEFAULT_CIPHERS)
388        return self
389
390    def __init__(self, protocol=PROTOCOL_TLS):
391        self.protocol = protocol
392
393    def wrap_socket(self, sock, server_side=False,
394                    do_handshake_on_connect=True,
395                    suppress_ragged_eofs=True,
396                    server_hostname=None, session=None):
397        return SSLSocket(sock=sock, server_side=server_side,
398                         do_handshake_on_connect=do_handshake_on_connect,
399                         suppress_ragged_eofs=suppress_ragged_eofs,
400                         server_hostname=server_hostname,
401                         _context=self, _session=session)
402
403    def wrap_bio(self, incoming, outgoing, server_side=False,
404                 server_hostname=None, session=None):
405        sslobj = self._wrap_bio(incoming, outgoing, server_side=server_side,
406                                server_hostname=server_hostname)
407        return SSLObject(sslobj, session=session)
408
409    def set_npn_protocols(self, npn_protocols):
410        protos = bytearray()
411        for protocol in npn_protocols:
412            b = bytes(protocol, 'ascii')
413            if len(b) == 0 or len(b) > 255:
414                raise SSLError('NPN protocols must be 1 to 255 in length')
415            protos.append(len(b))
416            protos.extend(b)
417
418        self._set_npn_protocols(protos)
419
420    def set_alpn_protocols(self, alpn_protocols):
421        protos = bytearray()
422        for protocol in alpn_protocols:
423            b = bytes(protocol, 'ascii')
424            if len(b) == 0 or len(b) > 255:
425                raise SSLError('ALPN protocols must be 1 to 255 in length')
426            protos.append(len(b))
427            protos.extend(b)
428
429        self._set_alpn_protocols(protos)
430
431    def _load_windows_store_certs(self, storename, purpose):
432        certs = bytearray()
433        try:
434            for cert, encoding, trust in enum_certificates(storename):
435                # CA certs are never PKCS#7 encoded
436                if encoding == "x509_asn":
437                    if trust is True or purpose.oid in trust:
438                        certs.extend(cert)
439        except PermissionError:
440            warnings.warn("unable to enumerate Windows certificate store")
441        if certs:
442            self.load_verify_locations(cadata=certs)
443        return certs
444
445    def load_default_certs(self, purpose=Purpose.SERVER_AUTH):
446        if not isinstance(purpose, _ASN1Object):
447            raise TypeError(purpose)
448        if sys.platform == "win32":
449            for storename in self._windows_cert_stores:
450                self._load_windows_store_certs(storename, purpose)
451        self.set_default_verify_paths()
452
453    @property
454    def options(self):
455        return Options(super().options)
456
457    @options.setter
458    def options(self, value):
459        super(SSLContext, SSLContext).options.__set__(self, value)
460
461    @property
462    def verify_flags(self):
463        return VerifyFlags(super().verify_flags)
464
465    @verify_flags.setter
466    def verify_flags(self, value):
467        super(SSLContext, SSLContext).verify_flags.__set__(self, value)
468
469    @property
470    def verify_mode(self):
471        value = super().verify_mode
472        try:
473            return VerifyMode(value)
474        except ValueError:
475            return value
476
477    @verify_mode.setter
478    def verify_mode(self, value):
479        super(SSLContext, SSLContext).verify_mode.__set__(self, value)
480
481
482def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None,
483                           capath=None, cadata=None):
484    """Create a SSLContext object with default settings.
485
486    NOTE: The protocol and settings may change anytime without prior
487          deprecation. The values represent a fair balance between maximum
488          compatibility and security.
489    """
490    if not isinstance(purpose, _ASN1Object):
491        raise TypeError(purpose)
492
493    # SSLContext sets OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION,
494    # OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE and OP_SINGLE_ECDH_USE
495    # by default.
496    context = SSLContext(PROTOCOL_TLS)
497
498    if purpose == Purpose.SERVER_AUTH:
499        # verify certs and host name in client mode
500        context.verify_mode = CERT_REQUIRED
501        context.check_hostname = True
502    elif purpose == Purpose.CLIENT_AUTH:
503        context.set_ciphers(_RESTRICTED_SERVER_CIPHERS)
504
505    if cafile or capath or cadata:
506        context.load_verify_locations(cafile, capath, cadata)
507    elif context.verify_mode != CERT_NONE:
508        # no explicit cafile, capath or cadata but the verify mode is
509        # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
510        # root CA certificates for the given purpose. This may fail silently.
511        context.load_default_certs(purpose)
512    return context
513
514def _create_unverified_context(protocol=PROTOCOL_TLS, *, cert_reqs=None,
515                           check_hostname=False, purpose=Purpose.SERVER_AUTH,
516                           certfile=None, keyfile=None,
517                           cafile=None, capath=None, cadata=None):
518    """Create a SSLContext object for Python stdlib modules
519
520    All Python stdlib modules shall use this function to create SSLContext
521    objects in order to keep common settings in one place. The configuration
522    is less restrict than create_default_context()'s to increase backward
523    compatibility.
524    """
525    if not isinstance(purpose, _ASN1Object):
526        raise TypeError(purpose)
527
528    # SSLContext sets OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION,
529    # OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE and OP_SINGLE_ECDH_USE
530    # by default.
531    context = SSLContext(protocol)
532
533    if cert_reqs is not None:
534        context.verify_mode = cert_reqs
535    context.check_hostname = check_hostname
536
537    if keyfile and not certfile:
538        raise ValueError("certfile must be specified")
539    if certfile or keyfile:
540        context.load_cert_chain(certfile, keyfile)
541
542    # load CA root certs
543    if cafile or capath or cadata:
544        context.load_verify_locations(cafile, capath, cadata)
545    elif context.verify_mode != CERT_NONE:
546        # no explicit cafile, capath or cadata but the verify mode is
547        # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
548        # root CA certificates for the given purpose. This may fail silently.
549        context.load_default_certs(purpose)
550
551    return context
552
553# Used by http.client if no context is explicitly passed.
554_create_default_https_context = create_default_context
555
556
557# Backwards compatibility alias, even though it's not a public name.
558_create_stdlib_context = _create_unverified_context
559
560
561class SSLObject:
562    """This class implements an interface on top of a low-level SSL object as
563    implemented by OpenSSL. This object captures the state of an SSL connection
564    but does not provide any network IO itself. IO needs to be performed
565    through separate "BIO" objects which are OpenSSL's IO abstraction layer.
566
567    This class does not have a public constructor. Instances are returned by
568    ``SSLContext.wrap_bio``. This class is typically used by framework authors
569    that want to implement asynchronous IO for SSL through memory buffers.
570
571    When compared to ``SSLSocket``, this object lacks the following features:
572
573     * Any form of network IO incluging methods such as ``recv`` and ``send``.
574     * The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery.
575    """
576
577    def __init__(self, sslobj, owner=None, session=None):
578        self._sslobj = sslobj
579        # Note: _sslobj takes a weak reference to owner
580        self._sslobj.owner = owner or self
581        if session is not None:
582            self._sslobj.session = session
583
584    @property
585    def context(self):
586        """The SSLContext that is currently in use."""
587        return self._sslobj.context
588
589    @context.setter
590    def context(self, ctx):
591        self._sslobj.context = ctx
592
593    @property
594    def session(self):
595        """The SSLSession for client socket."""
596        return self._sslobj.session
597
598    @session.setter
599    def session(self, session):
600        self._sslobj.session = session
601
602    @property
603    def session_reused(self):
604        """Was the client session reused during handshake"""
605        return self._sslobj.session_reused
606
607    @property
608    def server_side(self):
609        """Whether this is a server-side socket."""
610        return self._sslobj.server_side
611
612    @property
613    def server_hostname(self):
614        """The currently set server hostname (for SNI), or ``None`` if no
615        server hostame is set."""
616        return self._sslobj.server_hostname
617
618    def read(self, len=1024, buffer=None):
619        """Read up to 'len' bytes from the SSL object and return them.
620
621        If 'buffer' is provided, read into this buffer and return the number of
622        bytes read.
623        """
624        if buffer is not None:
625            v = self._sslobj.read(len, buffer)
626        else:
627            v = self._sslobj.read(len)
628        return v
629
630    def write(self, data):
631        """Write 'data' to the SSL object and return the number of bytes
632        written.
633
634        The 'data' argument must support the buffer interface.
635        """
636        return self._sslobj.write(data)
637
638    def getpeercert(self, binary_form=False):
639        """Returns a formatted version of the data in the certificate provided
640        by the other end of the SSL channel.
641
642        Return None if no certificate was provided, {} if a certificate was
643        provided, but not validated.
644        """
645        return self._sslobj.peer_certificate(binary_form)
646
647    def selected_npn_protocol(self):
648        """Return the currently selected NPN protocol as a string, or ``None``
649        if a next protocol was not negotiated or if NPN is not supported by one
650        of the peers."""
651        if _ssl.HAS_NPN:
652            return self._sslobj.selected_npn_protocol()
653
654    def selected_alpn_protocol(self):
655        """Return the currently selected ALPN protocol as a string, or ``None``
656        if a next protocol was not negotiated or if ALPN is not supported by one
657        of the peers."""
658        if _ssl.HAS_ALPN:
659            return self._sslobj.selected_alpn_protocol()
660
661    def cipher(self):
662        """Return the currently selected cipher as a 3-tuple ``(name,
663        ssl_version, secret_bits)``."""
664        return self._sslobj.cipher()
665
666    def shared_ciphers(self):
667        """Return a list of ciphers shared by the client during the handshake or
668        None if this is not a valid server connection.
669        """
670        return self._sslobj.shared_ciphers()
671
672    def compression(self):
673        """Return the current compression algorithm in use, or ``None`` if
674        compression was not negotiated or not supported by one of the peers."""
675        return self._sslobj.compression()
676
677    def pending(self):
678        """Return the number of bytes that can be read immediately."""
679        return self._sslobj.pending()
680
681    def do_handshake(self):
682        """Start the SSL/TLS handshake."""
683        self._sslobj.do_handshake()
684        if self.context.check_hostname:
685            if not self.server_hostname:
686                raise ValueError("check_hostname needs server_hostname "
687                                 "argument")
688            match_hostname(self.getpeercert(), self.server_hostname)
689
690    def unwrap(self):
691        """Start the SSL shutdown handshake."""
692        return self._sslobj.shutdown()
693
694    def get_channel_binding(self, cb_type="tls-unique"):
695        """Get channel binding data for current connection.  Raise ValueError
696        if the requested `cb_type` is not supported.  Return bytes of the data
697        or None if the data is not available (e.g. before the handshake)."""
698        if cb_type not in CHANNEL_BINDING_TYPES:
699            raise ValueError("Unsupported channel binding type")
700        if cb_type != "tls-unique":
701            raise NotImplementedError(
702                            "{0} channel binding type not implemented"
703                            .format(cb_type))
704        return self._sslobj.tls_unique_cb()
705
706    def version(self):
707        """Return a string identifying the protocol version used by the
708        current SSL channel. """
709        return self._sslobj.version()
710
711
712class SSLSocket(socket):
713    """This class implements a subtype of socket.socket that wraps
714    the underlying OS socket in an SSL context when necessary, and
715    provides read and write methods over that channel."""
716
717    def __init__(self, sock=None, keyfile=None, certfile=None,
718                 server_side=False, cert_reqs=CERT_NONE,
719                 ssl_version=PROTOCOL_TLS, ca_certs=None,
720                 do_handshake_on_connect=True,
721                 family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
722                 suppress_ragged_eofs=True, npn_protocols=None, ciphers=None,
723                 server_hostname=None,
724                 _context=None, _session=None):
725
726        if _context:
727            self._context = _context
728        else:
729            if server_side and not certfile:
730                raise ValueError("certfile must be specified for server-side "
731                                 "operations")
732            if keyfile and not certfile:
733                raise ValueError("certfile must be specified")
734            if certfile and not keyfile:
735                keyfile = certfile
736            self._context = SSLContext(ssl_version)
737            self._context.verify_mode = cert_reqs
738            if ca_certs:
739                self._context.load_verify_locations(ca_certs)
740            if certfile:
741                self._context.load_cert_chain(certfile, keyfile)
742            if npn_protocols:
743                self._context.set_npn_protocols(npn_protocols)
744            if ciphers:
745                self._context.set_ciphers(ciphers)
746            self.keyfile = keyfile
747            self.certfile = certfile
748            self.cert_reqs = cert_reqs
749            self.ssl_version = ssl_version
750            self.ca_certs = ca_certs
751            self.ciphers = ciphers
752        # Can't use sock.type as other flags (such as SOCK_NONBLOCK) get
753        # mixed in.
754        if sock.getsockopt(SOL_SOCKET, SO_TYPE) != SOCK_STREAM:
755            raise NotImplementedError("only stream sockets are supported")
756        if server_side:
757            if server_hostname:
758                raise ValueError("server_hostname can only be specified "
759                                 "in client mode")
760            if _session is not None:
761                raise ValueError("session can only be specified in "
762                                 "client mode")
763        if self._context.check_hostname and not server_hostname:
764            raise ValueError("check_hostname requires server_hostname")
765        self._session = _session
766        self.server_side = server_side
767        self.server_hostname = server_hostname
768        self.do_handshake_on_connect = do_handshake_on_connect
769        self.suppress_ragged_eofs = suppress_ragged_eofs
770        if sock is not None:
771            socket.__init__(self,
772                            family=sock.family,
773                            type=sock.type,
774                            proto=sock.proto,
775                            fileno=sock.fileno())
776            self.settimeout(sock.gettimeout())
777            sock.detach()
778        elif fileno is not None:
779            socket.__init__(self, fileno=fileno)
780        else:
781            socket.__init__(self, family=family, type=type, proto=proto)
782
783        # See if we are connected
784        try:
785            self.getpeername()
786        except OSError as e:
787            if e.errno != errno.ENOTCONN:
788                raise
789            connected = False
790        else:
791            connected = True
792
793        self._closed = False
794        self._sslobj = None
795        self._connected = connected
796        if connected:
797            # create the SSL object
798            try:
799                sslobj = self._context._wrap_socket(self, server_side,
800                                                    server_hostname)
801                self._sslobj = SSLObject(sslobj, owner=self,
802                                         session=self._session)
803                if do_handshake_on_connect:
804                    timeout = self.gettimeout()
805                    if timeout == 0.0:
806                        # non-blocking
807                        raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
808                    self.do_handshake()
809
810            except (OSError, ValueError):
811                self.close()
812                raise
813
814    @property
815    def context(self):
816        return self._context
817
818    @context.setter
819    def context(self, ctx):
820        self._context = ctx
821        self._sslobj.context = ctx
822
823    @property
824    def session(self):
825        """The SSLSession for client socket."""
826        if self._sslobj is not None:
827            return self._sslobj.session
828
829    @session.setter
830    def session(self, session):
831        self._session = session
832        if self._sslobj is not None:
833            self._sslobj.session = session
834
835    @property
836    def session_reused(self):
837        """Was the client session reused during handshake"""
838        if self._sslobj is not None:
839            return self._sslobj.session_reused
840
841    def dup(self):
842        raise NotImplemented("Can't dup() %s instances" %
843                             self.__class__.__name__)
844
845    def _checkClosed(self, msg=None):
846        # raise an exception here if you wish to check for spurious closes
847        pass
848
849    def _check_connected(self):
850        if not self._connected:
851            # getpeername() will raise ENOTCONN if the socket is really
852            # not connected; note that we can be connected even without
853            # _connected being set, e.g. if connect() first returned
854            # EAGAIN.
855            self.getpeername()
856
857    def read(self, len=1024, buffer=None):
858        """Read up to LEN bytes and return them.
859        Return zero-length string on EOF."""
860
861        self._checkClosed()
862        if not self._sslobj:
863            raise ValueError("Read on closed or unwrapped SSL socket.")
864        try:
865            return self._sslobj.read(len, buffer)
866        except SSLError as x:
867            if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
868                if buffer is not None:
869                    return 0
870                else:
871                    return b''
872            else:
873                raise
874
875    def write(self, data):
876        """Write DATA to the underlying SSL channel.  Returns
877        number of bytes of DATA actually transmitted."""
878
879        self._checkClosed()
880        if not self._sslobj:
881            raise ValueError("Write on closed or unwrapped SSL socket.")
882        return self._sslobj.write(data)
883
884    def getpeercert(self, binary_form=False):
885        """Returns a formatted version of the data in the
886        certificate provided by the other end of the SSL channel.
887        Return None if no certificate was provided, {} if a
888        certificate was provided, but not validated."""
889
890        self._checkClosed()
891        self._check_connected()
892        return self._sslobj.getpeercert(binary_form)
893
894    def selected_npn_protocol(self):
895        self._checkClosed()
896        if not self._sslobj or not _ssl.HAS_NPN:
897            return None
898        else:
899            return self._sslobj.selected_npn_protocol()
900
901    def selected_alpn_protocol(self):
902        self._checkClosed()
903        if not self._sslobj or not _ssl.HAS_ALPN:
904            return None
905        else:
906            return self._sslobj.selected_alpn_protocol()
907
908    def cipher(self):
909        self._checkClosed()
910        if not self._sslobj:
911            return None
912        else:
913            return self._sslobj.cipher()
914
915    def shared_ciphers(self):
916        self._checkClosed()
917        if not self._sslobj:
918            return None
919        return self._sslobj.shared_ciphers()
920
921    def compression(self):
922        self._checkClosed()
923        if not self._sslobj:
924            return None
925        else:
926            return self._sslobj.compression()
927
928    def send(self, data, flags=0):
929        self._checkClosed()
930        if self._sslobj:
931            if flags != 0:
932                raise ValueError(
933                    "non-zero flags not allowed in calls to send() on %s" %
934                    self.__class__)
935            return self._sslobj.write(data)
936        else:
937            return socket.send(self, data, flags)
938
939    def sendto(self, data, flags_or_addr, addr=None):
940        self._checkClosed()
941        if self._sslobj:
942            raise ValueError("sendto not allowed on instances of %s" %
943                             self.__class__)
944        elif addr is None:
945            return socket.sendto(self, data, flags_or_addr)
946        else:
947            return socket.sendto(self, data, flags_or_addr, addr)
948
949    def sendmsg(self, *args, **kwargs):
950        # Ensure programs don't send data unencrypted if they try to
951        # use this method.
952        raise NotImplementedError("sendmsg not allowed on instances of %s" %
953                                  self.__class__)
954
955    def sendall(self, data, flags=0):
956        self._checkClosed()
957        if self._sslobj:
958            if flags != 0:
959                raise ValueError(
960                    "non-zero flags not allowed in calls to sendall() on %s" %
961                    self.__class__)
962            amount = len(data)
963            count = 0
964            while (count < amount):
965                v = self.send(data[count:])
966                count += v
967        else:
968            return socket.sendall(self, data, flags)
969
970    def sendfile(self, file, offset=0, count=None):
971        """Send a file, possibly by using os.sendfile() if this is a
972        clear-text socket.  Return the total number of bytes sent.
973        """
974        if self._sslobj is None:
975            # os.sendfile() works with plain sockets only
976            return super().sendfile(file, offset, count)
977        else:
978            return self._sendfile_use_send(file, offset, count)
979
980    def recv(self, buflen=1024, flags=0):
981        self._checkClosed()
982        if self._sslobj:
983            if flags != 0:
984                raise ValueError(
985                    "non-zero flags not allowed in calls to recv() on %s" %
986                    self.__class__)
987            return self.read(buflen)
988        else:
989            return socket.recv(self, buflen, flags)
990
991    def recv_into(self, buffer, nbytes=None, flags=0):
992        self._checkClosed()
993        if buffer and (nbytes is None):
994            nbytes = len(buffer)
995        elif nbytes is None:
996            nbytes = 1024
997        if self._sslobj:
998            if flags != 0:
999                raise ValueError(
1000                  "non-zero flags not allowed in calls to recv_into() on %s" %
1001                  self.__class__)
1002            return self.read(nbytes, buffer)
1003        else:
1004            return socket.recv_into(self, buffer, nbytes, flags)
1005
1006    def recvfrom(self, buflen=1024, flags=0):
1007        self._checkClosed()
1008        if self._sslobj:
1009            raise ValueError("recvfrom not allowed on instances of %s" %
1010                             self.__class__)
1011        else:
1012            return socket.recvfrom(self, buflen, flags)
1013
1014    def recvfrom_into(self, buffer, nbytes=None, flags=0):
1015        self._checkClosed()
1016        if self._sslobj:
1017            raise ValueError("recvfrom_into not allowed on instances of %s" %
1018                             self.__class__)
1019        else:
1020            return socket.recvfrom_into(self, buffer, nbytes, flags)
1021
1022    def recvmsg(self, *args, **kwargs):
1023        raise NotImplementedError("recvmsg not allowed on instances of %s" %
1024                                  self.__class__)
1025
1026    def recvmsg_into(self, *args, **kwargs):
1027        raise NotImplementedError("recvmsg_into not allowed on instances of "
1028                                  "%s" % self.__class__)
1029
1030    def pending(self):
1031        self._checkClosed()
1032        if self._sslobj:
1033            return self._sslobj.pending()
1034        else:
1035            return 0
1036
1037    def shutdown(self, how):
1038        self._checkClosed()
1039        self._sslobj = None
1040        socket.shutdown(self, how)
1041
1042    def unwrap(self):
1043        if self._sslobj:
1044            s = self._sslobj.unwrap()
1045            self._sslobj = None
1046            return s
1047        else:
1048            raise ValueError("No SSL wrapper around " + str(self))
1049
1050    def _real_close(self):
1051        self._sslobj = None
1052        socket._real_close(self)
1053
1054    def do_handshake(self, block=False):
1055        """Perform a TLS/SSL handshake."""
1056        self._check_connected()
1057        timeout = self.gettimeout()
1058        try:
1059            if timeout == 0.0 and block:
1060                self.settimeout(None)
1061            self._sslobj.do_handshake()
1062        finally:
1063            self.settimeout(timeout)
1064
1065    def _real_connect(self, addr, connect_ex):
1066        if self.server_side:
1067            raise ValueError("can't connect in server-side mode")
1068        # Here we assume that the socket is client-side, and not
1069        # connected at the time of the call.  We connect it, then wrap it.
1070        if self._connected:
1071            raise ValueError("attempt to connect already-connected SSLSocket!")
1072        sslobj = self.context._wrap_socket(self, False, self.server_hostname)
1073        self._sslobj = SSLObject(sslobj, owner=self,
1074                                 session=self._session)
1075        try:
1076            if connect_ex:
1077                rc = socket.connect_ex(self, addr)
1078            else:
1079                rc = None
1080                socket.connect(self, addr)
1081            if not rc:
1082                self._connected = True
1083                if self.do_handshake_on_connect:
1084                    self.do_handshake()
1085            return rc
1086        except (OSError, ValueError):
1087            self._sslobj = None
1088            raise
1089
1090    def connect(self, addr):
1091        """Connects to remote ADDR, and then wraps the connection in
1092        an SSL channel."""
1093        self._real_connect(addr, False)
1094
1095    def connect_ex(self, addr):
1096        """Connects to remote ADDR, and then wraps the connection in
1097        an SSL channel."""
1098        return self._real_connect(addr, True)
1099
1100    def accept(self):
1101        """Accepts a new connection from a remote client, and returns
1102        a tuple containing that new connection wrapped with a server-side
1103        SSL channel, and the address of the remote client."""
1104
1105        newsock, addr = socket.accept(self)
1106        newsock = self.context.wrap_socket(newsock,
1107                    do_handshake_on_connect=self.do_handshake_on_connect,
1108                    suppress_ragged_eofs=self.suppress_ragged_eofs,
1109                    server_side=True)
1110        return newsock, addr
1111
1112    def get_channel_binding(self, cb_type="tls-unique"):
1113        """Get channel binding data for current connection.  Raise ValueError
1114        if the requested `cb_type` is not supported.  Return bytes of the data
1115        or None if the data is not available (e.g. before the handshake).
1116        """
1117        if self._sslobj is None:
1118            return None
1119        return self._sslobj.get_channel_binding(cb_type)
1120
1121    def version(self):
1122        """
1123        Return a string identifying the protocol version used by the
1124        current SSL channel, or None if there is no established channel.
1125        """
1126        if self._sslobj is None:
1127            return None
1128        return self._sslobj.version()
1129
1130
1131def wrap_socket(sock, keyfile=None, certfile=None,
1132                server_side=False, cert_reqs=CERT_NONE,
1133                ssl_version=PROTOCOL_TLS, ca_certs=None,
1134                do_handshake_on_connect=True,
1135                suppress_ragged_eofs=True,
1136                ciphers=None):
1137    return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,
1138                     server_side=server_side, cert_reqs=cert_reqs,
1139                     ssl_version=ssl_version, ca_certs=ca_certs,
1140                     do_handshake_on_connect=do_handshake_on_connect,
1141                     suppress_ragged_eofs=suppress_ragged_eofs,
1142                     ciphers=ciphers)
1143
1144# some utility functions
1145
1146def cert_time_to_seconds(cert_time):
1147    """Return the time in seconds since the Epoch, given the timestring
1148    representing the "notBefore" or "notAfter" date from a certificate
1149    in ``"%b %d %H:%M:%S %Y %Z"`` strptime format (C locale).
1150
1151    "notBefore" or "notAfter" dates must use UTC (RFC 5280).
1152
1153    Month is one of: Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec
1154    UTC should be specified as GMT (see ASN1_TIME_print())
1155    """
1156    from time import strptime
1157    from calendar import timegm
1158
1159    months = (
1160        "Jan","Feb","Mar","Apr","May","Jun",
1161        "Jul","Aug","Sep","Oct","Nov","Dec"
1162    )
1163    time_format = ' %d %H:%M:%S %Y GMT' # NOTE: no month, fixed GMT
1164    try:
1165        month_number = months.index(cert_time[:3].title()) + 1
1166    except ValueError:
1167        raise ValueError('time data %r does not match '
1168                         'format "%%b%s"' % (cert_time, time_format))
1169    else:
1170        # found valid month
1171        tt = strptime(cert_time[3:], time_format)
1172        # return an integer, the previous mktime()-based implementation
1173        # returned a float (fractional seconds are always zero here).
1174        return timegm((tt[0], month_number) + tt[2:6])
1175
1176PEM_HEADER = "-----BEGIN CERTIFICATE-----"
1177PEM_FOOTER = "-----END CERTIFICATE-----"
1178
1179def DER_cert_to_PEM_cert(der_cert_bytes):
1180    """Takes a certificate in binary DER format and returns the
1181    PEM version of it as a string."""
1182
1183    f = str(base64.standard_b64encode(der_cert_bytes), 'ASCII', 'strict')
1184    return (PEM_HEADER + '\n' +
1185            textwrap.fill(f, 64) + '\n' +
1186            PEM_FOOTER + '\n')
1187
1188def PEM_cert_to_DER_cert(pem_cert_string):
1189    """Takes a certificate in ASCII PEM format and returns the
1190    DER-encoded version of it as a byte sequence"""
1191
1192    if not pem_cert_string.startswith(PEM_HEADER):
1193        raise ValueError("Invalid PEM encoding; must start with %s"
1194                         % PEM_HEADER)
1195    if not pem_cert_string.strip().endswith(PEM_FOOTER):
1196        raise ValueError("Invalid PEM encoding; must end with %s"
1197                         % PEM_FOOTER)
1198    d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
1199    return base64.decodebytes(d.encode('ASCII', 'strict'))
1200
1201def get_server_certificate(addr, ssl_version=PROTOCOL_TLS, ca_certs=None):
1202    """Retrieve the certificate from the server at the specified address,
1203    and return it as a PEM-encoded string.
1204    If 'ca_certs' is specified, validate the server cert against it.
1205    If 'ssl_version' is specified, use it in the connection attempt."""
1206
1207    host, port = addr
1208    if ca_certs is not None:
1209        cert_reqs = CERT_REQUIRED
1210    else:
1211        cert_reqs = CERT_NONE
1212    context = _create_stdlib_context(ssl_version,
1213                                     cert_reqs=cert_reqs,
1214                                     cafile=ca_certs)
1215    with  create_connection(addr) as sock:
1216        with context.wrap_socket(sock) as sslsock:
1217            dercert = sslsock.getpeercert(True)
1218    return DER_cert_to_PEM_cert(dercert)
1219
1220def get_protocol_name(protocol_code):
1221    return _PROTOCOL_NAMES.get(protocol_code, '<unknown>')
1222