1"""Helper class for TLSConnection."""
2from __future__ import generators
3
4from utils.compat import *
5from utils.cryptomath import *
6from utils.cipherfactory import createAES, createRC4, createTripleDES
7from utils.codec import *
8from errors import *
9from messages import *
10from mathtls import *
11from constants import *
12from utils.cryptomath import getRandomBytes
13from utils import hmac
14from FileObject import FileObject
15
16# The sha module is deprecated in Python 2.6
17try:
18    import sha
19except ImportError:
20    from hashlib import sha1 as sha
21
22# The md5 module is deprecated in Python 2.6
23try:
24    import md5
25except ImportError:
26    from hashlib import md5
27
28import socket
29import errno
30import traceback
31
32class _ConnectionState:
33    def __init__(self):
34        self.macContext = None
35        self.encContext = None
36        self.seqnum = 0
37
38    def getSeqNumStr(self):
39        w = Writer(8)
40        w.add(self.seqnum, 8)
41        seqnumStr = bytesToString(w.bytes)
42        self.seqnum += 1
43        return seqnumStr
44
45
46class TLSRecordLayer:
47    """
48    This class handles data transmission for a TLS connection.
49
50    Its only subclass is L{tlslite.TLSConnection.TLSConnection}.  We've
51    separated the code in this class from TLSConnection to make things
52    more readable.
53
54
55    @type sock: socket.socket
56    @ivar sock: The underlying socket object.
57
58    @type session: L{tlslite.Session.Session}
59    @ivar session: The session corresponding to this connection.
60
61    Due to TLS session resumption, multiple connections can correspond
62    to the same underlying session.
63
64    @type version: tuple
65    @ivar version: The TLS version being used for this connection.
66
67    (3,0) means SSL 3.0, and (3,1) means TLS 1.0.
68
69    @type closed: bool
70    @ivar closed: If this connection is closed.
71
72    @type resumed: bool
73    @ivar resumed: If this connection is based on a resumed session.
74
75    @type allegedSharedKeyUsername: str or None
76    @ivar allegedSharedKeyUsername:  This is set to the shared-key
77    username asserted by the client, whether the handshake succeeded or
78    not.  If the handshake fails, this can be inspected to
79    determine if a guessing attack is in progress against a particular
80    user account.
81
82    @type allegedSrpUsername: str or None
83    @ivar allegedSrpUsername:  This is set to the SRP username
84    asserted by the client, whether the handshake succeeded or not.
85    If the handshake fails, this can be inspected to determine
86    if a guessing attack is in progress against a particular user
87    account.
88
89    @type closeSocket: bool
90    @ivar closeSocket: If the socket should be closed when the
91    connection is closed (writable).
92
93    If you set this to True, TLS Lite will assume the responsibility of
94    closing the socket when the TLS Connection is shutdown (either
95    through an error or through the user calling close()).  The default
96    is False.
97
98    @type ignoreAbruptClose: bool
99    @ivar ignoreAbruptClose: If an abrupt close of the socket should
100    raise an error (writable).
101
102    If you set this to True, TLS Lite will not raise a
103    L{tlslite.errors.TLSAbruptCloseError} exception if the underlying
104    socket is unexpectedly closed.  Such an unexpected closure could be
105    caused by an attacker.  However, it also occurs with some incorrect
106    TLS implementations.
107
108    You should set this to True only if you're not worried about an
109    attacker truncating the connection, and only if necessary to avoid
110    spurious errors.  The default is False.
111
112    @sort: __init__, read, readAsync, write, writeAsync, close, closeAsync,
113    getCipherImplementation, getCipherName
114    """
115
116    def __init__(self, sock):
117        self.sock = sock
118
119        #My session object (Session instance; read-only)
120        self.session = None
121
122        #Am I a client or server?
123        self._client = None
124
125        #Buffers for processing messages
126        self._handshakeBuffer = []
127        self._readBuffer = ""
128
129        #Handshake digests
130        self._handshake_md5 = md5.md5()
131        self._handshake_sha = sha.sha()
132
133        #TLS Protocol Version
134        self.version = (0,0) #read-only
135        self._versionCheck = False #Once we choose a version, this is True
136
137        #Current and Pending connection states
138        self._writeState = _ConnectionState()
139        self._readState = _ConnectionState()
140        self._pendingWriteState = _ConnectionState()
141        self._pendingReadState = _ConnectionState()
142
143        #Is the connection open?
144        self.closed = True #read-only
145        self._refCount = 0 #Used to trigger closure
146
147        #Is this a resumed (or shared-key) session?
148        self.resumed = False #read-only
149
150        #What username did the client claim in his handshake?
151        self.allegedSharedKeyUsername = None
152        self.allegedSrpUsername = None
153
154        #On a call to close(), do we close the socket? (writeable)
155        self.closeSocket = False
156
157        #If the socket is abruptly closed, do we ignore it
158        #and pretend the connection was shut down properly? (writeable)
159        self.ignoreAbruptClose = False
160
161        #Fault we will induce, for testing purposes
162        self.fault = None
163
164    #*********************************************************
165    # Public Functions START
166    #*********************************************************
167
168    def read(self, max=None, min=1):
169        """Read some data from the TLS connection.
170
171        This function will block until at least 'min' bytes are
172        available (or the connection is closed).
173
174        If an exception is raised, the connection will have been
175        automatically closed.
176
177        @type max: int
178        @param max: The maximum number of bytes to return.
179
180        @type min: int
181        @param min: The minimum number of bytes to return
182
183        @rtype: str
184        @return: A string of no more than 'max' bytes, and no fewer
185        than 'min' (unless the connection has been closed, in which
186        case fewer than 'min' bytes may be returned).
187
188        @raise socket.error: If a socket error occurs.
189        @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed
190        without a preceding alert.
191        @raise tlslite.errors.TLSAlert: If a TLS alert is signalled.
192        """
193        for result in self.readAsync(max, min):
194            pass
195        return result
196
197    def readAsync(self, max=None, min=1):
198        """Start a read operation on the TLS connection.
199
200        This function returns a generator which behaves similarly to
201        read().  Successive invocations of the generator will return 0
202        if it is waiting to read from the socket, 1 if it is waiting
203        to write to the socket, or a string if the read operation has
204        completed.
205
206        @rtype: iterable
207        @return: A generator; see above for details.
208        """
209        try:
210            while len(self._readBuffer)<min and not self.closed:
211                try:
212                    for result in self._getMsg(ContentType.application_data):
213                        if result in (0,1):
214                            yield result
215                    applicationData = result
216                    self._readBuffer += bytesToString(applicationData.write())
217                except TLSRemoteAlert, alert:
218                    if alert.description != AlertDescription.close_notify:
219                        raise
220                except TLSAbruptCloseError:
221                    if not self.ignoreAbruptClose:
222                        raise
223                    else:
224                        self._shutdown(True)
225
226            if max == None:
227                max = len(self._readBuffer)
228
229            returnStr = self._readBuffer[:max]
230            self._readBuffer = self._readBuffer[max:]
231            yield returnStr
232        except:
233            self._shutdown(False)
234            raise
235
236    def write(self, s):
237        """Write some data to the TLS connection.
238
239        This function will block until all the data has been sent.
240
241        If an exception is raised, the connection will have been
242        automatically closed.
243
244        @type s: str
245        @param s: The data to transmit to the other party.
246
247        @raise socket.error: If a socket error occurs.
248        """
249        for result in self.writeAsync(s):
250            pass
251
252    def writeAsync(self, s):
253        """Start a write operation on the TLS connection.
254
255        This function returns a generator which behaves similarly to
256        write().  Successive invocations of the generator will return
257        1 if it is waiting to write to the socket, or will raise
258        StopIteration if the write operation has completed.
259
260        @rtype: iterable
261        @return: A generator; see above for details.
262        """
263        try:
264            if self.closed:
265                raise ValueError()
266
267            index = 0
268            blockSize = 16384
269            skipEmptyFrag = False
270            while 1:
271                startIndex = index * blockSize
272                endIndex = startIndex + blockSize
273                if startIndex >= len(s):
274                    break
275                if endIndex > len(s):
276                    endIndex = len(s)
277                block = stringToBytes(s[startIndex : endIndex])
278                applicationData = ApplicationData().create(block)
279                for result in self._sendMsg(applicationData, skipEmptyFrag):
280                    yield result
281                skipEmptyFrag = True #only send an empy fragment on 1st message
282                index += 1
283        except:
284            self._shutdown(False)
285            raise
286
287    def close(self):
288        """Close the TLS connection.
289
290        This function will block until it has exchanged close_notify
291        alerts with the other party.  After doing so, it will shut down the
292        TLS connection.  Further attempts to read through this connection
293        will return "".  Further attempts to write through this connection
294        will raise ValueError.
295
296        If makefile() has been called on this connection, the connection
297        will be not be closed until the connection object and all file
298        objects have been closed.
299
300        Even if an exception is raised, the connection will have been
301        closed.
302
303        @raise socket.error: If a socket error occurs.
304        @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed
305        without a preceding alert.
306        @raise tlslite.errors.TLSAlert: If a TLS alert is signalled.
307        """
308        if not self.closed:
309            for result in self._decrefAsync():
310                pass
311
312    def closeAsync(self):
313        """Start a close operation on the TLS connection.
314
315        This function returns a generator which behaves similarly to
316        close().  Successive invocations of the generator will return 0
317        if it is waiting to read from the socket, 1 if it is waiting
318        to write to the socket, or will raise StopIteration if the
319        close operation has completed.
320
321        @rtype: iterable
322        @return: A generator; see above for details.
323        """
324        if not self.closed:
325            for result in self._decrefAsync():
326                yield result
327
328    def _decrefAsync(self):
329        self._refCount -= 1
330        if self._refCount == 0 and not self.closed:
331            try:
332                for result in self._sendMsg(Alert().create(\
333                        AlertDescription.close_notify, AlertLevel.warning)):
334                    yield result
335                alert = None
336                # Forcing a shutdown as WinHTTP does not seem to be
337                # responsive to the close notify.
338                prevCloseSocket = self.closeSocket
339                self.closeSocket = True
340                self._shutdown(True)
341                self.closeSocket = prevCloseSocket
342                while not alert:
343                    for result in self._getMsg((ContentType.alert, \
344                                              ContentType.application_data)):
345                        if result in (0,1):
346                            yield result
347                    if result.contentType == ContentType.alert:
348                        alert = result
349                if alert.description == AlertDescription.close_notify:
350                    self._shutdown(True)
351                else:
352                    raise TLSRemoteAlert(alert)
353            except (socket.error, TLSAbruptCloseError):
354                #If the other side closes the socket, that's okay
355                self._shutdown(True)
356            except:
357                self._shutdown(False)
358                raise
359
360    def getCipherName(self):
361        """Get the name of the cipher used with this connection.
362
363        @rtype: str
364        @return: The name of the cipher used with this connection.
365        Either 'aes128', 'aes256', 'rc4', or '3des'.
366        """
367        if not self._writeState.encContext:
368            return None
369        return self._writeState.encContext.name
370
371    def getCipherImplementation(self):
372        """Get the name of the cipher implementation used with
373        this connection.
374
375        @rtype: str
376        @return: The name of the cipher implementation used with
377        this connection.  Either 'python', 'cryptlib', 'openssl',
378        or 'pycrypto'.
379        """
380        if not self._writeState.encContext:
381            return None
382        return self._writeState.encContext.implementation
383
384
385
386    #Emulate a socket, somewhat -
387    def send(self, s):
388        """Send data to the TLS connection (socket emulation).
389
390        @raise socket.error: If a socket error occurs.
391        """
392        self.write(s)
393        return len(s)
394
395    def sendall(self, s):
396        """Send data to the TLS connection (socket emulation).
397
398        @raise socket.error: If a socket error occurs.
399        """
400        self.write(s)
401
402    def recv(self, bufsize):
403        """Get some data from the TLS connection (socket emulation).
404
405        @raise socket.error: If a socket error occurs.
406        @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed
407        without a preceding alert.
408        @raise tlslite.errors.TLSAlert: If a TLS alert is signalled.
409        """
410        return self.read(bufsize)
411
412    def makefile(self, mode='r', bufsize=-1):
413        """Create a file object for the TLS connection (socket emulation).
414
415        @rtype: L{tlslite.FileObject.FileObject}
416        """
417        self._refCount += 1
418        return FileObject(self, mode, bufsize)
419
420    def getsockname(self):
421        """Return the socket's own address (socket emulation)."""
422        return self.sock.getsockname()
423
424    def getpeername(self):
425        """Return the remote address to which the socket is connected
426        (socket emulation)."""
427        return self.sock.getpeername()
428
429    def settimeout(self, value):
430        """Set a timeout on blocking socket operations (socket emulation)."""
431        return self.sock.settimeout(value)
432
433    def gettimeout(self):
434        """Return the timeout associated with socket operations (socket
435        emulation)."""
436        return self.sock.gettimeout()
437
438    def setsockopt(self, level, optname, value):
439        """Set the value of the given socket option (socket emulation)."""
440        return self.sock.setsockopt(level, optname, value)
441
442
443     #*********************************************************
444     # Public Functions END
445     #*********************************************************
446
447    def _shutdown(self, resumable):
448        self._writeState = _ConnectionState()
449        self._readState = _ConnectionState()
450        #Don't do this: self._readBuffer = ""
451        self.version = (0,0)
452        self._versionCheck = False
453        self.closed = True
454        if self.closeSocket:
455            self.sock.close()
456
457        #Even if resumable is False, we'll never toggle this on
458        if not resumable and self.session:
459            self.session.resumable = False
460
461
462    def _sendError(self, alertDescription, errorStr=None):
463        alert = Alert().create(alertDescription, AlertLevel.fatal)
464        for result in self._sendMsg(alert):
465            yield result
466        self._shutdown(False)
467        raise TLSLocalAlert(alert, errorStr)
468
469    def _sendMsgs(self, msgs):
470        skipEmptyFrag = False
471        for msg in msgs:
472            for result in self._sendMsg(msg, skipEmptyFrag):
473                yield result
474            skipEmptyFrag = True
475
476    def _sendMsg(self, msg, skipEmptyFrag=False):
477        bytes = msg.write()
478        contentType = msg.contentType
479
480        #Whenever we're connected and asked to send a message,
481        #we first send an empty Application Data message.  This prevents
482        #an attacker from launching a chosen-plaintext attack based on
483        #knowing the next IV.
484        if not self.closed and not skipEmptyFrag and self.version == (3,1):
485            if self._writeState.encContext:
486                if self._writeState.encContext.isBlockCipher:
487                    for result in self._sendMsg(ApplicationData(),
488                                               skipEmptyFrag=True):
489                        yield result
490
491        #Update handshake hashes
492        if contentType == ContentType.handshake:
493            bytesStr = bytesToString(bytes)
494            self._handshake_md5.update(bytesStr)
495            self._handshake_sha.update(bytesStr)
496
497        #Calculate MAC
498        if self._writeState.macContext:
499            seqnumStr = self._writeState.getSeqNumStr()
500            bytesStr = bytesToString(bytes)
501            mac = self._writeState.macContext.copy()
502            mac.update(seqnumStr)
503            mac.update(chr(contentType))
504            if self.version == (3,0):
505                mac.update( chr( int(len(bytes)/256) ) )
506                mac.update( chr( int(len(bytes)%256) ) )
507            elif self.version in ((3,1), (3,2)):
508                mac.update(chr(self.version[0]))
509                mac.update(chr(self.version[1]))
510                mac.update( chr( int(len(bytes)/256) ) )
511                mac.update( chr( int(len(bytes)%256) ) )
512            else:
513                raise AssertionError()
514            mac.update(bytesStr)
515            macString = mac.digest()
516            macBytes = stringToBytes(macString)
517            if self.fault == Fault.badMAC:
518                macBytes[0] = (macBytes[0]+1) % 256
519
520        #Encrypt for Block or Stream Cipher
521        if self._writeState.encContext:
522            #Add padding and encrypt (for Block Cipher):
523            if self._writeState.encContext.isBlockCipher:
524
525                #Add TLS 1.1 fixed block
526                if self.version == (3,2):
527                    bytes = self.fixedIVBlock + bytes
528
529                #Add padding: bytes = bytes + (macBytes + paddingBytes)
530                currentLength = len(bytes) + len(macBytes) + 1
531                blockLength = self._writeState.encContext.block_size
532                paddingLength = blockLength-(currentLength % blockLength)
533
534                paddingBytes = createByteArraySequence([paddingLength] * \
535                                                       (paddingLength+1))
536                if self.fault == Fault.badPadding:
537                    paddingBytes[0] = (paddingBytes[0]+1) % 256
538                endBytes = concatArrays(macBytes, paddingBytes)
539                bytes = concatArrays(bytes, endBytes)
540                #Encrypt
541                plaintext = stringToBytes(bytes)
542                ciphertext = self._writeState.encContext.encrypt(plaintext)
543                bytes = stringToBytes(ciphertext)
544
545            #Encrypt (for Stream Cipher)
546            else:
547                bytes = concatArrays(bytes, macBytes)
548                plaintext = bytesToString(bytes)
549                ciphertext = self._writeState.encContext.encrypt(plaintext)
550                bytes = stringToBytes(ciphertext)
551
552        #Add record header and send
553        r = RecordHeader3().create(self.version, contentType, len(bytes))
554        s = bytesToString(concatArrays(r.write(), bytes))
555        while 1:
556            try:
557                bytesSent = self.sock.send(s) #Might raise socket.error
558            except socket.error, why:
559                if why[0] == errno.EWOULDBLOCK:
560                    yield 1
561                    continue
562                else:
563                    raise
564            if bytesSent == len(s):
565                return
566            s = s[bytesSent:]
567            yield 1
568
569
570    def _getMsg(self, expectedType, secondaryType=None, constructorType=None):
571        try:
572            if not isinstance(expectedType, tuple):
573                expectedType = (expectedType,)
574
575            #Spin in a loop, until we've got a non-empty record of a type we
576            #expect.  The loop will be repeated if:
577            #  - we receive a renegotiation attempt; we send no_renegotiation,
578            #    then try again
579            #  - we receive an empty application-data fragment; we try again
580            while 1:
581                for result in self._getNextRecord():
582                    if result in (0,1):
583                        yield result
584                recordHeader, p = result
585
586                #If this is an empty application-data fragment, try again
587                if recordHeader.type == ContentType.application_data:
588                    if p.index == len(p.bytes):
589                        continue
590
591                #If we received an unexpected record type...
592                if recordHeader.type not in expectedType:
593
594                    #If we received an alert...
595                    if recordHeader.type == ContentType.alert:
596                        alert = Alert().parse(p)
597
598                        #We either received a fatal error, a warning, or a
599                        #close_notify.  In any case, we're going to close the
600                        #connection.  In the latter two cases we respond with
601                        #a close_notify, but ignore any socket errors, since
602                        #the other side might have already closed the socket.
603                        if alert.level == AlertLevel.warning or \
604                           alert.description == AlertDescription.close_notify:
605
606                            #If the sendMsg() call fails because the socket has
607                            #already been closed, we will be forgiving and not
608                            #report the error nor invalidate the "resumability"
609                            #of the session.
610                            try:
611                                alertMsg = Alert()
612                                alertMsg.create(AlertDescription.close_notify,
613                                                AlertLevel.warning)
614                                for result in self._sendMsg(alertMsg):
615                                    yield result
616                            except socket.error:
617                                pass
618
619                            if alert.description == \
620                                   AlertDescription.close_notify:
621                                self._shutdown(True)
622                            elif alert.level == AlertLevel.warning:
623                                self._shutdown(False)
624
625                        else: #Fatal alert:
626                            self._shutdown(False)
627
628                        #Raise the alert as an exception
629                        raise TLSRemoteAlert(alert)
630
631                    #If we received a renegotiation attempt...
632                    if recordHeader.type == ContentType.handshake:
633                        subType = p.get(1)
634                        reneg = False
635                        if self._client:
636                            if subType == HandshakeType.hello_request:
637                                reneg = True
638                        else:
639                            if subType == HandshakeType.client_hello:
640                                reneg = True
641                        #Send no_renegotiation, then try again
642                        if reneg:
643                            alertMsg = Alert()
644                            alertMsg.create(AlertDescription.no_renegotiation,
645                                            AlertLevel.warning)
646                            for result in self._sendMsg(alertMsg):
647                                yield result
648                            continue
649
650                    #Otherwise: this is an unexpected record, but neither an
651                    #alert nor renegotiation
652                    for result in self._sendError(\
653                            AlertDescription.unexpected_message,
654                            "received type=%d" % recordHeader.type):
655                        yield result
656
657                break
658
659            #Parse based on content_type
660            if recordHeader.type == ContentType.change_cipher_spec:
661                yield ChangeCipherSpec().parse(p)
662            elif recordHeader.type == ContentType.alert:
663                yield Alert().parse(p)
664            elif recordHeader.type == ContentType.application_data:
665                yield ApplicationData().parse(p)
666            elif recordHeader.type == ContentType.handshake:
667                #Convert secondaryType to tuple, if it isn't already
668                if not isinstance(secondaryType, tuple):
669                    secondaryType = (secondaryType,)
670
671                #If it's a handshake message, check handshake header
672                if recordHeader.ssl2:
673                    subType = p.get(1)
674                    if subType != HandshakeType.client_hello:
675                        for result in self._sendError(\
676                                AlertDescription.unexpected_message,
677                                "Can only handle SSLv2 ClientHello messages"):
678                            yield result
679                    if HandshakeType.client_hello not in secondaryType:
680                        for result in self._sendError(\
681                                AlertDescription.unexpected_message):
682                            yield result
683                    subType = HandshakeType.client_hello
684                else:
685                    subType = p.get(1)
686                    if subType not in secondaryType:
687                        for result in self._sendError(\
688                                AlertDescription.unexpected_message,
689                                "Expecting %s, got %s" % (str(secondaryType), subType)):
690                            yield result
691
692                #Update handshake hashes
693                sToHash = bytesToString(p.bytes)
694                self._handshake_md5.update(sToHash)
695                self._handshake_sha.update(sToHash)
696
697                #Parse based on handshake type
698                if subType == HandshakeType.client_hello:
699                    yield ClientHello(recordHeader.ssl2).parse(p)
700                elif subType == HandshakeType.server_hello:
701                    yield ServerHello().parse(p)
702                elif subType == HandshakeType.certificate:
703                    yield Certificate(constructorType).parse(p)
704                elif subType == HandshakeType.certificate_request:
705                    yield CertificateRequest().parse(p)
706                elif subType == HandshakeType.certificate_verify:
707                    yield CertificateVerify().parse(p)
708                elif subType == HandshakeType.server_key_exchange:
709                    yield ServerKeyExchange(constructorType).parse(p)
710                elif subType == HandshakeType.server_hello_done:
711                    yield ServerHelloDone().parse(p)
712                elif subType == HandshakeType.client_key_exchange:
713                    yield ClientKeyExchange(constructorType, \
714                                            self.version).parse(p)
715                elif subType == HandshakeType.finished:
716                    yield Finished(self.version).parse(p)
717                elif subType == HandshakeType.encrypted_extensions:
718                    yield EncryptedExtensions().parse(p)
719                else:
720                    raise AssertionError()
721
722        #If an exception was raised by a Parser or Message instance:
723        except SyntaxError, e:
724            for result in self._sendError(AlertDescription.decode_error,
725                                         formatExceptionTrace(e)):
726                yield result
727
728
729    #Returns next record or next handshake message
730    def _getNextRecord(self):
731
732        #If there's a handshake message waiting, return it
733        if self._handshakeBuffer:
734            recordHeader, bytes = self._handshakeBuffer[0]
735            self._handshakeBuffer = self._handshakeBuffer[1:]
736            yield (recordHeader, Parser(bytes))
737            return
738
739        #Otherwise...
740        #Read the next record header
741        bytes = createByteArraySequence([])
742        recordHeaderLength = 1
743        ssl2 = False
744        while 1:
745            try:
746                s = self.sock.recv(recordHeaderLength-len(bytes))
747            except socket.error, why:
748                if why[0] == errno.EWOULDBLOCK:
749                    yield 0
750                    continue
751                else:
752                    raise
753
754            #If the connection was abruptly closed, raise an error
755            if len(s)==0:
756                raise TLSAbruptCloseError()
757
758            bytes += stringToBytes(s)
759            if len(bytes)==1:
760                if bytes[0] in ContentType.all:
761                    ssl2 = False
762                    recordHeaderLength = 5
763                elif bytes[0] == 128:
764                    ssl2 = True
765                    recordHeaderLength = 2
766                else:
767                    raise SyntaxError()
768            if len(bytes) == recordHeaderLength:
769                break
770
771        #Parse the record header
772        if ssl2:
773            r = RecordHeader2().parse(Parser(bytes))
774        else:
775            r = RecordHeader3().parse(Parser(bytes))
776
777        #Check the record header fields
778        if r.length > 18432:
779            for result in self._sendError(AlertDescription.record_overflow):
780                yield result
781
782        #Read the record contents
783        bytes = createByteArraySequence([])
784        while 1:
785            try:
786                s = self.sock.recv(r.length - len(bytes))
787            except socket.error, why:
788                if why[0] == errno.EWOULDBLOCK:
789                    yield 0
790                    continue
791                else:
792                    raise
793
794            #If the connection is closed, raise a socket error
795            if len(s)==0:
796                    raise TLSAbruptCloseError()
797
798            bytes += stringToBytes(s)
799            if len(bytes) == r.length:
800                break
801
802        #Check the record header fields (2)
803        #We do this after reading the contents from the socket, so that
804        #if there's an error, we at least don't leave extra bytes in the
805        #socket..
806        #
807        # THIS CHECK HAS NO SECURITY RELEVANCE (?), BUT COULD HURT INTEROP.
808        # SO WE LEAVE IT OUT FOR NOW.
809        #
810        #if self._versionCheck and r.version != self.version:
811        #    for result in self._sendError(AlertDescription.protocol_version,
812        #            "Version in header field: %s, should be %s" % (str(r.version),
813        #                                                       str(self.version))):
814        #        yield result
815
816        #Decrypt the record
817        for result in self._decryptRecord(r.type, bytes):
818            if result in (0,1):
819                yield result
820            else:
821                break
822        bytes = result
823        p = Parser(bytes)
824
825        #If it doesn't contain handshake messages, we can just return it
826        if r.type != ContentType.handshake:
827            yield (r, p)
828        #If it's an SSLv2 ClientHello, we can return it as well
829        elif r.ssl2:
830            yield (r, p)
831        else:
832            #Otherwise, we loop through and add the handshake messages to the
833            #handshake buffer
834            while 1:
835                if p.index == len(bytes): #If we're at the end
836                    if not self._handshakeBuffer:
837                        for result in self._sendError(\
838                                AlertDescription.decode_error, \
839                                "Received empty handshake record"):
840                            yield result
841                    break
842                #There needs to be at least 4 bytes to get a header
843                if p.index+4 > len(bytes):
844                    for result in self._sendError(\
845                            AlertDescription.decode_error,
846                            "A record has a partial handshake message (1)"):
847                        yield result
848                p.get(1) # skip handshake type
849                msgLength = p.get(3)
850                if p.index+msgLength > len(bytes):
851                    for result in self._sendError(\
852                            AlertDescription.decode_error,
853                            "A record has a partial handshake message (2)"):
854                        yield result
855
856                handshakePair = (r, bytes[p.index-4 : p.index+msgLength])
857                self._handshakeBuffer.append(handshakePair)
858                p.index += msgLength
859
860            #We've moved at least one handshake message into the
861            #handshakeBuffer, return the first one
862            recordHeader, bytes = self._handshakeBuffer[0]
863            self._handshakeBuffer = self._handshakeBuffer[1:]
864            yield (recordHeader, Parser(bytes))
865
866
867    def _decryptRecord(self, recordType, bytes):
868        if self._readState.encContext:
869
870            #Decrypt if it's a block cipher
871            if self._readState.encContext.isBlockCipher:
872                blockLength = self._readState.encContext.block_size
873                if len(bytes) % blockLength != 0:
874                    for result in self._sendError(\
875                            AlertDescription.decryption_failed,
876                            "Encrypted data not a multiple of blocksize"):
877                        yield result
878                ciphertext = bytesToString(bytes)
879                plaintext = self._readState.encContext.decrypt(ciphertext)
880                if self.version == (3,2): #For TLS 1.1, remove explicit IV
881                    plaintext = plaintext[self._readState.encContext.block_size : ]
882                bytes = stringToBytes(plaintext)
883
884                #Check padding
885                paddingGood = True
886                paddingLength = bytes[-1]
887                if (paddingLength+1) > len(bytes):
888                    paddingGood=False
889                    totalPaddingLength = 0
890                else:
891                    if self.version == (3,0):
892                        totalPaddingLength = paddingLength+1
893                    elif self.version in ((3,1), (3,2)):
894                        totalPaddingLength = paddingLength+1
895                        paddingBytes = bytes[-totalPaddingLength:-1]
896                        for byte in paddingBytes:
897                            if byte != paddingLength:
898                                paddingGood = False
899                                totalPaddingLength = 0
900                    else:
901                        raise AssertionError()
902
903            #Decrypt if it's a stream cipher
904            else:
905                paddingGood = True
906                ciphertext = bytesToString(bytes)
907                plaintext = self._readState.encContext.decrypt(ciphertext)
908                bytes = stringToBytes(plaintext)
909                totalPaddingLength = 0
910
911            #Check MAC
912            macGood = True
913            macLength = self._readState.macContext.digest_size
914            endLength = macLength + totalPaddingLength
915            if endLength > len(bytes):
916                macGood = False
917            else:
918                #Read MAC
919                startIndex = len(bytes) - endLength
920                endIndex = startIndex + macLength
921                checkBytes = bytes[startIndex : endIndex]
922
923                #Calculate MAC
924                seqnumStr = self._readState.getSeqNumStr()
925                bytes = bytes[:-endLength]
926                bytesStr = bytesToString(bytes)
927                mac = self._readState.macContext.copy()
928                mac.update(seqnumStr)
929                mac.update(chr(recordType))
930                if self.version == (3,0):
931                    mac.update( chr( int(len(bytes)/256) ) )
932                    mac.update( chr( int(len(bytes)%256) ) )
933                elif self.version in ((3,1), (3,2)):
934                    mac.update(chr(self.version[0]))
935                    mac.update(chr(self.version[1]))
936                    mac.update( chr( int(len(bytes)/256) ) )
937                    mac.update( chr( int(len(bytes)%256) ) )
938                else:
939                    raise AssertionError()
940                mac.update(bytesStr)
941                macString = mac.digest()
942                macBytes = stringToBytes(macString)
943
944                #Compare MACs
945                if macBytes != checkBytes:
946                    macGood = False
947
948            if not (paddingGood and macGood):
949                for result in self._sendError(AlertDescription.bad_record_mac,
950                                          "MAC failure (or padding failure)"):
951                    yield result
952
953        yield bytes
954
955    def _handshakeStart(self, client):
956        self._client = client
957        self._handshake_md5 = md5.md5()
958        self._handshake_sha = sha.sha()
959        self._handshakeBuffer = []
960        self.allegedSharedKeyUsername = None
961        self.allegedSrpUsername = None
962        self._refCount = 1
963
964    def _handshakeDone(self, resumed):
965        self.resumed = resumed
966        self.closed = False
967
968    def _calcPendingStates(self, clientRandom, serverRandom, implementations):
969        if self.session.cipherSuite in CipherSuite.aes128Suites:
970            macLength = 20
971            keyLength = 16
972            ivLength = 16
973            createCipherFunc = createAES
974        elif self.session.cipherSuite in CipherSuite.aes256Suites:
975            macLength = 20
976            keyLength = 32
977            ivLength = 16
978            createCipherFunc = createAES
979        elif self.session.cipherSuite in CipherSuite.rc4Suites:
980            macLength = 20
981            keyLength = 16
982            ivLength = 0
983            createCipherFunc = createRC4
984        elif self.session.cipherSuite in CipherSuite.tripleDESSuites:
985            macLength = 20
986            keyLength = 24
987            ivLength = 8
988            createCipherFunc = createTripleDES
989        else:
990            raise AssertionError()
991
992        if self.version == (3,0):
993            createMACFunc = MAC_SSL
994        elif self.version in ((3,1), (3,2)):
995            createMACFunc = hmac.HMAC
996
997        outputLength = (macLength*2) + (keyLength*2) + (ivLength*2)
998
999        #Calculate Keying Material from Master Secret
1000        if self.version == (3,0):
1001            keyBlock = PRF_SSL(self.session.masterSecret,
1002                               concatArrays(serverRandom, clientRandom),
1003                               outputLength)
1004        elif self.version in ((3,1), (3,2)):
1005            keyBlock = PRF(self.session.masterSecret,
1006                           "key expansion",
1007                           concatArrays(serverRandom,clientRandom),
1008                           outputLength)
1009        else:
1010            raise AssertionError()
1011
1012        #Slice up Keying Material
1013        clientPendingState = _ConnectionState()
1014        serverPendingState = _ConnectionState()
1015        p = Parser(keyBlock)
1016        clientMACBlock = bytesToString(p.getFixBytes(macLength))
1017        serverMACBlock = bytesToString(p.getFixBytes(macLength))
1018        clientKeyBlock = bytesToString(p.getFixBytes(keyLength))
1019        serverKeyBlock = bytesToString(p.getFixBytes(keyLength))
1020        clientIVBlock  = bytesToString(p.getFixBytes(ivLength))
1021        serverIVBlock  = bytesToString(p.getFixBytes(ivLength))
1022        clientPendingState.macContext = createMACFunc(clientMACBlock,
1023                                                      digestmod=sha)
1024        serverPendingState.macContext = createMACFunc(serverMACBlock,
1025                                                      digestmod=sha)
1026        clientPendingState.encContext = createCipherFunc(clientKeyBlock,
1027                                                         clientIVBlock,
1028                                                         implementations)
1029        serverPendingState.encContext = createCipherFunc(serverKeyBlock,
1030                                                         serverIVBlock,
1031                                                         implementations)
1032
1033        #Assign new connection states to pending states
1034        if self._client:
1035            self._pendingWriteState = clientPendingState
1036            self._pendingReadState = serverPendingState
1037        else:
1038            self._pendingWriteState = serverPendingState
1039            self._pendingReadState = clientPendingState
1040
1041        if self.version == (3,2) and ivLength:
1042            #Choose fixedIVBlock for TLS 1.1 (this is encrypted with the CBC
1043            #residue to create the IV for each sent block)
1044            self.fixedIVBlock = getRandomBytes(ivLength)
1045
1046    def _changeWriteState(self):
1047        self._writeState = self._pendingWriteState
1048        self._pendingWriteState = _ConnectionState()
1049
1050    def _changeReadState(self):
1051        self._readState = self._pendingReadState
1052        self._pendingReadState = _ConnectionState()
1053
1054    def _sendFinished(self):
1055        #Send ChangeCipherSpec
1056        for result in self._sendMsg(ChangeCipherSpec()):
1057            yield result
1058
1059        #Switch to pending write state
1060        self._changeWriteState()
1061
1062        #Calculate verification data
1063        verifyData = self._calcFinished(True)
1064        if self.fault == Fault.badFinished:
1065            verifyData[0] = (verifyData[0]+1)%256
1066
1067        #Send Finished message under new state
1068        finished = Finished(self.version).create(verifyData)
1069        for result in self._sendMsg(finished):
1070            yield result
1071
1072    def _getChangeCipherSpec(self):
1073        #Get and check ChangeCipherSpec
1074        for result in self._getMsg(ContentType.change_cipher_spec):
1075            if result in (0,1):
1076                yield result
1077        changeCipherSpec = result
1078
1079        if changeCipherSpec.type != 1:
1080            for result in self._sendError(AlertDescription.illegal_parameter,
1081                                         "ChangeCipherSpec type incorrect"):
1082                yield result
1083
1084        #Switch to pending read state
1085        self._changeReadState()
1086
1087    def _getEncryptedExtensions(self):
1088        for result in self._getMsg(ContentType.handshake,
1089                                   HandshakeType.encrypted_extensions):
1090            if result in (0,1):
1091                yield result
1092        encrypted_extensions = result
1093        self.channel_id = encrypted_extensions.channel_id_key
1094
1095    def _getFinished(self):
1096        #Calculate verification data
1097        verifyData = self._calcFinished(False)
1098
1099        #Get and check Finished message under new state
1100        for result in self._getMsg(ContentType.handshake,
1101                                  HandshakeType.finished):
1102            if result in (0,1):
1103                yield result
1104        finished = result
1105        if finished.verify_data != verifyData:
1106            for result in self._sendError(AlertDescription.decrypt_error,
1107                                         "Finished message is incorrect"):
1108                yield result
1109
1110    def _calcFinished(self, send=True):
1111        if self.version == (3,0):
1112            if (self._client and send) or (not self._client and not send):
1113                senderStr = "\x43\x4C\x4E\x54"
1114            else:
1115                senderStr = "\x53\x52\x56\x52"
1116
1117            verifyData = self._calcSSLHandshakeHash(self.session.masterSecret,
1118                                                   senderStr)
1119            return verifyData
1120
1121        elif self.version in ((3,1), (3,2)):
1122            if (self._client and send) or (not self._client and not send):
1123                label = "client finished"
1124            else:
1125                label = "server finished"
1126
1127            handshakeHashes = stringToBytes(self._handshake_md5.digest() + \
1128                                            self._handshake_sha.digest())
1129            verifyData = PRF(self.session.masterSecret, label, handshakeHashes,
1130                             12)
1131            return verifyData
1132        else:
1133            raise AssertionError()
1134
1135    #Used for Finished messages and CertificateVerify messages in SSL v3
1136    def _calcSSLHandshakeHash(self, masterSecret, label):
1137        masterSecretStr = bytesToString(masterSecret)
1138
1139        imac_md5 = self._handshake_md5.copy()
1140        imac_sha = self._handshake_sha.copy()
1141
1142        imac_md5.update(label + masterSecretStr + '\x36'*48)
1143        imac_sha.update(label + masterSecretStr + '\x36'*40)
1144
1145        md5Str = md5.md5(masterSecretStr + ('\x5c'*48) + \
1146                         imac_md5.digest()).digest()
1147        shaStr = sha.sha(masterSecretStr + ('\x5c'*40) + \
1148                         imac_sha.digest()).digest()
1149
1150        return stringToBytes(md5Str + shaStr)
1151