1# Authors:
2#   Trevor Perrin
3#   Google (adapted by Sam Rushing) - NPN support
4#   Martin von Loewis - python 3 port
5#
6# See the LICENSE file for legal information regarding use of this file.
7
8"""Helper class for TLSConnection."""
9from __future__ import generators
10
11from .utils.compat import *
12from .utils.cryptomath import *
13from .utils.cipherfactory import createAES, createRC4, createTripleDES
14from .utils.codec import *
15from .errors import *
16from .messages import *
17from .mathtls import *
18from .constants import *
19from .utils.cryptomath import getRandomBytes
20
21import socket
22import struct
23import errno
24import traceback
25
26class _ConnectionState(object):
27    def __init__(self):
28        self.macContext = None
29        self.encContext = None
30        self.seqnum = 0
31
32    def getSeqNumBytes(self):
33        w = Writer()
34        w.add(self.seqnum, 8)
35        self.seqnum += 1
36        return w.bytes
37
38
39class TLSRecordLayer(object):
40    """
41    This class handles data transmission for a TLS connection.
42
43    Its only subclass is L{tlslite.TLSConnection.TLSConnection}.  We've
44    separated the code in this class from TLSConnection to make things
45    more readable.
46
47
48    @type sock: socket.socket
49    @ivar sock: The underlying socket object.
50
51    @type session: L{tlslite.Session.Session}
52    @ivar session: The session corresponding to this connection.
53
54    Due to TLS session resumption, multiple connections can correspond
55    to the same underlying session.
56
57    @type version: tuple
58    @ivar version: The TLS version being used for this connection.
59
60    (3,0) means SSL 3.0, and (3,1) means TLS 1.0.
61
62    @type closed: bool
63    @ivar closed: If this connection is closed.
64
65    @type resumed: bool
66    @ivar resumed: If this connection is based on a resumed session.
67
68    @type allegedSrpUsername: str or None
69    @ivar allegedSrpUsername:  This is set to the SRP username
70    asserted by the client, whether the handshake succeeded or not.
71    If the handshake fails, this can be inspected to determine
72    if a guessing attack is in progress against a particular user
73    account.
74
75    @type closeSocket: bool
76    @ivar closeSocket: If the socket should be closed when the
77    connection is closed, defaults to True (writable).
78
79    If you set this to True, TLS Lite will assume the responsibility of
80    closing the socket when the TLS Connection is shutdown (either
81    through an error or through the user calling close()).  The default
82    is False.
83
84    @type ignoreAbruptClose: bool
85    @ivar ignoreAbruptClose: If an abrupt close of the socket should
86    raise an error (writable).
87
88    If you set this to True, TLS Lite will not raise a
89    L{tlslite.errors.TLSAbruptCloseError} exception if the underlying
90    socket is unexpectedly closed.  Such an unexpected closure could be
91    caused by an attacker.  However, it also occurs with some incorrect
92    TLS implementations.
93
94    You should set this to True only if you're not worried about an
95    attacker truncating the connection, and only if necessary to avoid
96    spurious errors.  The default is False.
97
98    @sort: __init__, read, readAsync, write, writeAsync, close, closeAsync,
99    getCipherImplementation, getCipherName
100    """
101
102    def __init__(self, sock):
103        self.sock = sock
104
105        #My session object (Session instance; read-only)
106        self.session = None
107
108        #Am I a client or server?
109        self._client = None
110
111        #Buffers for processing messages
112        self._handshakeBuffer = []
113        self.clearReadBuffer()
114        self.clearWriteBuffer()
115
116        #Handshake digests
117        self._handshake_md5 = hashlib.md5()
118        self._handshake_sha = hashlib.sha1()
119
120        #TLS Protocol Version
121        self.version = (0,0) #read-only
122        self._versionCheck = False #Once we choose a version, this is True
123
124        #Current and Pending connection states
125        self._writeState = _ConnectionState()
126        self._readState = _ConnectionState()
127        self._pendingWriteState = _ConnectionState()
128        self._pendingReadState = _ConnectionState()
129
130        #Is the connection open?
131        self.closed = True #read-only
132        self._refCount = 0 #Used to trigger closure
133
134        #Is this a resumed session?
135        self.resumed = False #read-only
136
137        #What username did the client claim in his handshake?
138        self.allegedSrpUsername = None
139
140        #On a call to close(), do we close the socket? (writeable)
141        self.closeSocket = True
142
143        #If the socket is abruptly closed, do we ignore it
144        #and pretend the connection was shut down properly? (writeable)
145        self.ignoreAbruptClose = False
146
147        #Fault we will induce, for testing purposes
148        self.fault = None
149
150    def clearReadBuffer(self):
151        self._readBuffer = b''
152
153    def clearWriteBuffer(self):
154        self._send_writer = None
155
156
157    #*********************************************************
158    # Public Functions START
159    #*********************************************************
160
161    def read(self, max=None, min=1):
162        """Read some data from the TLS connection.
163
164        This function will block until at least 'min' bytes are
165        available (or the connection is closed).
166
167        If an exception is raised, the connection will have been
168        automatically closed.
169
170        @type max: int
171        @param max: The maximum number of bytes to return.
172
173        @type min: int
174        @param min: The minimum number of bytes to return
175
176        @rtype: str
177        @return: A string of no more than 'max' bytes, and no fewer
178        than 'min' (unless the connection has been closed, in which
179        case fewer than 'min' bytes may be returned).
180
181        @raise socket.error: If a socket error occurs.
182        @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed
183        without a preceding alert.
184        @raise tlslite.errors.TLSAlert: If a TLS alert is signalled.
185        """
186        for result in self.readAsync(max, min):
187            pass
188        return result
189
190    def readAsync(self, max=None, min=1):
191        """Start a read operation on the TLS connection.
192
193        This function returns a generator which behaves similarly to
194        read().  Successive invocations of the generator will return 0
195        if it is waiting to read from the socket, 1 if it is waiting
196        to write to the socket, or a string if the read operation has
197        completed.
198
199        @rtype: iterable
200        @return: A generator; see above for details.
201        """
202        try:
203            while len(self._readBuffer)<min and not self.closed:
204                try:
205                    for result in self._getMsg(ContentType.application_data):
206                        if result in (0,1):
207                            yield result
208                    applicationData = result
209                    self._readBuffer += applicationData.write()
210                except TLSRemoteAlert as alert:
211                    if alert.description != AlertDescription.close_notify:
212                        raise
213                except TLSAbruptCloseError:
214                    if not self.ignoreAbruptClose:
215                        raise
216                    else:
217                        self._shutdown(True)
218
219            if max == None:
220                max = len(self._readBuffer)
221
222            returnBytes = self._readBuffer[:max]
223            self._readBuffer = self._readBuffer[max:]
224            yield bytes(returnBytes)
225        except GeneratorExit:
226            raise
227        except:
228            self._shutdown(False)
229            raise
230
231    def unread(self, b):
232        """Add bytes to the front of the socket read buffer for future
233        reading. Be careful using this in the context of select(...): if you
234        unread the last data from a socket, that won't wake up selected waiters,
235        and those waiters may hang forever.
236        """
237        self._readBuffer = b + self._readBuffer
238
239    def write(self, s):
240        """Write some data to the TLS connection.
241
242        This function will block until all the data has been sent.
243
244        If an exception is raised, the connection will have been
245        automatically closed.
246
247        @type s: str
248        @param s: The data to transmit to the other party.
249
250        @raise socket.error: If a socket error occurs.
251        """
252        for result in self.writeAsync(s):
253            pass
254
255    def writeAsync(self, s):
256        """Start a write operation on the TLS connection.
257
258        This function returns a generator which behaves similarly to
259        write().  Successive invocations of the generator will return
260        1 if it is waiting to write to the socket, or will raise
261        StopIteration if the write operation has completed.
262
263        @rtype: iterable
264        @return: A generator; see above for details.
265        """
266        try:
267            if self.closed:
268                raise TLSClosedConnectionError("attempt to write to closed connection")
269
270            index = 0
271            blockSize = 16384
272            randomizeFirstBlock = True
273            while 1:
274                startIndex = index * blockSize
275                endIndex = startIndex + blockSize
276                if startIndex >= len(s):
277                    break
278                if endIndex > len(s):
279                    endIndex = len(s)
280                block = bytearray(s[startIndex : endIndex])
281                applicationData = ApplicationData().create(block)
282                for result in self._sendMsg(applicationData, \
283                                            randomizeFirstBlock):
284                    yield result
285                randomizeFirstBlock = False #only on 1st message
286                index += 1
287        except GeneratorExit:
288            raise
289        except Exception:
290            # Don't invalidate the session on write failure if abrupt closes are
291            # okay.
292            self._shutdown(self.ignoreAbruptClose)
293            raise
294
295    def close(self):
296        """Close the TLS connection.
297
298        This function will block until it has exchanged close_notify
299        alerts with the other party.  After doing so, it will shut down the
300        TLS connection.  Further attempts to read through this connection
301        will return "".  Further attempts to write through this connection
302        will raise ValueError.
303
304        If makefile() has been called on this connection, the connection
305        will be not be closed until the connection object and all file
306        objects have been closed.
307
308        Even if an exception is raised, the connection will have been
309        closed.
310
311        @raise socket.error: If a socket error occurs.
312        @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed
313        without a preceding alert.
314        @raise tlslite.errors.TLSAlert: If a TLS alert is signalled.
315        """
316        if not self.closed:
317            for result in self._decrefAsync():
318                pass
319
320    # Python 3 callback
321    _decref_socketios = close
322
323    def closeAsync(self):
324        """Start a close operation on the TLS connection.
325
326        This function returns a generator which behaves similarly to
327        close().  Successive invocations of the generator will return 0
328        if it is waiting to read from the socket, 1 if it is waiting
329        to write to the socket, or will raise StopIteration if the
330        close operation has completed.
331
332        @rtype: iterable
333        @return: A generator; see above for details.
334        """
335        if not self.closed:
336            for result in self._decrefAsync():
337                yield result
338
339    def _decrefAsync(self):
340        self._refCount -= 1
341        if self._refCount == 0 and not self.closed:
342            try:
343                for result in self._sendMsg(Alert().create(\
344                        AlertDescription.close_notify, AlertLevel.warning)):
345                    yield result
346                alert = None
347                # By default close the socket, since it's been observed
348                # that some other libraries will not respond to the
349                # close_notify alert, thus leaving us hanging if we're
350                # expecting it
351                if self.closeSocket:
352                    self._shutdown(True)
353                else:
354                    while not alert:
355                        for result in self._getMsg((ContentType.alert, \
356                                                  ContentType.application_data)):
357                            if result in (0,1):
358                                yield result
359                        if result.contentType == ContentType.alert:
360                            alert = result
361                    if alert.description == AlertDescription.close_notify:
362                        self._shutdown(True)
363                    else:
364                        raise TLSRemoteAlert(alert)
365            except (socket.error, TLSAbruptCloseError):
366                #If the other side closes the socket, that's okay
367                self._shutdown(True)
368            except GeneratorExit:
369                raise
370            except:
371                self._shutdown(False)
372                raise
373
374    def getVersionName(self):
375        """Get the name of this TLS version.
376
377        @rtype: str
378        @return: The name of the TLS version used with this connection.
379        Either None, 'SSL 3.0', 'TLS 1.0', or 'TLS 1.1'.
380        """
381        if self.version == (3,0):
382            return "SSL 3.0"
383        elif self.version == (3,1):
384            return "TLS 1.0"
385        elif self.version == (3,2):
386            return "TLS 1.1"
387        else:
388            return None
389
390    def getCipherName(self):
391        """Get the name of the cipher used with this connection.
392
393        @rtype: str
394        @return: The name of the cipher used with this connection.
395        Either 'aes128', 'aes256', 'rc4', or '3des'.
396        """
397        if not self._writeState.encContext:
398            return None
399        return self._writeState.encContext.name
400
401    def getCipherImplementation(self):
402        """Get the name of the cipher implementation used with
403        this connection.
404
405        @rtype: str
406        @return: The name of the cipher implementation used with
407        this connection.  Either 'python', 'openssl', or 'pycrypto'.
408        """
409        if not self._writeState.encContext:
410            return None
411        return self._writeState.encContext.implementation
412
413
414
415    #Emulate a socket, somewhat -
416    def send(self, s):
417        """Send data to the TLS connection (socket emulation).
418
419        @raise socket.error: If a socket error occurs.
420        """
421        self.write(s)
422        return len(s)
423
424    def sendall(self, s):
425        """Send data to the TLS connection (socket emulation).
426
427        @raise socket.error: If a socket error occurs.
428        """
429        self.write(s)
430
431    def recv(self, bufsize):
432        """Get some data from the TLS connection (socket emulation).
433
434        @raise socket.error: If a socket error occurs.
435        @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed
436        without a preceding alert.
437        @raise tlslite.errors.TLSAlert: If a TLS alert is signalled.
438        """
439        return self.read(bufsize)
440
441    def recv_into(self, b):
442        # XXX doc string
443        data = self.read(len(b))
444        if not data:
445            return None
446        b[:len(data)] = data
447        return len(data)
448
449    def makefile(self, mode='r', bufsize=-1):
450        """Create a file object for the TLS connection (socket emulation).
451
452        @rtype: L{socket._fileobject}
453        """
454        self._refCount += 1
455        # So, it is pretty fragile to be using Python internal objects
456        # like this, but it is probably the best/easiest way to provide
457        # matching behavior for socket emulation purposes.  The 'close'
458        # argument is nice, its apparently a recent addition to this
459        # class, so that when fileobject.close() gets called, it will
460        # close() us, causing the refcount to be decremented (decrefAsync).
461        #
462        # If this is the last close() on the outstanding fileobjects /
463        # TLSConnection, then the "actual" close alerts will be sent,
464        # socket closed, etc.
465        if sys.version_info < (3,):
466            return socket._fileobject(self, mode, bufsize, close=True)
467        else:
468            # XXX need to wrap this further if buffering is requested
469            return socket.SocketIO(self, mode)
470
471    def getsockname(self):
472        """Return the socket's own address (socket emulation)."""
473        return self.sock.getsockname()
474
475    def getpeername(self):
476        """Return the remote address to which the socket is connected
477        (socket emulation)."""
478        return self.sock.getpeername()
479
480    def settimeout(self, value):
481        """Set a timeout on blocking socket operations (socket emulation)."""
482        return self.sock.settimeout(value)
483
484    def gettimeout(self):
485        """Return the timeout associated with socket operations (socket
486        emulation)."""
487        return self.sock.gettimeout()
488
489    def setsockopt(self, level, optname, value):
490        """Set the value of the given socket option (socket emulation)."""
491        return self.sock.setsockopt(level, optname, value)
492
493    def shutdown(self, how):
494        """Shutdown the underlying socket."""
495        return self.sock.shutdown(how)
496
497    def fileno(self):
498        """Not implement in TLS Lite."""
499        raise NotImplementedError()
500
501
502     #*********************************************************
503     # Public Functions END
504     #*********************************************************
505
506    def _shutdown(self, resumable):
507        self._writeState = _ConnectionState()
508        self._readState = _ConnectionState()
509        self.version = (0,0)
510        self._versionCheck = False
511        self.closed = True
512        if self.closeSocket:
513            self.sock.close()
514
515        #Even if resumable is False, we'll never toggle this on
516        if not resumable and self.session:
517            self.session.resumable = False
518
519
520    def _sendError(self, alertDescription, errorStr=None):
521        alert = Alert().create(alertDescription, AlertLevel.fatal)
522        for result in self._sendMsg(alert):
523            yield result
524        self._shutdown(False)
525        raise TLSLocalAlert(alert, errorStr)
526
527    def _abruptClose(self, reset=False):
528        if reset:
529            #Set an SO_LINGER timeout of 0 to send a TCP RST.
530            self.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
531                            struct.pack('ii', 1, 0))
532        self._shutdown(False)
533
534    def _sendMsgs(self, msgs):
535        randomizeFirstBlock = True
536        for msg in msgs:
537            for result in self._sendMsg(msg, randomizeFirstBlock):
538                yield result
539            randomizeFirstBlock = True
540
541    def _sendMsg(self, msg, randomizeFirstBlock = True):
542        #Whenever we're connected and asked to send an app data message,
543        #we first send the first byte of the message.  This prevents
544        #an attacker from launching a chosen-plaintext attack based on
545        #knowing the next IV (a la BEAST).
546        if not self.closed and randomizeFirstBlock and self.version <= (3,1) \
547                and self._writeState.encContext \
548                and self._writeState.encContext.isBlockCipher \
549                and isinstance(msg, ApplicationData):
550            msgFirstByte = msg.splitFirstByte()
551            for result in self._sendMsg(msgFirstByte,
552                                       randomizeFirstBlock = False):
553                yield result
554
555        b = msg.write()
556
557        # If a 1-byte message was passed in, and we "split" the
558        # first(only) byte off above, we may have a 0-length msg:
559        if len(b) == 0:
560            return
561
562        contentType = msg.contentType
563
564        #Update handshake hashes
565        if contentType == ContentType.handshake:
566            self._handshake_md5.update(compat26Str(b))
567            self._handshake_sha.update(compat26Str(b))
568
569        #Calculate MAC
570        if self._writeState.macContext:
571            seqnumBytes = self._writeState.getSeqNumBytes()
572            mac = self._writeState.macContext.copy()
573            mac.update(compatHMAC(seqnumBytes))
574            mac.update(compatHMAC(bytearray([contentType])))
575            if self.version == (3,0):
576                mac.update( compatHMAC( bytearray([len(b)//256] )))
577                mac.update( compatHMAC( bytearray([len(b)%256] )))
578            elif self.version in ((3,1), (3,2)):
579                mac.update(compatHMAC( bytearray([self.version[0]] )))
580                mac.update(compatHMAC( bytearray([self.version[1]] )))
581                mac.update( compatHMAC( bytearray([len(b)//256] )))
582                mac.update( compatHMAC( bytearray([len(b)%256] )))
583            else:
584                raise AssertionError()
585            mac.update(compatHMAC(b))
586            macBytes = bytearray(mac.digest())
587            if self.fault == Fault.badMAC:
588                macBytes[0] = (macBytes[0]+1) % 256
589
590        #Encrypt for Block or Stream Cipher
591        if self._writeState.encContext:
592            #Add padding and encrypt (for Block Cipher):
593            if self._writeState.encContext.isBlockCipher:
594
595                #Add TLS 1.1 fixed block
596                if self.version == (3,2):
597                    b = self.fixedIVBlock + b
598
599                #Add padding: b = b + (macBytes + paddingBytes)
600                currentLength = len(b) + len(macBytes)
601                blockLength = self._writeState.encContext.block_size
602                paddingLength = blockLength - 1 - (currentLength % blockLength)
603
604                paddingBytes = bytearray([paddingLength] * (paddingLength+1))
605                if self.fault == Fault.badPadding:
606                    paddingBytes[0] = (paddingBytes[0]+1) % 256
607                endBytes = macBytes + paddingBytes
608                b += endBytes
609                #Encrypt
610                b = self._writeState.encContext.encrypt(b)
611
612            #Encrypt (for Stream Cipher)
613            else:
614                b += macBytes
615                b = self._writeState.encContext.encrypt(b)
616
617        #Add record header and send
618        r = RecordHeader3().create(self.version, contentType, len(b))
619        s = r.write() + b
620        while 1:
621            try:
622                bytesSent = self.sock.send(s) #Might raise socket.error
623            except socket.error as why:
624                if why.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN):
625                    yield 1
626                    continue
627                else:
628                    # The socket was unexpectedly closed.  The tricky part
629                    # is that there may be an alert sent by the other party
630                    # sitting in the read buffer.  So, if we get here after
631                    # handshaking, we will just raise the error and let the
632                    # caller read more data if it would like, thus stumbling
633                    # upon the error.
634                    #
635                    # However, if we get here DURING handshaking, we take
636                    # it upon ourselves to see if the next message is an
637                    # Alert.
638                    if contentType == ContentType.handshake:
639
640                        # See if there's an alert record
641                        # Could raise socket.error or TLSAbruptCloseError
642                        for result in self._getNextRecord():
643                            if result in (0,1):
644                                yield result
645
646                        # Closes the socket
647                        self._shutdown(False)
648
649                        # If we got an alert, raise it
650                        recordHeader, p = result
651                        if recordHeader.type == ContentType.alert:
652                            alert = Alert().parse(p)
653                            raise TLSRemoteAlert(alert)
654                    else:
655                        # If we got some other message who know what
656                        # the remote side is doing, just go ahead and
657                        # raise the socket.error
658                        raise
659            if bytesSent == len(s):
660                return
661            s = s[bytesSent:]
662            yield 1
663
664
665    def _getMsg(self, expectedType, secondaryType=None, constructorType=None):
666        try:
667            if not isinstance(expectedType, tuple):
668                expectedType = (expectedType,)
669
670            #Spin in a loop, until we've got a non-empty record of a type we
671            #expect.  The loop will be repeated if:
672            #  - we receive a renegotiation attempt; we send no_renegotiation,
673            #    then try again
674            #  - we receive an empty application-data fragment; we try again
675            while 1:
676                for result in self._getNextRecord():
677                    if result in (0,1):
678                        yield result
679                recordHeader, p = result
680
681                #If this is an empty application-data fragment, try again
682                if recordHeader.type == ContentType.application_data:
683                    if p.index == len(p.bytes):
684                        continue
685
686                #If we received an unexpected record type...
687                if recordHeader.type not in expectedType:
688
689                    #If we received an alert...
690                    if recordHeader.type == ContentType.alert:
691                        alert = Alert().parse(p)
692
693                        #We either received a fatal error, a warning, or a
694                        #close_notify.  In any case, we're going to close the
695                        #connection.  In the latter two cases we respond with
696                        #a close_notify, but ignore any socket errors, since
697                        #the other side might have already closed the socket.
698                        if alert.level == AlertLevel.warning or \
699                           alert.description == AlertDescription.close_notify:
700
701                            #If the sendMsg() call fails because the socket has
702                            #already been closed, we will be forgiving and not
703                            #report the error nor invalidate the "resumability"
704                            #of the session.
705                            try:
706                                alertMsg = Alert()
707                                alertMsg.create(AlertDescription.close_notify,
708                                                AlertLevel.warning)
709                                for result in self._sendMsg(alertMsg):
710                                    yield result
711                            except socket.error:
712                                pass
713
714                            if alert.description == \
715                                   AlertDescription.close_notify:
716                                self._shutdown(True)
717                            elif alert.level == AlertLevel.warning:
718                                self._shutdown(False)
719
720                        else: #Fatal alert:
721                            self._shutdown(False)
722
723                        #Raise the alert as an exception
724                        raise TLSRemoteAlert(alert)
725
726                    #If we received a renegotiation attempt...
727                    if recordHeader.type == ContentType.handshake:
728                        subType = p.get(1)
729                        reneg = False
730                        if self._client:
731                            if subType == HandshakeType.hello_request:
732                                reneg = True
733                        else:
734                            if subType == HandshakeType.client_hello:
735                                reneg = True
736                        #Send no_renegotiation, then try again
737                        if reneg:
738                            alertMsg = Alert()
739                            alertMsg.create(AlertDescription.no_renegotiation,
740                                            AlertLevel.warning)
741                            for result in self._sendMsg(alertMsg):
742                                yield result
743                            continue
744
745                    #Otherwise: this is an unexpected record, but neither an
746                    #alert nor renegotiation
747                    for result in self._sendError(\
748                            AlertDescription.unexpected_message,
749                            "received type=%d" % recordHeader.type):
750                        yield result
751
752                break
753
754            #Parse based on content_type
755            if recordHeader.type == ContentType.change_cipher_spec:
756                yield ChangeCipherSpec().parse(p)
757            elif recordHeader.type == ContentType.alert:
758                yield Alert().parse(p)
759            elif recordHeader.type == ContentType.application_data:
760                yield ApplicationData().parse(p)
761            elif recordHeader.type == ContentType.handshake:
762                #Convert secondaryType to tuple, if it isn't already
763                if not isinstance(secondaryType, tuple):
764                    secondaryType = (secondaryType,)
765
766                #If it's a handshake message, check handshake header
767                if recordHeader.ssl2:
768                    subType = p.get(1)
769                    if subType != HandshakeType.client_hello:
770                        for result in self._sendError(\
771                                AlertDescription.unexpected_message,
772                                "Can only handle SSLv2 ClientHello messages"):
773                            yield result
774                    if HandshakeType.client_hello not in secondaryType:
775                        for result in self._sendError(\
776                                AlertDescription.unexpected_message):
777                            yield result
778                    subType = HandshakeType.client_hello
779                else:
780                    subType = p.get(1)
781                    if subType not in secondaryType:
782                        for result in self._sendError(\
783                                AlertDescription.unexpected_message,
784                                "Expecting %s, got %s" % (str(secondaryType), subType)):
785                            yield result
786
787                #Update handshake hashes
788                self._handshake_md5.update(compat26Str(p.bytes))
789                self._handshake_sha.update(compat26Str(p.bytes))
790
791                #Parse based on handshake type
792                if subType == HandshakeType.client_hello:
793                    yield ClientHello(recordHeader.ssl2).parse(p)
794                elif subType == HandshakeType.server_hello:
795                    yield ServerHello().parse(p)
796                elif subType == HandshakeType.certificate:
797                    yield Certificate(constructorType).parse(p)
798                elif subType == HandshakeType.certificate_request:
799                    yield CertificateRequest().parse(p)
800                elif subType == HandshakeType.certificate_verify:
801                    yield CertificateVerify().parse(p)
802                elif subType == HandshakeType.server_key_exchange:
803                    yield ServerKeyExchange(constructorType).parse(p)
804                elif subType == HandshakeType.server_hello_done:
805                    yield ServerHelloDone().parse(p)
806                elif subType == HandshakeType.client_key_exchange:
807                    yield ClientKeyExchange(constructorType, \
808                                            self.version).parse(p)
809                elif subType == HandshakeType.finished:
810                    yield Finished(self.version).parse(p)
811                elif subType == HandshakeType.next_protocol:
812                    yield NextProtocol().parse(p)
813                elif subType == HandshakeType.encrypted_extensions:
814                    yield EncryptedExtensions().parse(p)
815                else:
816                    raise AssertionError()
817
818        #If an exception was raised by a Parser or Message instance:
819        except SyntaxError as e:
820            for result in self._sendError(AlertDescription.decode_error,
821                                         formatExceptionTrace(e)):
822                yield result
823
824
825    #Returns next record or next handshake message
826    def _getNextRecord(self):
827
828        #If there's a handshake message waiting, return it
829        if self._handshakeBuffer:
830            recordHeader, b = self._handshakeBuffer[0]
831            self._handshakeBuffer = self._handshakeBuffer[1:]
832            yield (recordHeader, Parser(b))
833            return
834
835        #Otherwise...
836        #Read the next record header
837        b = bytearray(0)
838        recordHeaderLength = 1
839        ssl2 = False
840        while 1:
841            try:
842                s = self.sock.recv(recordHeaderLength-len(b))
843            except socket.error as why:
844                if why.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN):
845                    yield 0
846                    continue
847                else:
848                    raise
849
850            #If the connection was abruptly closed, raise an error
851            if len(s)==0:
852                raise TLSAbruptCloseError()
853
854            b += bytearray(s)
855            if len(b)==1:
856                if b[0] in ContentType.all:
857                    ssl2 = False
858                    recordHeaderLength = 5
859                elif b[0] == 128:
860                    ssl2 = True
861                    recordHeaderLength = 2
862                else:
863                    raise SyntaxError()
864            if len(b) == recordHeaderLength:
865                break
866
867        #Parse the record header
868        if ssl2:
869            r = RecordHeader2().parse(Parser(b))
870        else:
871            r = RecordHeader3().parse(Parser(b))
872
873        #Check the record header fields
874        if r.length > 18432:
875            for result in self._sendError(AlertDescription.record_overflow):
876                yield result
877
878        #Read the record contents
879        b = bytearray(0)
880        while 1:
881            try:
882                s = self.sock.recv(r.length - len(b))
883            except socket.error as why:
884                if why.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN):
885                    yield 0
886                    continue
887                else:
888                    raise
889
890            #If the connection is closed, raise a socket error
891            if len(s)==0:
892                    raise TLSAbruptCloseError()
893
894            b += bytearray(s)
895            if len(b) == r.length:
896                break
897
898        #Check the record header fields (2)
899        #We do this after reading the contents from the socket, so that
900        #if there's an error, we at least don't leave extra bytes in the
901        #socket..
902        #
903        # THIS CHECK HAS NO SECURITY RELEVANCE (?), BUT COULD HURT INTEROP.
904        # SO WE LEAVE IT OUT FOR NOW.
905        #
906        #if self._versionCheck and r.version != self.version:
907        #    for result in self._sendError(AlertDescription.protocol_version,
908        #            "Version in header field: %s, should be %s" % (str(r.version),
909        #                                                       str(self.version))):
910        #        yield result
911
912        #Decrypt the record
913        for result in self._decryptRecord(r.type, b):
914            if result in (0,1): yield result
915            else: break
916        b = result
917        p = Parser(b)
918
919        #If it doesn't contain handshake messages, we can just return it
920        if r.type != ContentType.handshake:
921            yield (r, p)
922        #If it's an SSLv2 ClientHello, we can return it as well
923        elif r.ssl2:
924            yield (r, p)
925        else:
926            #Otherwise, we loop through and add the handshake messages to the
927            #handshake buffer
928            while 1:
929                if p.index == len(b): #If we're at the end
930                    if not self._handshakeBuffer:
931                        for result in self._sendError(\
932                                AlertDescription.decode_error, \
933                                "Received empty handshake record"):
934                            yield result
935                    break
936                #There needs to be at least 4 bytes to get a header
937                if p.index+4 > len(b):
938                    for result in self._sendError(\
939                            AlertDescription.decode_error,
940                            "A record has a partial handshake message (1)"):
941                        yield result
942                p.get(1) # skip handshake type
943                msgLength = p.get(3)
944                if p.index+msgLength > len(b):
945                    for result in self._sendError(\
946                            AlertDescription.decode_error,
947                            "A record has a partial handshake message (2)"):
948                        yield result
949
950                handshakePair = (r, b[p.index-4 : p.index+msgLength])
951                self._handshakeBuffer.append(handshakePair)
952                p.index += msgLength
953
954            #We've moved at least one handshake message into the
955            #handshakeBuffer, return the first one
956            recordHeader, b = self._handshakeBuffer[0]
957            self._handshakeBuffer = self._handshakeBuffer[1:]
958            yield (recordHeader, Parser(b))
959
960
961    def _decryptRecord(self, recordType, b):
962        if self._readState.encContext:
963
964            #Decrypt if it's a block cipher
965            if self._readState.encContext.isBlockCipher:
966                blockLength = self._readState.encContext.block_size
967                if len(b) % blockLength != 0:
968                    for result in self._sendError(\
969                            AlertDescription.decryption_failed,
970                            "Encrypted data not a multiple of blocksize"):
971                        yield result
972                b = self._readState.encContext.decrypt(b)
973                if self.version == (3,2): #For TLS 1.1, remove explicit IV
974                    b = b[self._readState.encContext.block_size : ]
975
976                #Check padding
977                paddingGood = True
978                paddingLength = b[-1]
979                if (paddingLength+1) > len(b):
980                    paddingGood=False
981                    totalPaddingLength = 0
982                else:
983                    if self.version == (3,0):
984                        totalPaddingLength = paddingLength+1
985                    elif self.version in ((3,1), (3,2)):
986                        totalPaddingLength = paddingLength+1
987                        paddingBytes = b[-totalPaddingLength:-1]
988                        for byte in paddingBytes:
989                            if byte != paddingLength:
990                                paddingGood = False
991                                totalPaddingLength = 0
992                    else:
993                        raise AssertionError()
994
995            #Decrypt if it's a stream cipher
996            else:
997                paddingGood = True
998                b = self._readState.encContext.decrypt(b)
999                totalPaddingLength = 0
1000
1001            #Check MAC
1002            macGood = True
1003            macLength = self._readState.macContext.digest_size
1004            endLength = macLength + totalPaddingLength
1005            if endLength > len(b):
1006                macGood = False
1007            else:
1008                #Read MAC
1009                startIndex = len(b) - endLength
1010                endIndex = startIndex + macLength
1011                checkBytes = b[startIndex : endIndex]
1012
1013                #Calculate MAC
1014                seqnumBytes = self._readState.getSeqNumBytes()
1015                b = b[:-endLength]
1016                mac = self._readState.macContext.copy()
1017                mac.update(compatHMAC(seqnumBytes))
1018                mac.update(compatHMAC(bytearray([recordType])))
1019                if self.version == (3,0):
1020                    mac.update( compatHMAC(bytearray( [len(b)//256] ) ))
1021                    mac.update( compatHMAC(bytearray( [len(b)%256] ) ))
1022                elif self.version in ((3,1), (3,2)):
1023                    mac.update(compatHMAC(bytearray( [self.version[0]] ) ))
1024                    mac.update(compatHMAC(bytearray( [self.version[1]] ) ))
1025                    mac.update(compatHMAC(bytearray( [len(b)//256] ) ))
1026                    mac.update(compatHMAC(bytearray( [len(b)%256] ) ))
1027                else:
1028                    raise AssertionError()
1029                mac.update(compatHMAC(b))
1030                macBytes = bytearray(mac.digest())
1031
1032                #Compare MACs
1033                if macBytes != checkBytes:
1034                    macGood = False
1035
1036            if not (paddingGood and macGood):
1037                for result in self._sendError(AlertDescription.bad_record_mac,
1038                                          "MAC failure (or padding failure)"):
1039                    yield result
1040
1041        yield b
1042
1043    def _handshakeStart(self, client):
1044        if not self.closed:
1045            raise ValueError("Renegotiation disallowed for security reasons")
1046        self._client = client
1047        self._handshake_md5 = hashlib.md5()
1048        self._handshake_sha = hashlib.sha1()
1049        self._handshakeBuffer = []
1050        self.allegedSrpUsername = None
1051        self._refCount = 1
1052
1053    def _handshakeDone(self, resumed):
1054        self.resumed = resumed
1055        self.closed = False
1056
1057    def _calcPendingStates(self, cipherSuite, masterSecret,
1058            clientRandom, serverRandom, implementations):
1059        if cipherSuite in CipherSuite.aes128Suites:
1060            keyLength = 16
1061            ivLength = 16
1062            createCipherFunc = createAES
1063        elif cipherSuite in CipherSuite.aes256Suites:
1064            keyLength = 32
1065            ivLength = 16
1066            createCipherFunc = createAES
1067        elif cipherSuite in CipherSuite.rc4Suites:
1068            keyLength = 16
1069            ivLength = 0
1070            createCipherFunc = createRC4
1071        elif cipherSuite in CipherSuite.tripleDESSuites:
1072            keyLength = 24
1073            ivLength = 8
1074            createCipherFunc = createTripleDES
1075        else:
1076            raise AssertionError()
1077
1078        if cipherSuite in CipherSuite.shaSuites:
1079            macLength = 20
1080            digestmod = hashlib.sha1
1081        elif cipherSuite in CipherSuite.md5Suites:
1082            macLength = 16
1083            digestmod = hashlib.md5
1084
1085        if self.version == (3,0):
1086            createMACFunc = createMAC_SSL
1087        elif self.version in ((3,1), (3,2)):
1088            createMACFunc = createHMAC
1089
1090        outputLength = (macLength*2) + (keyLength*2) + (ivLength*2)
1091
1092        #Calculate Keying Material from Master Secret
1093        if self.version == (3,0):
1094            keyBlock = PRF_SSL(masterSecret,
1095                               serverRandom + clientRandom,
1096                               outputLength)
1097        elif self.version in ((3,1), (3,2)):
1098            keyBlock = PRF(masterSecret,
1099                           b"key expansion",
1100                           serverRandom + clientRandom,
1101                           outputLength)
1102        else:
1103            raise AssertionError()
1104
1105        #Slice up Keying Material
1106        clientPendingState = _ConnectionState()
1107        serverPendingState = _ConnectionState()
1108        p = Parser(keyBlock)
1109        clientMACBlock = p.getFixBytes(macLength)
1110        serverMACBlock = p.getFixBytes(macLength)
1111        clientKeyBlock = p.getFixBytes(keyLength)
1112        serverKeyBlock = p.getFixBytes(keyLength)
1113        clientIVBlock  = p.getFixBytes(ivLength)
1114        serverIVBlock  = p.getFixBytes(ivLength)
1115        clientPendingState.macContext = createMACFunc(
1116            compatHMAC(clientMACBlock), digestmod=digestmod)
1117        serverPendingState.macContext = createMACFunc(
1118            compatHMAC(serverMACBlock), digestmod=digestmod)
1119        clientPendingState.encContext = createCipherFunc(clientKeyBlock,
1120                                                         clientIVBlock,
1121                                                         implementations)
1122        serverPendingState.encContext = createCipherFunc(serverKeyBlock,
1123                                                         serverIVBlock,
1124                                                         implementations)
1125
1126        #Assign new connection states to pending states
1127        if self._client:
1128            self._pendingWriteState = clientPendingState
1129            self._pendingReadState = serverPendingState
1130        else:
1131            self._pendingWriteState = serverPendingState
1132            self._pendingReadState = clientPendingState
1133
1134        if self.version == (3,2) and ivLength:
1135            #Choose fixedIVBlock for TLS 1.1 (this is encrypted with the CBC
1136            #residue to create the IV for each sent block)
1137            self.fixedIVBlock = getRandomBytes(ivLength)
1138
1139    def _changeWriteState(self):
1140        self._writeState = self._pendingWriteState
1141        self._pendingWriteState = _ConnectionState()
1142
1143    def _changeReadState(self):
1144        self._readState = self._pendingReadState
1145        self._pendingReadState = _ConnectionState()
1146
1147    #Used for Finished messages and CertificateVerify messages in SSL v3
1148    def _calcSSLHandshakeHash(self, masterSecret, label):
1149        imac_md5 = self._handshake_md5.copy()
1150        imac_sha = self._handshake_sha.copy()
1151
1152        imac_md5.update(compatHMAC(label + masterSecret + bytearray([0x36]*48)))
1153        imac_sha.update(compatHMAC(label + masterSecret + bytearray([0x36]*40)))
1154
1155        md5Bytes = MD5(masterSecret + bytearray([0x5c]*48) + \
1156                         bytearray(imac_md5.digest()))
1157        shaBytes = SHA1(masterSecret + bytearray([0x5c]*40) + \
1158                         bytearray(imac_sha.digest()))
1159
1160        return md5Bytes + shaBytes
1161
1162