tlsrecordlayer.py revision f8ee788a64d60abd8f2d742a5fdedde054ecd910
1# Authors:
2#   Trevor Perrin
3#   Google (adapted by Sam Rushing) - NPN support
4#   Martin von Loewis - python 3 port
5#
6# See the LICENSE file for legal information regarding use of this file.
7
8"""Helper class for TLSConnection."""
9from __future__ import generators
10
11from .utils.compat import *
12from .utils.cryptomath import *
13from .utils.cipherfactory import createAES, createRC4, createTripleDES
14from .utils.codec import *
15from .errors import *
16from .messages import *
17from .mathtls import *
18from .constants import *
19from .utils.cryptomath import getRandomBytes
20
21import socket
22import errno
23import traceback
24
25class _ConnectionState(object):
26    def __init__(self):
27        self.macContext = None
28        self.encContext = None
29        self.seqnum = 0
30
31    def getSeqNumBytes(self):
32        w = Writer()
33        w.add(self.seqnum, 8)
34        self.seqnum += 1
35        return w.bytes
36
37
38class TLSRecordLayer(object):
39    """
40    This class handles data transmission for a TLS connection.
41
42    Its only subclass is L{tlslite.TLSConnection.TLSConnection}.  We've
43    separated the code in this class from TLSConnection to make things
44    more readable.
45
46
47    @type sock: socket.socket
48    @ivar sock: The underlying socket object.
49
50    @type session: L{tlslite.Session.Session}
51    @ivar session: The session corresponding to this connection.
52
53    Due to TLS session resumption, multiple connections can correspond
54    to the same underlying session.
55
56    @type version: tuple
57    @ivar version: The TLS version being used for this connection.
58
59    (3,0) means SSL 3.0, and (3,1) means TLS 1.0.
60
61    @type closed: bool
62    @ivar closed: If this connection is closed.
63
64    @type resumed: bool
65    @ivar resumed: If this connection is based on a resumed session.
66
67    @type allegedSrpUsername: str or None
68    @ivar allegedSrpUsername:  This is set to the SRP username
69    asserted by the client, whether the handshake succeeded or not.
70    If the handshake fails, this can be inspected to determine
71    if a guessing attack is in progress against a particular user
72    account.
73
74    @type closeSocket: bool
75    @ivar closeSocket: If the socket should be closed when the
76    connection is closed, defaults to True (writable).
77
78    If you set this to True, TLS Lite will assume the responsibility of
79    closing the socket when the TLS Connection is shutdown (either
80    through an error or through the user calling close()).  The default
81    is False.
82
83    @type ignoreAbruptClose: bool
84    @ivar ignoreAbruptClose: If an abrupt close of the socket should
85    raise an error (writable).
86
87    If you set this to True, TLS Lite will not raise a
88    L{tlslite.errors.TLSAbruptCloseError} exception if the underlying
89    socket is unexpectedly closed.  Such an unexpected closure could be
90    caused by an attacker.  However, it also occurs with some incorrect
91    TLS implementations.
92
93    You should set this to True only if you're not worried about an
94    attacker truncating the connection, and only if necessary to avoid
95    spurious errors.  The default is False.
96
97    @sort: __init__, read, readAsync, write, writeAsync, close, closeAsync,
98    getCipherImplementation, getCipherName
99    """
100
101    def __init__(self, sock):
102        self.sock = sock
103
104        #My session object (Session instance; read-only)
105        self.session = None
106
107        #Am I a client or server?
108        self._client = None
109
110        #Buffers for processing messages
111        self._handshakeBuffer = []
112        self.clearReadBuffer()
113        self.clearWriteBuffer()
114
115        #Handshake digests
116        self._handshake_md5 = hashlib.md5()
117        self._handshake_sha = hashlib.sha1()
118
119        #TLS Protocol Version
120        self.version = (0,0) #read-only
121        self._versionCheck = False #Once we choose a version, this is True
122
123        #Current and Pending connection states
124        self._writeState = _ConnectionState()
125        self._readState = _ConnectionState()
126        self._pendingWriteState = _ConnectionState()
127        self._pendingReadState = _ConnectionState()
128
129        #Is the connection open?
130        self.closed = True #read-only
131        self._refCount = 0 #Used to trigger closure
132
133        #Is this a resumed session?
134        self.resumed = False #read-only
135
136        #What username did the client claim in his handshake?
137        self.allegedSrpUsername = None
138
139        #On a call to close(), do we close the socket? (writeable)
140        self.closeSocket = True
141
142        #If the socket is abruptly closed, do we ignore it
143        #and pretend the connection was shut down properly? (writeable)
144        self.ignoreAbruptClose = False
145
146        #Fault we will induce, for testing purposes
147        self.fault = None
148
149    def clearReadBuffer(self):
150        self._readBuffer = b''
151
152    def clearWriteBuffer(self):
153        self._send_writer = None
154
155
156    #*********************************************************
157    # Public Functions START
158    #*********************************************************
159
160    def read(self, max=None, min=1):
161        """Read some data from the TLS connection.
162
163        This function will block until at least 'min' bytes are
164        available (or the connection is closed).
165
166        If an exception is raised, the connection will have been
167        automatically closed.
168
169        @type max: int
170        @param max: The maximum number of bytes to return.
171
172        @type min: int
173        @param min: The minimum number of bytes to return
174
175        @rtype: str
176        @return: A string of no more than 'max' bytes, and no fewer
177        than 'min' (unless the connection has been closed, in which
178        case fewer than 'min' bytes may be returned).
179
180        @raise socket.error: If a socket error occurs.
181        @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed
182        without a preceding alert.
183        @raise tlslite.errors.TLSAlert: If a TLS alert is signalled.
184        """
185        for result in self.readAsync(max, min):
186            pass
187        return result
188
189    def readAsync(self, max=None, min=1):
190        """Start a read operation on the TLS connection.
191
192        This function returns a generator which behaves similarly to
193        read().  Successive invocations of the generator will return 0
194        if it is waiting to read from the socket, 1 if it is waiting
195        to write to the socket, or a string if the read operation has
196        completed.
197
198        @rtype: iterable
199        @return: A generator; see above for details.
200        """
201        try:
202            while len(self._readBuffer)<min and not self.closed:
203                try:
204                    for result in self._getMsg(ContentType.application_data):
205                        if result in (0,1):
206                            yield result
207                    applicationData = result
208                    self._readBuffer += applicationData.write()
209                except TLSRemoteAlert as alert:
210                    if alert.description != AlertDescription.close_notify:
211                        raise
212                except TLSAbruptCloseError:
213                    if not self.ignoreAbruptClose:
214                        raise
215                    else:
216                        self._shutdown(True)
217
218            if max == None:
219                max = len(self._readBuffer)
220
221            returnBytes = self._readBuffer[:max]
222            self._readBuffer = self._readBuffer[max:]
223            yield bytes(returnBytes)
224        except GeneratorExit:
225            raise
226        except:
227            self._shutdown(False)
228            raise
229
230    def unread(self, b):
231        """Add bytes to the front of the socket read buffer for future
232        reading. Be careful using this in the context of select(...): if you
233        unread the last data from a socket, that won't wake up selected waiters,
234        and those waiters may hang forever.
235        """
236        self._readBuffer = b + self._readBuffer
237
238    def write(self, s):
239        """Write some data to the TLS connection.
240
241        This function will block until all the data has been sent.
242
243        If an exception is raised, the connection will have been
244        automatically closed.
245
246        @type s: str
247        @param s: The data to transmit to the other party.
248
249        @raise socket.error: If a socket error occurs.
250        """
251        for result in self.writeAsync(s):
252            pass
253
254    def writeAsync(self, s):
255        """Start a write operation on the TLS connection.
256
257        This function returns a generator which behaves similarly to
258        write().  Successive invocations of the generator will return
259        1 if it is waiting to write to the socket, or will raise
260        StopIteration if the write operation has completed.
261
262        @rtype: iterable
263        @return: A generator; see above for details.
264        """
265        try:
266            if self.closed:
267                raise TLSClosedConnectionError("attempt to write to closed connection")
268
269            index = 0
270            blockSize = 16384
271            randomizeFirstBlock = True
272            while 1:
273                startIndex = index * blockSize
274                endIndex = startIndex + blockSize
275                if startIndex >= len(s):
276                    break
277                if endIndex > len(s):
278                    endIndex = len(s)
279                block = bytearray(s[startIndex : endIndex])
280                applicationData = ApplicationData().create(block)
281                for result in self._sendMsg(applicationData, \
282                                            randomizeFirstBlock):
283                    yield result
284                randomizeFirstBlock = False #only on 1st message
285                index += 1
286        except GeneratorExit:
287            raise
288        except Exception:
289            # Don't invalidate the session on write failure if abrupt closes are
290            # okay.
291            self._shutdown(self.ignoreAbruptClose)
292            raise
293
294    def close(self):
295        """Close the TLS connection.
296
297        This function will block until it has exchanged close_notify
298        alerts with the other party.  After doing so, it will shut down the
299        TLS connection.  Further attempts to read through this connection
300        will return "".  Further attempts to write through this connection
301        will raise ValueError.
302
303        If makefile() has been called on this connection, the connection
304        will be not be closed until the connection object and all file
305        objects have been closed.
306
307        Even if an exception is raised, the connection will have been
308        closed.
309
310        @raise socket.error: If a socket error occurs.
311        @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed
312        without a preceding alert.
313        @raise tlslite.errors.TLSAlert: If a TLS alert is signalled.
314        """
315        if not self.closed:
316            for result in self._decrefAsync():
317                pass
318
319    # Python 3 callback
320    _decref_socketios = close
321
322    def closeAsync(self):
323        """Start a close operation on the TLS connection.
324
325        This function returns a generator which behaves similarly to
326        close().  Successive invocations of the generator will return 0
327        if it is waiting to read from the socket, 1 if it is waiting
328        to write to the socket, or will raise StopIteration if the
329        close operation has completed.
330
331        @rtype: iterable
332        @return: A generator; see above for details.
333        """
334        if not self.closed:
335            for result in self._decrefAsync():
336                yield result
337
338    def _decrefAsync(self):
339        self._refCount -= 1
340        if self._refCount == 0 and not self.closed:
341            try:
342                for result in self._sendMsg(Alert().create(\
343                        AlertDescription.close_notify, AlertLevel.warning)):
344                    yield result
345                alert = None
346                # By default close the socket, since it's been observed
347                # that some other libraries will not respond to the
348                # close_notify alert, thus leaving us hanging if we're
349                # expecting it
350                if self.closeSocket:
351                    self._shutdown(True)
352                else:
353                    while not alert:
354                        for result in self._getMsg((ContentType.alert, \
355                                                  ContentType.application_data)):
356                            if result in (0,1):
357                                yield result
358                        if result.contentType == ContentType.alert:
359                            alert = result
360                    if alert.description == AlertDescription.close_notify:
361                        self._shutdown(True)
362                    else:
363                        raise TLSRemoteAlert(alert)
364            except (socket.error, TLSAbruptCloseError):
365                #If the other side closes the socket, that's okay
366                self._shutdown(True)
367            except GeneratorExit:
368                raise
369            except:
370                self._shutdown(False)
371                raise
372
373    def getVersionName(self):
374        """Get the name of this TLS version.
375
376        @rtype: str
377        @return: The name of the TLS version used with this connection.
378        Either None, 'SSL 3.0', 'TLS 1.0', or 'TLS 1.1'.
379        """
380        if self.version == (3,0):
381            return "SSL 3.0"
382        elif self.version == (3,1):
383            return "TLS 1.0"
384        elif self.version == (3,2):
385            return "TLS 1.1"
386        else:
387            return None
388
389    def getCipherName(self):
390        """Get the name of the cipher used with this connection.
391
392        @rtype: str
393        @return: The name of the cipher used with this connection.
394        Either 'aes128', 'aes256', 'rc4', or '3des'.
395        """
396        if not self._writeState.encContext:
397            return None
398        return self._writeState.encContext.name
399
400    def getCipherImplementation(self):
401        """Get the name of the cipher implementation used with
402        this connection.
403
404        @rtype: str
405        @return: The name of the cipher implementation used with
406        this connection.  Either 'python', 'openssl', or 'pycrypto'.
407        """
408        if not self._writeState.encContext:
409            return None
410        return self._writeState.encContext.implementation
411
412
413
414    #Emulate a socket, somewhat -
415    def send(self, s):
416        """Send data to the TLS connection (socket emulation).
417
418        @raise socket.error: If a socket error occurs.
419        """
420        self.write(s)
421        return len(s)
422
423    def sendall(self, s):
424        """Send data to the TLS connection (socket emulation).
425
426        @raise socket.error: If a socket error occurs.
427        """
428        self.write(s)
429
430    def recv(self, bufsize):
431        """Get some data from the TLS connection (socket emulation).
432
433        @raise socket.error: If a socket error occurs.
434        @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed
435        without a preceding alert.
436        @raise tlslite.errors.TLSAlert: If a TLS alert is signalled.
437        """
438        return self.read(bufsize)
439
440    def recv_into(self, b):
441        # XXX doc string
442        data = self.read(len(b))
443        if not data:
444            return None
445        b[:len(data)] = data
446        return len(data)
447
448    def makefile(self, mode='r', bufsize=-1):
449        """Create a file object for the TLS connection (socket emulation).
450
451        @rtype: L{socket._fileobject}
452        """
453        self._refCount += 1
454        # So, it is pretty fragile to be using Python internal objects
455        # like this, but it is probably the best/easiest way to provide
456        # matching behavior for socket emulation purposes.  The 'close'
457        # argument is nice, its apparently a recent addition to this
458        # class, so that when fileobject.close() gets called, it will
459        # close() us, causing the refcount to be decremented (decrefAsync).
460        #
461        # If this is the last close() on the outstanding fileobjects /
462        # TLSConnection, then the "actual" close alerts will be sent,
463        # socket closed, etc.
464        if sys.version_info < (3,):
465            return socket._fileobject(self, mode, bufsize, close=True)
466        else:
467            # XXX need to wrap this further if buffering is requested
468            return socket.SocketIO(self, mode)
469
470    def getsockname(self):
471        """Return the socket's own address (socket emulation)."""
472        return self.sock.getsockname()
473
474    def getpeername(self):
475        """Return the remote address to which the socket is connected
476        (socket emulation)."""
477        return self.sock.getpeername()
478
479    def settimeout(self, value):
480        """Set a timeout on blocking socket operations (socket emulation)."""
481        return self.sock.settimeout(value)
482
483    def gettimeout(self):
484        """Return the timeout associated with socket operations (socket
485        emulation)."""
486        return self.sock.gettimeout()
487
488    def setsockopt(self, level, optname, value):
489        """Set the value of the given socket option (socket emulation)."""
490        return self.sock.setsockopt(level, optname, value)
491
492    def shutdown(self, how):
493        """Shutdown the underlying socket."""
494        return self.sock.shutdown(how)
495
496    def fileno(self):
497        """Not implement in TLS Lite."""
498        raise NotImplementedError()
499
500
501     #*********************************************************
502     # Public Functions END
503     #*********************************************************
504
505    def _shutdown(self, resumable):
506        self._writeState = _ConnectionState()
507        self._readState = _ConnectionState()
508        self.version = (0,0)
509        self._versionCheck = False
510        self.closed = True
511        if self.closeSocket:
512            self.sock.close()
513
514        #Even if resumable is False, we'll never toggle this on
515        if not resumable and self.session:
516            self.session.resumable = False
517
518
519    def _sendError(self, alertDescription, errorStr=None):
520        alert = Alert().create(alertDescription, AlertLevel.fatal)
521        for result in self._sendMsg(alert):
522            yield result
523        self._shutdown(False)
524        raise TLSLocalAlert(alert, errorStr)
525
526    def _sendMsgs(self, msgs):
527        randomizeFirstBlock = True
528        for msg in msgs:
529            for result in self._sendMsg(msg, randomizeFirstBlock):
530                yield result
531            randomizeFirstBlock = True
532
533    def _sendMsg(self, msg, randomizeFirstBlock = True):
534        #Whenever we're connected and asked to send an app data message,
535        #we first send the first byte of the message.  This prevents
536        #an attacker from launching a chosen-plaintext attack based on
537        #knowing the next IV (a la BEAST).
538        if not self.closed and randomizeFirstBlock and self.version <= (3,1) \
539                and self._writeState.encContext \
540                and self._writeState.encContext.isBlockCipher \
541                and isinstance(msg, ApplicationData):
542            msgFirstByte = msg.splitFirstByte()
543            for result in self._sendMsg(msgFirstByte,
544                                       randomizeFirstBlock = False):
545                yield result
546
547        b = msg.write()
548
549        # If a 1-byte message was passed in, and we "split" the
550        # first(only) byte off above, we may have a 0-length msg:
551        if len(b) == 0:
552            return
553
554        contentType = msg.contentType
555
556        #Update handshake hashes
557        if contentType == ContentType.handshake:
558            self._handshake_md5.update(compat26Str(b))
559            self._handshake_sha.update(compat26Str(b))
560
561        #Calculate MAC
562        if self._writeState.macContext:
563            seqnumBytes = self._writeState.getSeqNumBytes()
564            mac = self._writeState.macContext.copy()
565            mac.update(compatHMAC(seqnumBytes))
566            mac.update(compatHMAC(bytearray([contentType])))
567            if self.version == (3,0):
568                mac.update( compatHMAC( bytearray([len(b)//256] )))
569                mac.update( compatHMAC( bytearray([len(b)%256] )))
570            elif self.version in ((3,1), (3,2)):
571                mac.update(compatHMAC( bytearray([self.version[0]] )))
572                mac.update(compatHMAC( bytearray([self.version[1]] )))
573                mac.update( compatHMAC( bytearray([len(b)//256] )))
574                mac.update( compatHMAC( bytearray([len(b)%256] )))
575            else:
576                raise AssertionError()
577            mac.update(compatHMAC(b))
578            macBytes = bytearray(mac.digest())
579            if self.fault == Fault.badMAC:
580                macBytes[0] = (macBytes[0]+1) % 256
581
582        #Encrypt for Block or Stream Cipher
583        if self._writeState.encContext:
584            #Add padding and encrypt (for Block Cipher):
585            if self._writeState.encContext.isBlockCipher:
586
587                #Add TLS 1.1 fixed block
588                if self.version == (3,2):
589                    b = self.fixedIVBlock + b
590
591                #Add padding: b = b + (macBytes + paddingBytes)
592                currentLength = len(b) + len(macBytes)
593                blockLength = self._writeState.encContext.block_size
594                paddingLength = blockLength - 1 - (currentLength % blockLength)
595
596                paddingBytes = bytearray([paddingLength] * (paddingLength+1))
597                if self.fault == Fault.badPadding:
598                    paddingBytes[0] = (paddingBytes[0]+1) % 256
599                endBytes = macBytes + paddingBytes
600                b += endBytes
601                #Encrypt
602                b = self._writeState.encContext.encrypt(b)
603
604            #Encrypt (for Stream Cipher)
605            else:
606                b += macBytes
607                b = self._writeState.encContext.encrypt(b)
608
609        #Add record header and send
610        r = RecordHeader3().create(self.version, contentType, len(b))
611        s = r.write() + b
612        while 1:
613            try:
614                bytesSent = self.sock.send(s) #Might raise socket.error
615            except socket.error as why:
616                if why.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN):
617                    yield 1
618                    continue
619                else:
620                    # The socket was unexpectedly closed.  The tricky part
621                    # is that there may be an alert sent by the other party
622                    # sitting in the read buffer.  So, if we get here after
623                    # handshaking, we will just raise the error and let the
624                    # caller read more data if it would like, thus stumbling
625                    # upon the error.
626                    #
627                    # However, if we get here DURING handshaking, we take
628                    # it upon ourselves to see if the next message is an
629                    # Alert.
630                    if contentType == ContentType.handshake:
631
632                        # See if there's an alert record
633                        # Could raise socket.error or TLSAbruptCloseError
634                        for result in self._getNextRecord():
635                            if result in (0,1):
636                                yield result
637
638                        # Closes the socket
639                        self._shutdown(False)
640
641                        # If we got an alert, raise it
642                        recordHeader, p = result
643                        if recordHeader.type == ContentType.alert:
644                            alert = Alert().parse(p)
645                            raise TLSRemoteAlert(alert)
646                    else:
647                        # If we got some other message who know what
648                        # the remote side is doing, just go ahead and
649                        # raise the socket.error
650                        raise
651            if bytesSent == len(s):
652                return
653            s = s[bytesSent:]
654            yield 1
655
656
657    def _getMsg(self, expectedType, secondaryType=None, constructorType=None):
658        try:
659            if not isinstance(expectedType, tuple):
660                expectedType = (expectedType,)
661
662            #Spin in a loop, until we've got a non-empty record of a type we
663            #expect.  The loop will be repeated if:
664            #  - we receive a renegotiation attempt; we send no_renegotiation,
665            #    then try again
666            #  - we receive an empty application-data fragment; we try again
667            while 1:
668                for result in self._getNextRecord():
669                    if result in (0,1):
670                        yield result
671                recordHeader, p = result
672
673                #If this is an empty application-data fragment, try again
674                if recordHeader.type == ContentType.application_data:
675                    if p.index == len(p.bytes):
676                        continue
677
678                #If we received an unexpected record type...
679                if recordHeader.type not in expectedType:
680
681                    #If we received an alert...
682                    if recordHeader.type == ContentType.alert:
683                        alert = Alert().parse(p)
684
685                        #We either received a fatal error, a warning, or a
686                        #close_notify.  In any case, we're going to close the
687                        #connection.  In the latter two cases we respond with
688                        #a close_notify, but ignore any socket errors, since
689                        #the other side might have already closed the socket.
690                        if alert.level == AlertLevel.warning or \
691                           alert.description == AlertDescription.close_notify:
692
693                            #If the sendMsg() call fails because the socket has
694                            #already been closed, we will be forgiving and not
695                            #report the error nor invalidate the "resumability"
696                            #of the session.
697                            try:
698                                alertMsg = Alert()
699                                alertMsg.create(AlertDescription.close_notify,
700                                                AlertLevel.warning)
701                                for result in self._sendMsg(alertMsg):
702                                    yield result
703                            except socket.error:
704                                pass
705
706                            if alert.description == \
707                                   AlertDescription.close_notify:
708                                self._shutdown(True)
709                            elif alert.level == AlertLevel.warning:
710                                self._shutdown(False)
711
712                        else: #Fatal alert:
713                            self._shutdown(False)
714
715                        #Raise the alert as an exception
716                        raise TLSRemoteAlert(alert)
717
718                    #If we received a renegotiation attempt...
719                    if recordHeader.type == ContentType.handshake:
720                        subType = p.get(1)
721                        reneg = False
722                        if self._client:
723                            if subType == HandshakeType.hello_request:
724                                reneg = True
725                        else:
726                            if subType == HandshakeType.client_hello:
727                                reneg = True
728                        #Send no_renegotiation, then try again
729                        if reneg:
730                            alertMsg = Alert()
731                            alertMsg.create(AlertDescription.no_renegotiation,
732                                            AlertLevel.warning)
733                            for result in self._sendMsg(alertMsg):
734                                yield result
735                            continue
736
737                    #Otherwise: this is an unexpected record, but neither an
738                    #alert nor renegotiation
739                    for result in self._sendError(\
740                            AlertDescription.unexpected_message,
741                            "received type=%d" % recordHeader.type):
742                        yield result
743
744                break
745
746            #Parse based on content_type
747            if recordHeader.type == ContentType.change_cipher_spec:
748                yield ChangeCipherSpec().parse(p)
749            elif recordHeader.type == ContentType.alert:
750                yield Alert().parse(p)
751            elif recordHeader.type == ContentType.application_data:
752                yield ApplicationData().parse(p)
753            elif recordHeader.type == ContentType.handshake:
754                #Convert secondaryType to tuple, if it isn't already
755                if not isinstance(secondaryType, tuple):
756                    secondaryType = (secondaryType,)
757
758                #If it's a handshake message, check handshake header
759                if recordHeader.ssl2:
760                    subType = p.get(1)
761                    if subType != HandshakeType.client_hello:
762                        for result in self._sendError(\
763                                AlertDescription.unexpected_message,
764                                "Can only handle SSLv2 ClientHello messages"):
765                            yield result
766                    if HandshakeType.client_hello not in secondaryType:
767                        for result in self._sendError(\
768                                AlertDescription.unexpected_message):
769                            yield result
770                    subType = HandshakeType.client_hello
771                else:
772                    subType = p.get(1)
773                    if subType not in secondaryType:
774                        for result in self._sendError(\
775                                AlertDescription.unexpected_message,
776                                "Expecting %s, got %s" % (str(secondaryType), subType)):
777                            yield result
778
779                #Update handshake hashes
780                self._handshake_md5.update(compat26Str(p.bytes))
781                self._handshake_sha.update(compat26Str(p.bytes))
782
783                #Parse based on handshake type
784                if subType == HandshakeType.client_hello:
785                    yield ClientHello(recordHeader.ssl2).parse(p)
786                elif subType == HandshakeType.server_hello:
787                    yield ServerHello().parse(p)
788                elif subType == HandshakeType.certificate:
789                    yield Certificate(constructorType).parse(p)
790                elif subType == HandshakeType.certificate_request:
791                    yield CertificateRequest().parse(p)
792                elif subType == HandshakeType.certificate_verify:
793                    yield CertificateVerify().parse(p)
794                elif subType == HandshakeType.server_key_exchange:
795                    yield ServerKeyExchange(constructorType).parse(p)
796                elif subType == HandshakeType.server_hello_done:
797                    yield ServerHelloDone().parse(p)
798                elif subType == HandshakeType.client_key_exchange:
799                    yield ClientKeyExchange(constructorType, \
800                                            self.version).parse(p)
801                elif subType == HandshakeType.finished:
802                    yield Finished(self.version).parse(p)
803                elif subType == HandshakeType.next_protocol:
804                    yield NextProtocol().parse(p)
805                elif subType == HandshakeType.encrypted_extensions:
806                    yield EncryptedExtensions().parse(p)
807                else:
808                    raise AssertionError()
809
810        #If an exception was raised by a Parser or Message instance:
811        except SyntaxError as e:
812            for result in self._sendError(AlertDescription.decode_error,
813                                         formatExceptionTrace(e)):
814                yield result
815
816
817    #Returns next record or next handshake message
818    def _getNextRecord(self):
819
820        #If there's a handshake message waiting, return it
821        if self._handshakeBuffer:
822            recordHeader, b = self._handshakeBuffer[0]
823            self._handshakeBuffer = self._handshakeBuffer[1:]
824            yield (recordHeader, Parser(b))
825            return
826
827        #Otherwise...
828        #Read the next record header
829        b = bytearray(0)
830        recordHeaderLength = 1
831        ssl2 = False
832        while 1:
833            try:
834                s = self.sock.recv(recordHeaderLength-len(b))
835            except socket.error as why:
836                if why.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN):
837                    yield 0
838                    continue
839                else:
840                    raise
841
842            #If the connection was abruptly closed, raise an error
843            if len(s)==0:
844                raise TLSAbruptCloseError()
845
846            b += bytearray(s)
847            if len(b)==1:
848                if b[0] in ContentType.all:
849                    ssl2 = False
850                    recordHeaderLength = 5
851                elif b[0] == 128:
852                    ssl2 = True
853                    recordHeaderLength = 2
854                else:
855                    raise SyntaxError()
856            if len(b) == recordHeaderLength:
857                break
858
859        #Parse the record header
860        if ssl2:
861            r = RecordHeader2().parse(Parser(b))
862        else:
863            r = RecordHeader3().parse(Parser(b))
864
865        #Check the record header fields
866        if r.length > 18432:
867            for result in self._sendError(AlertDescription.record_overflow):
868                yield result
869
870        #Read the record contents
871        b = bytearray(0)
872        while 1:
873            try:
874                s = self.sock.recv(r.length - len(b))
875            except socket.error as why:
876                if why.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN):
877                    yield 0
878                    continue
879                else:
880                    raise
881
882            #If the connection is closed, raise a socket error
883            if len(s)==0:
884                    raise TLSAbruptCloseError()
885
886            b += bytearray(s)
887            if len(b) == r.length:
888                break
889
890        #Check the record header fields (2)
891        #We do this after reading the contents from the socket, so that
892        #if there's an error, we at least don't leave extra bytes in the
893        #socket..
894        #
895        # THIS CHECK HAS NO SECURITY RELEVANCE (?), BUT COULD HURT INTEROP.
896        # SO WE LEAVE IT OUT FOR NOW.
897        #
898        #if self._versionCheck and r.version != self.version:
899        #    for result in self._sendError(AlertDescription.protocol_version,
900        #            "Version in header field: %s, should be %s" % (str(r.version),
901        #                                                       str(self.version))):
902        #        yield result
903
904        #Decrypt the record
905        for result in self._decryptRecord(r.type, b):
906            if result in (0,1): yield result
907            else: break
908        b = result
909        p = Parser(b)
910
911        #If it doesn't contain handshake messages, we can just return it
912        if r.type != ContentType.handshake:
913            yield (r, p)
914        #If it's an SSLv2 ClientHello, we can return it as well
915        elif r.ssl2:
916            yield (r, p)
917        else:
918            #Otherwise, we loop through and add the handshake messages to the
919            #handshake buffer
920            while 1:
921                if p.index == len(b): #If we're at the end
922                    if not self._handshakeBuffer:
923                        for result in self._sendError(\
924                                AlertDescription.decode_error, \
925                                "Received empty handshake record"):
926                            yield result
927                    break
928                #There needs to be at least 4 bytes to get a header
929                if p.index+4 > len(b):
930                    for result in self._sendError(\
931                            AlertDescription.decode_error,
932                            "A record has a partial handshake message (1)"):
933                        yield result
934                p.get(1) # skip handshake type
935                msgLength = p.get(3)
936                if p.index+msgLength > len(b):
937                    for result in self._sendError(\
938                            AlertDescription.decode_error,
939                            "A record has a partial handshake message (2)"):
940                        yield result
941
942                handshakePair = (r, b[p.index-4 : p.index+msgLength])
943                self._handshakeBuffer.append(handshakePair)
944                p.index += msgLength
945
946            #We've moved at least one handshake message into the
947            #handshakeBuffer, return the first one
948            recordHeader, b = self._handshakeBuffer[0]
949            self._handshakeBuffer = self._handshakeBuffer[1:]
950            yield (recordHeader, Parser(b))
951
952
953    def _decryptRecord(self, recordType, b):
954        if self._readState.encContext:
955
956            #Decrypt if it's a block cipher
957            if self._readState.encContext.isBlockCipher:
958                blockLength = self._readState.encContext.block_size
959                if len(b) % blockLength != 0:
960                    for result in self._sendError(\
961                            AlertDescription.decryption_failed,
962                            "Encrypted data not a multiple of blocksize"):
963                        yield result
964                b = self._readState.encContext.decrypt(b)
965                if self.version == (3,2): #For TLS 1.1, remove explicit IV
966                    b = b[self._readState.encContext.block_size : ]
967
968                #Check padding
969                paddingGood = True
970                paddingLength = b[-1]
971                if (paddingLength+1) > len(b):
972                    paddingGood=False
973                    totalPaddingLength = 0
974                else:
975                    if self.version == (3,0):
976                        totalPaddingLength = paddingLength+1
977                    elif self.version in ((3,1), (3,2)):
978                        totalPaddingLength = paddingLength+1
979                        paddingBytes = b[-totalPaddingLength:-1]
980                        for byte in paddingBytes:
981                            if byte != paddingLength:
982                                paddingGood = False
983                                totalPaddingLength = 0
984                    else:
985                        raise AssertionError()
986
987            #Decrypt if it's a stream cipher
988            else:
989                paddingGood = True
990                b = self._readState.encContext.decrypt(b)
991                totalPaddingLength = 0
992
993            #Check MAC
994            macGood = True
995            macLength = self._readState.macContext.digest_size
996            endLength = macLength + totalPaddingLength
997            if endLength > len(b):
998                macGood = False
999            else:
1000                #Read MAC
1001                startIndex = len(b) - endLength
1002                endIndex = startIndex + macLength
1003                checkBytes = b[startIndex : endIndex]
1004
1005                #Calculate MAC
1006                seqnumBytes = self._readState.getSeqNumBytes()
1007                b = b[:-endLength]
1008                mac = self._readState.macContext.copy()
1009                mac.update(compatHMAC(seqnumBytes))
1010                mac.update(compatHMAC(bytearray([recordType])))
1011                if self.version == (3,0):
1012                    mac.update( compatHMAC(bytearray( [len(b)//256] ) ))
1013                    mac.update( compatHMAC(bytearray( [len(b)%256] ) ))
1014                elif self.version in ((3,1), (3,2)):
1015                    mac.update(compatHMAC(bytearray( [self.version[0]] ) ))
1016                    mac.update(compatHMAC(bytearray( [self.version[1]] ) ))
1017                    mac.update(compatHMAC(bytearray( [len(b)//256] ) ))
1018                    mac.update(compatHMAC(bytearray( [len(b)%256] ) ))
1019                else:
1020                    raise AssertionError()
1021                mac.update(compatHMAC(b))
1022                macBytes = bytearray(mac.digest())
1023
1024                #Compare MACs
1025                if macBytes != checkBytes:
1026                    macGood = False
1027
1028            if not (paddingGood and macGood):
1029                for result in self._sendError(AlertDescription.bad_record_mac,
1030                                          "MAC failure (or padding failure)"):
1031                    yield result
1032
1033        yield b
1034
1035    def _handshakeStart(self, client):
1036        if not self.closed:
1037            raise ValueError("Renegotiation disallowed for security reasons")
1038        self._client = client
1039        self._handshake_md5 = hashlib.md5()
1040        self._handshake_sha = hashlib.sha1()
1041        self._handshakeBuffer = []
1042        self.allegedSrpUsername = None
1043        self._refCount = 1
1044
1045    def _handshakeDone(self, resumed):
1046        self.resumed = resumed
1047        self.closed = False
1048
1049    def _calcPendingStates(self, cipherSuite, masterSecret,
1050            clientRandom, serverRandom, implementations):
1051        if cipherSuite in CipherSuite.aes128Suites:
1052            keyLength = 16
1053            ivLength = 16
1054            createCipherFunc = createAES
1055        elif cipherSuite in CipherSuite.aes256Suites:
1056            keyLength = 32
1057            ivLength = 16
1058            createCipherFunc = createAES
1059        elif cipherSuite in CipherSuite.rc4Suites:
1060            keyLength = 16
1061            ivLength = 0
1062            createCipherFunc = createRC4
1063        elif cipherSuite in CipherSuite.tripleDESSuites:
1064            keyLength = 24
1065            ivLength = 8
1066            createCipherFunc = createTripleDES
1067        else:
1068            raise AssertionError()
1069
1070        if cipherSuite in CipherSuite.shaSuites:
1071            macLength = 20
1072            digestmod = hashlib.sha1
1073        elif cipherSuite in CipherSuite.md5Suites:
1074            macLength = 16
1075            digestmod = hashlib.md5
1076
1077        if self.version == (3,0):
1078            createMACFunc = createMAC_SSL
1079        elif self.version in ((3,1), (3,2)):
1080            createMACFunc = createHMAC
1081
1082        outputLength = (macLength*2) + (keyLength*2) + (ivLength*2)
1083
1084        #Calculate Keying Material from Master Secret
1085        if self.version == (3,0):
1086            keyBlock = PRF_SSL(masterSecret,
1087                               serverRandom + clientRandom,
1088                               outputLength)
1089        elif self.version in ((3,1), (3,2)):
1090            keyBlock = PRF(masterSecret,
1091                           b"key expansion",
1092                           serverRandom + clientRandom,
1093                           outputLength)
1094        else:
1095            raise AssertionError()
1096
1097        #Slice up Keying Material
1098        clientPendingState = _ConnectionState()
1099        serverPendingState = _ConnectionState()
1100        p = Parser(keyBlock)
1101        clientMACBlock = p.getFixBytes(macLength)
1102        serverMACBlock = p.getFixBytes(macLength)
1103        clientKeyBlock = p.getFixBytes(keyLength)
1104        serverKeyBlock = p.getFixBytes(keyLength)
1105        clientIVBlock  = p.getFixBytes(ivLength)
1106        serverIVBlock  = p.getFixBytes(ivLength)
1107        clientPendingState.macContext = createMACFunc(
1108            compatHMAC(clientMACBlock), digestmod=digestmod)
1109        serverPendingState.macContext = createMACFunc(
1110            compatHMAC(serverMACBlock), digestmod=digestmod)
1111        clientPendingState.encContext = createCipherFunc(clientKeyBlock,
1112                                                         clientIVBlock,
1113                                                         implementations)
1114        serverPendingState.encContext = createCipherFunc(serverKeyBlock,
1115                                                         serverIVBlock,
1116                                                         implementations)
1117
1118        #Assign new connection states to pending states
1119        if self._client:
1120            self._pendingWriteState = clientPendingState
1121            self._pendingReadState = serverPendingState
1122        else:
1123            self._pendingWriteState = serverPendingState
1124            self._pendingReadState = clientPendingState
1125
1126        if self.version == (3,2) and ivLength:
1127            #Choose fixedIVBlock for TLS 1.1 (this is encrypted with the CBC
1128            #residue to create the IV for each sent block)
1129            self.fixedIVBlock = getRandomBytes(ivLength)
1130
1131    def _changeWriteState(self):
1132        self._writeState = self._pendingWriteState
1133        self._pendingWriteState = _ConnectionState()
1134
1135    def _changeReadState(self):
1136        self._readState = self._pendingReadState
1137        self._pendingReadState = _ConnectionState()
1138
1139    #Used for Finished messages and CertificateVerify messages in SSL v3
1140    def _calcSSLHandshakeHash(self, masterSecret, label):
1141        imac_md5 = self._handshake_md5.copy()
1142        imac_sha = self._handshake_sha.copy()
1143
1144        imac_md5.update(compatHMAC(label + masterSecret + bytearray([0x36]*48)))
1145        imac_sha.update(compatHMAC(label + masterSecret + bytearray([0x36]*40)))
1146
1147        md5Bytes = MD5(masterSecret + bytearray([0x5c]*48) + \
1148                         bytearray(imac_md5.digest()))
1149        shaBytes = SHA1(masterSecret + bytearray([0x5c]*40) + \
1150                         bytearray(imac_sha.digest()))
1151
1152        return md5Bytes + shaBytes
1153
1154