1"""Classes representing TLS messages.""" 2 3from utils.compat import * 4from utils.cryptomath import * 5from errors import * 6from utils.codec import * 7from constants import * 8from X509 import X509 9from X509CertChain import X509CertChain 10 11# The sha module is deprecated in Python 2.6 12try: 13 import sha 14except ImportError: 15 from hashlib import sha1 as sha 16 17# The md5 module is deprecated in Python 2.6 18try: 19 import md5 20except ImportError: 21 from hashlib import md5 22 23class RecordHeader3: 24 def __init__(self): 25 self.type = 0 26 self.version = (0,0) 27 self.length = 0 28 self.ssl2 = False 29 30 def create(self, version, type, length): 31 self.type = type 32 self.version = version 33 self.length = length 34 return self 35 36 def write(self): 37 w = Writer(5) 38 w.add(self.type, 1) 39 w.add(self.version[0], 1) 40 w.add(self.version[1], 1) 41 w.add(self.length, 2) 42 return w.bytes 43 44 def parse(self, p): 45 self.type = p.get(1) 46 self.version = (p.get(1), p.get(1)) 47 self.length = p.get(2) 48 self.ssl2 = False 49 return self 50 51class RecordHeader2: 52 def __init__(self): 53 self.type = 0 54 self.version = (0,0) 55 self.length = 0 56 self.ssl2 = True 57 58 def parse(self, p): 59 if p.get(1)!=128: 60 raise SyntaxError() 61 self.type = ContentType.handshake 62 self.version = (2,0) 63 #We don't support 2-byte-length-headers; could be a problem 64 self.length = p.get(1) 65 return self 66 67 68class Msg: 69 def preWrite(self, trial): 70 if trial: 71 w = Writer() 72 else: 73 length = self.write(True) 74 w = Writer(length) 75 return w 76 77 def postWrite(self, w, trial): 78 if trial: 79 return w.index 80 else: 81 return w.bytes 82 83class Alert(Msg): 84 def __init__(self): 85 self.contentType = ContentType.alert 86 self.level = 0 87 self.description = 0 88 89 def create(self, description, level=AlertLevel.fatal): 90 self.level = level 91 self.description = description 92 return self 93 94 def parse(self, p): 95 p.setLengthCheck(2) 96 self.level = p.get(1) 97 self.description = p.get(1) 98 p.stopLengthCheck() 99 return self 100 101 def write(self): 102 w = Writer(2) 103 w.add(self.level, 1) 104 w.add(self.description, 1) 105 return w.bytes 106 107 108class HandshakeMsg(Msg): 109 def preWrite(self, handshakeType, trial): 110 if trial: 111 w = Writer() 112 w.add(handshakeType, 1) 113 w.add(0, 3) 114 else: 115 length = self.write(True) 116 w = Writer(length) 117 w.add(handshakeType, 1) 118 w.add(length-4, 3) 119 return w 120 121 122class ClientHello(HandshakeMsg): 123 def __init__(self, ssl2=False): 124 self.contentType = ContentType.handshake 125 self.ssl2 = ssl2 126 self.client_version = (0,0) 127 self.random = createByteArrayZeros(32) 128 self.session_id = createByteArraySequence([]) 129 self.cipher_suites = [] # a list of 16-bit values 130 self.certificate_types = [CertificateType.x509] 131 self.compression_methods = [] # a list of 8-bit values 132 self.srp_username = None # a string 133 self.channel_id = False 134 self.support_signed_cert_timestamps = False 135 self.status_request = False 136 137 def create(self, version, random, session_id, cipher_suites, 138 certificate_types=None, srp_username=None): 139 self.client_version = version 140 self.random = random 141 self.session_id = session_id 142 self.cipher_suites = cipher_suites 143 self.certificate_types = certificate_types 144 self.compression_methods = [0] 145 self.srp_username = srp_username 146 return self 147 148 def parse(self, p): 149 if self.ssl2: 150 self.client_version = (p.get(1), p.get(1)) 151 cipherSpecsLength = p.get(2) 152 sessionIDLength = p.get(2) 153 randomLength = p.get(2) 154 self.cipher_suites = p.getFixList(3, int(cipherSpecsLength/3)) 155 self.session_id = p.getFixBytes(sessionIDLength) 156 self.random = p.getFixBytes(randomLength) 157 if len(self.random) < 32: 158 zeroBytes = 32-len(self.random) 159 self.random = createByteArrayZeros(zeroBytes) + self.random 160 self.compression_methods = [0]#Fake this value 161 162 #We're not doing a stopLengthCheck() for SSLv2, oh well.. 163 else: 164 p.startLengthCheck(3) 165 self.client_version = (p.get(1), p.get(1)) 166 self.random = p.getFixBytes(32) 167 self.session_id = p.getVarBytes(1) 168 self.cipher_suites = p.getVarList(2, 2) 169 self.compression_methods = p.getVarList(1, 1) 170 if not p.atLengthCheck(): 171 totalExtLength = p.get(2) 172 soFar = 0 173 while soFar != totalExtLength: 174 extType = p.get(2) 175 extLength = p.get(2) 176 if extType == 6: 177 self.srp_username = bytesToString(p.getVarBytes(1)) 178 elif extType == 7: 179 self.certificate_types = p.getVarList(1, 1) 180 elif extType == ExtensionType.channel_id: 181 self.channel_id = True 182 elif extType == ExtensionType.signed_cert_timestamps: 183 if extLength: 184 raise SyntaxError() 185 self.support_signed_cert_timestamps = True 186 elif extType == ExtensionType.status_request: 187 # Extension contents are currently ignored. 188 # According to RFC 6066, this is not strictly forbidden 189 # (although it is suboptimal): 190 # Servers that receive a client hello containing the 191 # "status_request" extension MAY return a suitable 192 # certificate status response to the client along with 193 # their certificate. If OCSP is requested, they 194 # SHOULD use the information contained in the extension 195 # when selecting an OCSP responder and SHOULD include 196 # request_extensions in the OCSP request. 197 p.getFixBytes(extLength) 198 self.status_request = True 199 else: 200 p.getFixBytes(extLength) 201 soFar += 4 + extLength 202 p.stopLengthCheck() 203 return self 204 205 def write(self, trial=False): 206 w = HandshakeMsg.preWrite(self, HandshakeType.client_hello, trial) 207 w.add(self.client_version[0], 1) 208 w.add(self.client_version[1], 1) 209 w.addFixSeq(self.random, 1) 210 w.addVarSeq(self.session_id, 1, 1) 211 w.addVarSeq(self.cipher_suites, 2, 2) 212 w.addVarSeq(self.compression_methods, 1, 1) 213 214 extLength = 0 215 if self.certificate_types and self.certificate_types != \ 216 [CertificateType.x509]: 217 extLength += 5 + len(self.certificate_types) 218 if self.srp_username: 219 extLength += 5 + len(self.srp_username) 220 if extLength > 0: 221 w.add(extLength, 2) 222 223 if self.certificate_types and self.certificate_types != \ 224 [CertificateType.x509]: 225 w.add(7, 2) 226 w.add(len(self.certificate_types)+1, 2) 227 w.addVarSeq(self.certificate_types, 1, 1) 228 if self.srp_username: 229 w.add(6, 2) 230 w.add(len(self.srp_username)+1, 2) 231 w.addVarSeq(stringToBytes(self.srp_username), 1, 1) 232 233 return HandshakeMsg.postWrite(self, w, trial) 234 235 236class ServerHello(HandshakeMsg): 237 def __init__(self): 238 self.contentType = ContentType.handshake 239 self.server_version = (0,0) 240 self.random = createByteArrayZeros(32) 241 self.session_id = createByteArraySequence([]) 242 self.cipher_suite = 0 243 self.certificate_type = CertificateType.x509 244 self.compression_method = 0 245 self.channel_id = False 246 self.signed_cert_timestamps = None 247 self.status_request = False 248 249 def create(self, version, random, session_id, cipher_suite, 250 certificate_type): 251 self.server_version = version 252 self.random = random 253 self.session_id = session_id 254 self.cipher_suite = cipher_suite 255 self.certificate_type = certificate_type 256 self.compression_method = 0 257 return self 258 259 def parse(self, p): 260 p.startLengthCheck(3) 261 self.server_version = (p.get(1), p.get(1)) 262 self.random = p.getFixBytes(32) 263 self.session_id = p.getVarBytes(1) 264 self.cipher_suite = p.get(2) 265 self.compression_method = p.get(1) 266 if not p.atLengthCheck(): 267 totalExtLength = p.get(2) 268 soFar = 0 269 while soFar != totalExtLength: 270 extType = p.get(2) 271 extLength = p.get(2) 272 if extType == 7: 273 self.certificate_type = p.get(1) 274 else: 275 p.getFixBytes(extLength) 276 soFar += 4 + extLength 277 p.stopLengthCheck() 278 return self 279 280 def write(self, trial=False): 281 w = HandshakeMsg.preWrite(self, HandshakeType.server_hello, trial) 282 w.add(self.server_version[0], 1) 283 w.add(self.server_version[1], 1) 284 w.addFixSeq(self.random, 1) 285 w.addVarSeq(self.session_id, 1, 1) 286 w.add(self.cipher_suite, 2) 287 w.add(self.compression_method, 1) 288 289 extLength = 0 290 if self.certificate_type and self.certificate_type != \ 291 CertificateType.x509: 292 extLength += 5 293 294 if self.channel_id: 295 extLength += 4 296 297 if self.signed_cert_timestamps: 298 extLength += 4 + len(self.signed_cert_timestamps) 299 300 if self.status_request: 301 extLength += 4 302 303 if extLength != 0: 304 w.add(extLength, 2) 305 306 if self.certificate_type and self.certificate_type != \ 307 CertificateType.x509: 308 w.add(7, 2) 309 w.add(1, 2) 310 w.add(self.certificate_type, 1) 311 312 if self.channel_id: 313 w.add(ExtensionType.channel_id, 2) 314 w.add(0, 2) 315 316 if self.signed_cert_timestamps: 317 w.add(ExtensionType.signed_cert_timestamps, 2) 318 w.addVarSeq(stringToBytes(self.signed_cert_timestamps), 1, 2) 319 320 if self.status_request: 321 w.add(ExtensionType.status_request, 2) 322 w.add(0, 2) 323 324 return HandshakeMsg.postWrite(self, w, trial) 325 326class Certificate(HandshakeMsg): 327 def __init__(self, certificateType): 328 self.certificateType = certificateType 329 self.contentType = ContentType.handshake 330 self.certChain = None 331 332 def create(self, certChain): 333 self.certChain = certChain 334 return self 335 336 def parse(self, p): 337 p.startLengthCheck(3) 338 if self.certificateType == CertificateType.x509: 339 chainLength = p.get(3) 340 index = 0 341 certificate_list = [] 342 while index != chainLength: 343 certBytes = p.getVarBytes(3) 344 x509 = X509() 345 x509.parseBinary(certBytes) 346 certificate_list.append(x509) 347 index += len(certBytes)+3 348 if certificate_list: 349 self.certChain = X509CertChain(certificate_list) 350 elif self.certificateType == CertificateType.cryptoID: 351 s = bytesToString(p.getVarBytes(2)) 352 if s: 353 try: 354 import cryptoIDlib.CertChain 355 except ImportError: 356 raise SyntaxError(\ 357 "cryptoID cert chain received, cryptoIDlib not present") 358 self.certChain = cryptoIDlib.CertChain.CertChain().parse(s) 359 else: 360 raise AssertionError() 361 362 p.stopLengthCheck() 363 return self 364 365 def write(self, trial=False): 366 w = HandshakeMsg.preWrite(self, HandshakeType.certificate, trial) 367 if self.certificateType == CertificateType.x509: 368 chainLength = 0 369 if self.certChain: 370 certificate_list = self.certChain.x509List 371 else: 372 certificate_list = [] 373 #determine length 374 for cert in certificate_list: 375 bytes = cert.writeBytes() 376 chainLength += len(bytes)+3 377 #add bytes 378 w.add(chainLength, 3) 379 for cert in certificate_list: 380 bytes = cert.writeBytes() 381 w.addVarSeq(bytes, 1, 3) 382 elif self.certificateType == CertificateType.cryptoID: 383 if self.certChain: 384 bytes = stringToBytes(self.certChain.write()) 385 else: 386 bytes = createByteArraySequence([]) 387 w.addVarSeq(bytes, 1, 2) 388 else: 389 raise AssertionError() 390 return HandshakeMsg.postWrite(self, w, trial) 391 392class CertificateStatus(HandshakeMsg): 393 def __init__(self): 394 self.contentType = ContentType.handshake 395 396 def create(self, ocsp_response): 397 self.ocsp_response = ocsp_response 398 return self 399 400 # Defined for the sake of completeness, even though we currently only 401 # support sending the status message (server-side), not requesting 402 # or receiving it (client-side). 403 def parse(self, p): 404 p.startLengthCheck(3) 405 status_type = p.get(1) 406 # Only one type is specified, so hardwire it. 407 if status_type != CertificateStatusType.ocsp: 408 raise SyntaxError() 409 ocsp_response = p.getVarBytes(3) 410 if not ocsp_response: 411 # Can't be empty 412 raise SyntaxError() 413 self.ocsp_response = ocsp_response 414 return self 415 416 def write(self, trial=False): 417 w = HandshakeMsg.preWrite(self, HandshakeType.certificate_status, 418 trial) 419 w.add(CertificateStatusType.ocsp, 1) 420 w.addVarSeq(stringToBytes(self.ocsp_response), 1, 3) 421 return HandshakeMsg.postWrite(self, w, trial) 422 423class CertificateRequest(HandshakeMsg): 424 def __init__(self): 425 self.contentType = ContentType.handshake 426 #Apple's Secure Transport library rejects empty certificate_types, so 427 #default to rsa_sign. 428 self.certificate_types = [ClientCertificateType.rsa_sign] 429 self.certificate_authorities = [] 430 431 def create(self, certificate_types, certificate_authorities): 432 self.certificate_types = certificate_types 433 self.certificate_authorities = certificate_authorities 434 return self 435 436 def parse(self, p): 437 p.startLengthCheck(3) 438 self.certificate_types = p.getVarList(1, 1) 439 ca_list_length = p.get(2) 440 index = 0 441 self.certificate_authorities = [] 442 while index != ca_list_length: 443 ca_bytes = p.getVarBytes(2) 444 self.certificate_authorities.append(ca_bytes) 445 index += len(ca_bytes)+2 446 p.stopLengthCheck() 447 return self 448 449 def write(self, trial=False): 450 w = HandshakeMsg.preWrite(self, HandshakeType.certificate_request, 451 trial) 452 w.addVarSeq(self.certificate_types, 1, 1) 453 caLength = 0 454 #determine length 455 for ca_dn in self.certificate_authorities: 456 caLength += len(ca_dn)+2 457 w.add(caLength, 2) 458 #add bytes 459 for ca_dn in self.certificate_authorities: 460 w.addVarSeq(ca_dn, 1, 2) 461 return HandshakeMsg.postWrite(self, w, trial) 462 463class ServerKeyExchange(HandshakeMsg): 464 def __init__(self, cipherSuite): 465 self.cipherSuite = cipherSuite 466 self.contentType = ContentType.handshake 467 self.srp_N = 0L 468 self.srp_g = 0L 469 self.srp_s = createByteArraySequence([]) 470 self.srp_B = 0L 471 self.signature = createByteArraySequence([]) 472 473 def createSRP(self, srp_N, srp_g, srp_s, srp_B): 474 self.srp_N = srp_N 475 self.srp_g = srp_g 476 self.srp_s = srp_s 477 self.srp_B = srp_B 478 return self 479 480 def parse(self, p): 481 p.startLengthCheck(3) 482 self.srp_N = bytesToNumber(p.getVarBytes(2)) 483 self.srp_g = bytesToNumber(p.getVarBytes(2)) 484 self.srp_s = p.getVarBytes(1) 485 self.srp_B = bytesToNumber(p.getVarBytes(2)) 486 if self.cipherSuite in CipherSuite.srpRsaSuites: 487 self.signature = p.getVarBytes(2) 488 p.stopLengthCheck() 489 return self 490 491 def write(self, trial=False): 492 w = HandshakeMsg.preWrite(self, HandshakeType.server_key_exchange, 493 trial) 494 w.addVarSeq(numberToBytes(self.srp_N), 1, 2) 495 w.addVarSeq(numberToBytes(self.srp_g), 1, 2) 496 w.addVarSeq(self.srp_s, 1, 1) 497 w.addVarSeq(numberToBytes(self.srp_B), 1, 2) 498 if self.cipherSuite in CipherSuite.srpRsaSuites: 499 w.addVarSeq(self.signature, 1, 2) 500 return HandshakeMsg.postWrite(self, w, trial) 501 502 def hash(self, clientRandom, serverRandom): 503 oldCipherSuite = self.cipherSuite 504 self.cipherSuite = None 505 try: 506 bytes = clientRandom + serverRandom + self.write()[4:] 507 s = bytesToString(bytes) 508 return stringToBytes(md5.md5(s).digest() + sha.sha(s).digest()) 509 finally: 510 self.cipherSuite = oldCipherSuite 511 512class ServerHelloDone(HandshakeMsg): 513 def __init__(self): 514 self.contentType = ContentType.handshake 515 516 def create(self): 517 return self 518 519 def parse(self, p): 520 p.startLengthCheck(3) 521 p.stopLengthCheck() 522 return self 523 524 def write(self, trial=False): 525 w = HandshakeMsg.preWrite(self, HandshakeType.server_hello_done, trial) 526 return HandshakeMsg.postWrite(self, w, trial) 527 528class ClientKeyExchange(HandshakeMsg): 529 def __init__(self, cipherSuite, version=None): 530 self.cipherSuite = cipherSuite 531 self.version = version 532 self.contentType = ContentType.handshake 533 self.srp_A = 0 534 self.encryptedPreMasterSecret = createByteArraySequence([]) 535 536 def createSRP(self, srp_A): 537 self.srp_A = srp_A 538 return self 539 540 def createRSA(self, encryptedPreMasterSecret): 541 self.encryptedPreMasterSecret = encryptedPreMasterSecret 542 return self 543 544 def parse(self, p): 545 p.startLengthCheck(3) 546 if self.cipherSuite in CipherSuite.srpSuites + \ 547 CipherSuite.srpRsaSuites: 548 self.srp_A = bytesToNumber(p.getVarBytes(2)) 549 elif self.cipherSuite in CipherSuite.rsaSuites: 550 if self.version in ((3,1), (3,2)): 551 self.encryptedPreMasterSecret = p.getVarBytes(2) 552 elif self.version == (3,0): 553 self.encryptedPreMasterSecret = \ 554 p.getFixBytes(len(p.bytes)-p.index) 555 else: 556 raise AssertionError() 557 else: 558 raise AssertionError() 559 p.stopLengthCheck() 560 return self 561 562 def write(self, trial=False): 563 w = HandshakeMsg.preWrite(self, HandshakeType.client_key_exchange, 564 trial) 565 if self.cipherSuite in CipherSuite.srpSuites + \ 566 CipherSuite.srpRsaSuites: 567 w.addVarSeq(numberToBytes(self.srp_A), 1, 2) 568 elif self.cipherSuite in CipherSuite.rsaSuites: 569 if self.version in ((3,1), (3,2)): 570 w.addVarSeq(self.encryptedPreMasterSecret, 1, 2) 571 elif self.version == (3,0): 572 w.addFixSeq(self.encryptedPreMasterSecret, 1) 573 else: 574 raise AssertionError() 575 else: 576 raise AssertionError() 577 return HandshakeMsg.postWrite(self, w, trial) 578 579class CertificateVerify(HandshakeMsg): 580 def __init__(self): 581 self.contentType = ContentType.handshake 582 self.signature = createByteArraySequence([]) 583 584 def create(self, signature): 585 self.signature = signature 586 return self 587 588 def parse(self, p): 589 p.startLengthCheck(3) 590 self.signature = p.getVarBytes(2) 591 p.stopLengthCheck() 592 return self 593 594 def write(self, trial=False): 595 w = HandshakeMsg.preWrite(self, HandshakeType.certificate_verify, 596 trial) 597 w.addVarSeq(self.signature, 1, 2) 598 return HandshakeMsg.postWrite(self, w, trial) 599 600class ChangeCipherSpec(Msg): 601 def __init__(self): 602 self.contentType = ContentType.change_cipher_spec 603 self.type = 1 604 605 def create(self): 606 self.type = 1 607 return self 608 609 def parse(self, p): 610 p.setLengthCheck(1) 611 self.type = p.get(1) 612 p.stopLengthCheck() 613 return self 614 615 def write(self, trial=False): 616 w = Msg.preWrite(self, trial) 617 w.add(self.type,1) 618 return Msg.postWrite(self, w, trial) 619 620 621class Finished(HandshakeMsg): 622 def __init__(self, version): 623 self.contentType = ContentType.handshake 624 self.version = version 625 self.verify_data = createByteArraySequence([]) 626 627 def create(self, verify_data): 628 self.verify_data = verify_data 629 return self 630 631 def parse(self, p): 632 p.startLengthCheck(3) 633 if self.version == (3,0): 634 self.verify_data = p.getFixBytes(36) 635 elif self.version in ((3,1), (3,2)): 636 self.verify_data = p.getFixBytes(12) 637 else: 638 raise AssertionError() 639 p.stopLengthCheck() 640 return self 641 642 def write(self, trial=False): 643 w = HandshakeMsg.preWrite(self, HandshakeType.finished, trial) 644 w.addFixSeq(self.verify_data, 1) 645 return HandshakeMsg.postWrite(self, w, trial) 646 647class EncryptedExtensions(HandshakeMsg): 648 def __init__(self): 649 self.channel_id_key = None 650 self.channel_id_proof = None 651 652 def parse(self, p): 653 p.startLengthCheck(3) 654 soFar = 0 655 while soFar != p.lengthCheck: 656 extType = p.get(2) 657 extLength = p.get(2) 658 if extType == ExtensionType.channel_id: 659 if extLength != 32*4: 660 raise SyntaxError() 661 self.channel_id_key = p.getFixBytes(64) 662 self.channel_id_proof = p.getFixBytes(64) 663 else: 664 p.getFixBytes(extLength) 665 soFar += 4 + extLength 666 p.stopLengthCheck() 667 return self 668 669class ApplicationData(Msg): 670 def __init__(self): 671 self.contentType = ContentType.application_data 672 self.bytes = createByteArraySequence([]) 673 674 def create(self, bytes): 675 self.bytes = bytes 676 return self 677 678 def parse(self, p): 679 self.bytes = p.bytes 680 return self 681 682 def write(self): 683 return self.bytes 684