1"""Classes representing TLS messages."""
2
3from utils.compat import *
4from utils.cryptomath import *
5from errors import *
6from utils.codec import *
7from constants import *
8from X509 import X509
9from X509CertChain import X509CertChain
10
11# The sha module is deprecated in Python 2.6
12try:
13    import sha
14except ImportError:
15    from hashlib import sha1 as sha
16
17# The md5 module is deprecated in Python 2.6
18try:
19    import md5
20except ImportError:
21    from hashlib import md5
22
23class RecordHeader3:
24    def __init__(self):
25        self.type = 0
26        self.version = (0,0)
27        self.length = 0
28        self.ssl2 = False
29
30    def create(self, version, type, length):
31        self.type = type
32        self.version = version
33        self.length = length
34        return self
35
36    def write(self):
37        w = Writer(5)
38        w.add(self.type, 1)
39        w.add(self.version[0], 1)
40        w.add(self.version[1], 1)
41        w.add(self.length, 2)
42        return w.bytes
43
44    def parse(self, p):
45        self.type = p.get(1)
46        self.version = (p.get(1), p.get(1))
47        self.length = p.get(2)
48        self.ssl2 = False
49        return self
50
51class RecordHeader2:
52    def __init__(self):
53        self.type = 0
54        self.version = (0,0)
55        self.length = 0
56        self.ssl2 = True
57
58    def parse(self, p):
59        if p.get(1)!=128:
60            raise SyntaxError()
61        self.type = ContentType.handshake
62        self.version = (2,0)
63        #We don't support 2-byte-length-headers; could be a problem
64        self.length = p.get(1)
65        return self
66
67
68class Msg:
69    def preWrite(self, trial):
70        if trial:
71            w = Writer()
72        else:
73            length = self.write(True)
74            w = Writer(length)
75        return w
76
77    def postWrite(self, w, trial):
78        if trial:
79            return w.index
80        else:
81            return w.bytes
82
83class Alert(Msg):
84    def __init__(self):
85        self.contentType = ContentType.alert
86        self.level = 0
87        self.description = 0
88
89    def create(self, description, level=AlertLevel.fatal):
90        self.level = level
91        self.description = description
92        return self
93
94    def parse(self, p):
95        p.setLengthCheck(2)
96        self.level = p.get(1)
97        self.description = p.get(1)
98        p.stopLengthCheck()
99        return self
100
101    def write(self):
102        w = Writer(2)
103        w.add(self.level, 1)
104        w.add(self.description, 1)
105        return w.bytes
106
107
108class HandshakeMsg(Msg):
109    def preWrite(self, handshakeType, trial):
110        if trial:
111            w = Writer()
112            w.add(handshakeType, 1)
113            w.add(0, 3)
114        else:
115            length = self.write(True)
116            w = Writer(length)
117            w.add(handshakeType, 1)
118            w.add(length-4, 3)
119        return w
120
121
122class ClientHello(HandshakeMsg):
123    def __init__(self, ssl2=False):
124        self.contentType = ContentType.handshake
125        self.ssl2 = ssl2
126        self.client_version = (0,0)
127        self.random = createByteArrayZeros(32)
128        self.session_id = createByteArraySequence([])
129        self.cipher_suites = []         # a list of 16-bit values
130        self.certificate_types = [CertificateType.x509]
131        self.compression_methods = []   # a list of 8-bit values
132        self.srp_username = None        # a string
133        self.channel_id = False
134        self.support_signed_cert_timestamps = False
135        self.status_request = False
136
137    def create(self, version, random, session_id, cipher_suites,
138               certificate_types=None, srp_username=None):
139        self.client_version = version
140        self.random = random
141        self.session_id = session_id
142        self.cipher_suites = cipher_suites
143        self.certificate_types = certificate_types
144        self.compression_methods = [0]
145        self.srp_username = srp_username
146        return self
147
148    def parse(self, p):
149        if self.ssl2:
150            self.client_version = (p.get(1), p.get(1))
151            cipherSpecsLength = p.get(2)
152            sessionIDLength = p.get(2)
153            randomLength = p.get(2)
154            self.cipher_suites = p.getFixList(3, int(cipherSpecsLength/3))
155            self.session_id = p.getFixBytes(sessionIDLength)
156            self.random = p.getFixBytes(randomLength)
157            if len(self.random) < 32:
158                zeroBytes = 32-len(self.random)
159                self.random = createByteArrayZeros(zeroBytes) + self.random
160            self.compression_methods = [0]#Fake this value
161
162            #We're not doing a stopLengthCheck() for SSLv2, oh well..
163        else:
164            p.startLengthCheck(3)
165            self.client_version = (p.get(1), p.get(1))
166            self.random = p.getFixBytes(32)
167            self.session_id = p.getVarBytes(1)
168            self.cipher_suites = p.getVarList(2, 2)
169            self.compression_methods = p.getVarList(1, 1)
170            if not p.atLengthCheck():
171                totalExtLength = p.get(2)
172                soFar = 0
173                while soFar != totalExtLength:
174                    extType = p.get(2)
175                    extLength = p.get(2)
176                    if extType == 6:
177                        self.srp_username = bytesToString(p.getVarBytes(1))
178                    elif extType == 7:
179                        self.certificate_types = p.getVarList(1, 1)
180                    elif extType == ExtensionType.channel_id:
181                        self.channel_id = True
182                    elif extType == ExtensionType.signed_cert_timestamps:
183                        if extLength:
184                            raise SyntaxError()
185                        self.support_signed_cert_timestamps = True
186                    elif extType == ExtensionType.status_request:
187                        # Extension contents are currently ignored.
188                        # According to RFC 6066, this is not strictly forbidden
189                        # (although it is suboptimal):
190                        # Servers that receive a client hello containing the
191                        # "status_request" extension MAY return a suitable
192                        # certificate status response to the client along with
193                        # their certificate.  If OCSP is requested, they
194                        # SHOULD use the information contained in the extension
195                        # when selecting an OCSP responder and SHOULD include
196                        # request_extensions in the OCSP request.
197                        p.getFixBytes(extLength)
198                        self.status_request = True
199                    else:
200                        p.getFixBytes(extLength)
201                    soFar += 4 + extLength
202            p.stopLengthCheck()
203        return self
204
205    def write(self, trial=False):
206        w = HandshakeMsg.preWrite(self, HandshakeType.client_hello, trial)
207        w.add(self.client_version[0], 1)
208        w.add(self.client_version[1], 1)
209        w.addFixSeq(self.random, 1)
210        w.addVarSeq(self.session_id, 1, 1)
211        w.addVarSeq(self.cipher_suites, 2, 2)
212        w.addVarSeq(self.compression_methods, 1, 1)
213
214        extLength = 0
215        if self.certificate_types and self.certificate_types != \
216                [CertificateType.x509]:
217            extLength += 5 + len(self.certificate_types)
218        if self.srp_username:
219            extLength += 5 + len(self.srp_username)
220        if extLength > 0:
221            w.add(extLength, 2)
222
223        if self.certificate_types and self.certificate_types != \
224                [CertificateType.x509]:
225            w.add(7, 2)
226            w.add(len(self.certificate_types)+1, 2)
227            w.addVarSeq(self.certificate_types, 1, 1)
228        if self.srp_username:
229            w.add(6, 2)
230            w.add(len(self.srp_username)+1, 2)
231            w.addVarSeq(stringToBytes(self.srp_username), 1, 1)
232
233        return HandshakeMsg.postWrite(self, w, trial)
234
235
236class ServerHello(HandshakeMsg):
237    def __init__(self):
238        self.contentType = ContentType.handshake
239        self.server_version = (0,0)
240        self.random = createByteArrayZeros(32)
241        self.session_id = createByteArraySequence([])
242        self.cipher_suite = 0
243        self.certificate_type = CertificateType.x509
244        self.compression_method = 0
245        self.channel_id = False
246        self.signed_cert_timestamps = None
247        self.status_request = False
248
249    def create(self, version, random, session_id, cipher_suite,
250               certificate_type):
251        self.server_version = version
252        self.random = random
253        self.session_id = session_id
254        self.cipher_suite = cipher_suite
255        self.certificate_type = certificate_type
256        self.compression_method = 0
257        return self
258
259    def parse(self, p):
260        p.startLengthCheck(3)
261        self.server_version = (p.get(1), p.get(1))
262        self.random = p.getFixBytes(32)
263        self.session_id = p.getVarBytes(1)
264        self.cipher_suite = p.get(2)
265        self.compression_method = p.get(1)
266        if not p.atLengthCheck():
267            totalExtLength = p.get(2)
268            soFar = 0
269            while soFar != totalExtLength:
270                extType = p.get(2)
271                extLength = p.get(2)
272                if extType == 7:
273                    self.certificate_type = p.get(1)
274                else:
275                    p.getFixBytes(extLength)
276                soFar += 4 + extLength
277        p.stopLengthCheck()
278        return self
279
280    def write(self, trial=False):
281        w = HandshakeMsg.preWrite(self, HandshakeType.server_hello, trial)
282        w.add(self.server_version[0], 1)
283        w.add(self.server_version[1], 1)
284        w.addFixSeq(self.random, 1)
285        w.addVarSeq(self.session_id, 1, 1)
286        w.add(self.cipher_suite, 2)
287        w.add(self.compression_method, 1)
288
289        extLength = 0
290        if self.certificate_type and self.certificate_type != \
291                CertificateType.x509:
292            extLength += 5
293
294        if self.channel_id:
295            extLength += 4
296
297        if self.signed_cert_timestamps:
298            extLength += 4 + len(self.signed_cert_timestamps)
299
300        if self.status_request:
301            extLength += 4
302
303        if extLength != 0:
304            w.add(extLength, 2)
305
306        if self.certificate_type and self.certificate_type != \
307                CertificateType.x509:
308            w.add(7, 2)
309            w.add(1, 2)
310            w.add(self.certificate_type, 1)
311
312        if self.channel_id:
313            w.add(ExtensionType.channel_id, 2)
314            w.add(0, 2)
315
316        if self.signed_cert_timestamps:
317            w.add(ExtensionType.signed_cert_timestamps, 2)
318            w.addVarSeq(stringToBytes(self.signed_cert_timestamps), 1, 2)
319
320        if self.status_request:
321            w.add(ExtensionType.status_request, 2)
322            w.add(0, 2)
323
324        return HandshakeMsg.postWrite(self, w, trial)
325
326class Certificate(HandshakeMsg):
327    def __init__(self, certificateType):
328        self.certificateType = certificateType
329        self.contentType = ContentType.handshake
330        self.certChain = None
331
332    def create(self, certChain):
333        self.certChain = certChain
334        return self
335
336    def parse(self, p):
337        p.startLengthCheck(3)
338        if self.certificateType == CertificateType.x509:
339            chainLength = p.get(3)
340            index = 0
341            certificate_list = []
342            while index != chainLength:
343                certBytes = p.getVarBytes(3)
344                x509 = X509()
345                x509.parseBinary(certBytes)
346                certificate_list.append(x509)
347                index += len(certBytes)+3
348            if certificate_list:
349                self.certChain = X509CertChain(certificate_list)
350        elif self.certificateType == CertificateType.cryptoID:
351            s = bytesToString(p.getVarBytes(2))
352            if s:
353                try:
354                    import cryptoIDlib.CertChain
355                except ImportError:
356                    raise SyntaxError(\
357                    "cryptoID cert chain received, cryptoIDlib not present")
358                self.certChain = cryptoIDlib.CertChain.CertChain().parse(s)
359        else:
360            raise AssertionError()
361
362        p.stopLengthCheck()
363        return self
364
365    def write(self, trial=False):
366        w = HandshakeMsg.preWrite(self, HandshakeType.certificate, trial)
367        if self.certificateType == CertificateType.x509:
368            chainLength = 0
369            if self.certChain:
370                certificate_list = self.certChain.x509List
371            else:
372                certificate_list = []
373            #determine length
374            for cert in certificate_list:
375                bytes = cert.writeBytes()
376                chainLength += len(bytes)+3
377            #add bytes
378            w.add(chainLength, 3)
379            for cert in certificate_list:
380                bytes = cert.writeBytes()
381                w.addVarSeq(bytes, 1, 3)
382        elif self.certificateType == CertificateType.cryptoID:
383            if self.certChain:
384                bytes = stringToBytes(self.certChain.write())
385            else:
386                bytes = createByteArraySequence([])
387            w.addVarSeq(bytes, 1, 2)
388        else:
389            raise AssertionError()
390        return HandshakeMsg.postWrite(self, w, trial)
391
392class CertificateStatus(HandshakeMsg):
393    def __init__(self):
394        self.contentType = ContentType.handshake
395
396    def create(self, ocsp_response):
397        self.ocsp_response = ocsp_response
398        return self
399
400    # Defined for the sake of completeness, even though we currently only
401    # support sending the status message (server-side), not requesting
402    # or receiving it (client-side).
403    def parse(self, p):
404        p.startLengthCheck(3)
405        status_type = p.get(1)
406        # Only one type is specified, so hardwire it.
407        if status_type != CertificateStatusType.ocsp:
408            raise SyntaxError()
409        ocsp_response = p.getVarBytes(3)
410        if not ocsp_response:
411            # Can't be empty
412            raise SyntaxError()
413        self.ocsp_response = ocsp_response
414        return self
415
416    def write(self, trial=False):
417        w = HandshakeMsg.preWrite(self, HandshakeType.certificate_status,
418                                  trial)
419        w.add(CertificateStatusType.ocsp, 1)
420        w.addVarSeq(stringToBytes(self.ocsp_response), 1, 3)
421        return HandshakeMsg.postWrite(self, w, trial)
422
423class CertificateRequest(HandshakeMsg):
424    def __init__(self):
425        self.contentType = ContentType.handshake
426        #Apple's Secure Transport library rejects empty certificate_types, so
427        #default to rsa_sign.
428        self.certificate_types = [ClientCertificateType.rsa_sign]
429        self.certificate_authorities = []
430
431    def create(self, certificate_types, certificate_authorities):
432        self.certificate_types = certificate_types
433        self.certificate_authorities = certificate_authorities
434        return self
435
436    def parse(self, p):
437        p.startLengthCheck(3)
438        self.certificate_types = p.getVarList(1, 1)
439        ca_list_length = p.get(2)
440        index = 0
441        self.certificate_authorities = []
442        while index != ca_list_length:
443          ca_bytes = p.getVarBytes(2)
444          self.certificate_authorities.append(ca_bytes)
445          index += len(ca_bytes)+2
446        p.stopLengthCheck()
447        return self
448
449    def write(self, trial=False):
450        w = HandshakeMsg.preWrite(self, HandshakeType.certificate_request,
451                                  trial)
452        w.addVarSeq(self.certificate_types, 1, 1)
453        caLength = 0
454        #determine length
455        for ca_dn in self.certificate_authorities:
456            caLength += len(ca_dn)+2
457        w.add(caLength, 2)
458        #add bytes
459        for ca_dn in self.certificate_authorities:
460            w.addVarSeq(ca_dn, 1, 2)
461        return HandshakeMsg.postWrite(self, w, trial)
462
463class ServerKeyExchange(HandshakeMsg):
464    def __init__(self, cipherSuite):
465        self.cipherSuite = cipherSuite
466        self.contentType = ContentType.handshake
467        self.srp_N = 0L
468        self.srp_g = 0L
469        self.srp_s = createByteArraySequence([])
470        self.srp_B = 0L
471        self.signature = createByteArraySequence([])
472
473    def createSRP(self, srp_N, srp_g, srp_s, srp_B):
474        self.srp_N = srp_N
475        self.srp_g = srp_g
476        self.srp_s = srp_s
477        self.srp_B = srp_B
478        return self
479
480    def parse(self, p):
481        p.startLengthCheck(3)
482        self.srp_N = bytesToNumber(p.getVarBytes(2))
483        self.srp_g = bytesToNumber(p.getVarBytes(2))
484        self.srp_s = p.getVarBytes(1)
485        self.srp_B = bytesToNumber(p.getVarBytes(2))
486        if self.cipherSuite in CipherSuite.srpRsaSuites:
487            self.signature = p.getVarBytes(2)
488        p.stopLengthCheck()
489        return self
490
491    def write(self, trial=False):
492        w = HandshakeMsg.preWrite(self, HandshakeType.server_key_exchange,
493                                  trial)
494        w.addVarSeq(numberToBytes(self.srp_N), 1, 2)
495        w.addVarSeq(numberToBytes(self.srp_g), 1, 2)
496        w.addVarSeq(self.srp_s, 1, 1)
497        w.addVarSeq(numberToBytes(self.srp_B), 1, 2)
498        if self.cipherSuite in CipherSuite.srpRsaSuites:
499            w.addVarSeq(self.signature, 1, 2)
500        return HandshakeMsg.postWrite(self, w, trial)
501
502    def hash(self, clientRandom, serverRandom):
503        oldCipherSuite = self.cipherSuite
504        self.cipherSuite = None
505        try:
506            bytes = clientRandom + serverRandom + self.write()[4:]
507            s = bytesToString(bytes)
508            return stringToBytes(md5.md5(s).digest() + sha.sha(s).digest())
509        finally:
510            self.cipherSuite = oldCipherSuite
511
512class ServerHelloDone(HandshakeMsg):
513    def __init__(self):
514        self.contentType = ContentType.handshake
515
516    def create(self):
517        return self
518
519    def parse(self, p):
520        p.startLengthCheck(3)
521        p.stopLengthCheck()
522        return self
523
524    def write(self, trial=False):
525        w = HandshakeMsg.preWrite(self, HandshakeType.server_hello_done, trial)
526        return HandshakeMsg.postWrite(self, w, trial)
527
528class ClientKeyExchange(HandshakeMsg):
529    def __init__(self, cipherSuite, version=None):
530        self.cipherSuite = cipherSuite
531        self.version = version
532        self.contentType = ContentType.handshake
533        self.srp_A = 0
534        self.encryptedPreMasterSecret = createByteArraySequence([])
535
536    def createSRP(self, srp_A):
537        self.srp_A = srp_A
538        return self
539
540    def createRSA(self, encryptedPreMasterSecret):
541        self.encryptedPreMasterSecret = encryptedPreMasterSecret
542        return self
543
544    def parse(self, p):
545        p.startLengthCheck(3)
546        if self.cipherSuite in CipherSuite.srpSuites + \
547                               CipherSuite.srpRsaSuites:
548            self.srp_A = bytesToNumber(p.getVarBytes(2))
549        elif self.cipherSuite in CipherSuite.rsaSuites:
550            if self.version in ((3,1), (3,2)):
551                self.encryptedPreMasterSecret = p.getVarBytes(2)
552            elif self.version == (3,0):
553                self.encryptedPreMasterSecret = \
554                    p.getFixBytes(len(p.bytes)-p.index)
555            else:
556                raise AssertionError()
557        else:
558            raise AssertionError()
559        p.stopLengthCheck()
560        return self
561
562    def write(self, trial=False):
563        w = HandshakeMsg.preWrite(self, HandshakeType.client_key_exchange,
564                                  trial)
565        if self.cipherSuite in CipherSuite.srpSuites + \
566                               CipherSuite.srpRsaSuites:
567            w.addVarSeq(numberToBytes(self.srp_A), 1, 2)
568        elif self.cipherSuite in CipherSuite.rsaSuites:
569            if self.version in ((3,1), (3,2)):
570                w.addVarSeq(self.encryptedPreMasterSecret, 1, 2)
571            elif self.version == (3,0):
572                w.addFixSeq(self.encryptedPreMasterSecret, 1)
573            else:
574                raise AssertionError()
575        else:
576            raise AssertionError()
577        return HandshakeMsg.postWrite(self, w, trial)
578
579class CertificateVerify(HandshakeMsg):
580    def __init__(self):
581        self.contentType = ContentType.handshake
582        self.signature = createByteArraySequence([])
583
584    def create(self, signature):
585        self.signature = signature
586        return self
587
588    def parse(self, p):
589        p.startLengthCheck(3)
590        self.signature = p.getVarBytes(2)
591        p.stopLengthCheck()
592        return self
593
594    def write(self, trial=False):
595        w = HandshakeMsg.preWrite(self, HandshakeType.certificate_verify,
596                                  trial)
597        w.addVarSeq(self.signature, 1, 2)
598        return HandshakeMsg.postWrite(self, w, trial)
599
600class ChangeCipherSpec(Msg):
601    def __init__(self):
602        self.contentType = ContentType.change_cipher_spec
603        self.type = 1
604
605    def create(self):
606        self.type = 1
607        return self
608
609    def parse(self, p):
610        p.setLengthCheck(1)
611        self.type = p.get(1)
612        p.stopLengthCheck()
613        return self
614
615    def write(self, trial=False):
616        w = Msg.preWrite(self, trial)
617        w.add(self.type,1)
618        return Msg.postWrite(self, w, trial)
619
620
621class Finished(HandshakeMsg):
622    def __init__(self, version):
623        self.contentType = ContentType.handshake
624        self.version = version
625        self.verify_data = createByteArraySequence([])
626
627    def create(self, verify_data):
628        self.verify_data = verify_data
629        return self
630
631    def parse(self, p):
632        p.startLengthCheck(3)
633        if self.version == (3,0):
634            self.verify_data = p.getFixBytes(36)
635        elif self.version in ((3,1), (3,2)):
636            self.verify_data = p.getFixBytes(12)
637        else:
638            raise AssertionError()
639        p.stopLengthCheck()
640        return self
641
642    def write(self, trial=False):
643        w = HandshakeMsg.preWrite(self, HandshakeType.finished, trial)
644        w.addFixSeq(self.verify_data, 1)
645        return HandshakeMsg.postWrite(self, w, trial)
646
647class EncryptedExtensions(HandshakeMsg):
648    def __init__(self):
649        self.channel_id_key = None
650        self.channel_id_proof = None
651
652    def parse(self, p):
653        p.startLengthCheck(3)
654        soFar = 0
655        while soFar != p.lengthCheck:
656            extType = p.get(2)
657            extLength = p.get(2)
658            if extType == ExtensionType.channel_id:
659                if extLength != 32*4:
660                    raise SyntaxError()
661                self.channel_id_key = p.getFixBytes(64)
662                self.channel_id_proof = p.getFixBytes(64)
663            else:
664                p.getFixBytes(extLength)
665            soFar += 4 + extLength
666        p.stopLengthCheck()
667        return self
668
669class ApplicationData(Msg):
670    def __init__(self):
671        self.contentType = ContentType.application_data
672        self.bytes = createByteArraySequence([])
673
674    def create(self, bytes):
675        self.bytes = bytes
676        return self
677
678    def parse(self, p):
679        self.bytes = p.bytes
680        return self
681
682    def write(self):
683        return self.bytes
684