1"""Class representing an X.509 certificate chain."""
2
3from utils import cryptomath
4from X509 import X509
5
6class X509CertChain:
7    """This class represents a chain of X.509 certificates.
8
9    @type x509List: list
10    @ivar x509List: A list of L{tlslite.X509.X509} instances,
11    starting with the end-entity certificate and with every
12    subsequent certificate certifying the previous.
13    """
14
15    def __init__(self, x509List=None):
16        """Create a new X509CertChain.
17
18        @type x509List: list
19        @param x509List: A list of L{tlslite.X509.X509} instances,
20        starting with the end-entity certificate and with every
21        subsequent certificate certifying the previous.
22        """
23        if x509List:
24            self.x509List = x509List
25        else:
26            self.x509List = []
27
28    def parseChain(self, s):
29        """Parse a PEM-encoded X.509 certificate file chain file.
30
31        @type s: str
32        @param s: A PEM-encoded (eg: Base64) X.509 certificate file, with every
33        certificate wrapped within "-----BEGIN CERTIFICATE-----" and
34        "-----END CERTIFICATE-----" tags). Extraneous data outside such tags,
35        such as human readable representations, will be ignored.
36        """
37
38        class PEMIterator(object):
39            """Simple iterator over PEM-encoded certificates within a string.
40
41            @type data: string
42            @ivar data: A string containing PEM-encoded (Base64) certificates,
43            with every certificate wrapped within "-----BEGIN CERTIFICATE-----"
44            and "-----END CERTIFICATE-----" tags). Extraneous data outside such
45            tags, such as human readable representations, will be ignored.
46
47            @type index: integer
48            @ivar index: The current offset within data to begin iterating from.
49            """
50
51            _CERTIFICATE_HEADER = "-----BEGIN CERTIFICATE-----"
52            """The PEM encoding block header for X.509 certificates."""
53
54            _CERTIFICATE_FOOTER = "-----END CERTIFICATE-----"
55            """The PEM encoding block footer for X.509 certificates."""
56
57            def __init__(self, s):
58                self.data = s
59                self.index = 0
60
61            def __iter__(self):
62                return self
63
64            def next(self):
65                """Iterates and returns the next L{tlslite.X509.X509}
66                certificate in data.
67
68                @rtype tlslite.X509.X509
69                """
70
71                self.index = self.data.find(self._CERTIFICATE_HEADER,
72                                            self.index)
73                if self.index == -1:
74                    raise StopIteration
75                end = self.data.find(self._CERTIFICATE_FOOTER, self.index)
76                if end == -1:
77                    raise StopIteration
78
79                certStr = self.data[self.index+len(self._CERTIFICATE_HEADER) :
80                                    end]
81                self.index = end + len(self._CERTIFICATE_FOOTER)
82                bytes = cryptomath.base64ToBytes(certStr)
83                return X509().parseBinary(bytes)
84
85        self.x509List = list(PEMIterator(s))
86        return self
87
88    def getNumCerts(self):
89        """Get the number of certificates in this chain.
90
91        @rtype: int
92        """
93        return len(self.x509List)
94
95    def getEndEntityPublicKey(self):
96        """Get the public key from the end-entity certificate.
97
98        @rtype: L{tlslite.utils.RSAKey.RSAKey}
99        """
100        if self.getNumCerts() == 0:
101            raise AssertionError()
102        return self.x509List[0].publicKey
103
104    def getFingerprint(self):
105        """Get the hex-encoded fingerprint of the end-entity certificate.
106
107        @rtype: str
108        @return: A hex-encoded fingerprint.
109        """
110        if self.getNumCerts() == 0:
111            raise AssertionError()
112        return self.x509List[0].getFingerprint()
113
114    def getCommonName(self):
115        """Get the Subject's Common Name from the end-entity certificate.
116
117        The cryptlib_py module must be installed in order to use this
118        function.
119
120        @rtype: str or None
121        @return: The CN component of the certificate's subject DN, if
122        present.
123        """
124        if self.getNumCerts() == 0:
125            raise AssertionError()
126        return self.x509List[0].getCommonName()
127
128    def validate(self, x509TrustList):
129        """Check the validity of the certificate chain.
130
131        This checks that every certificate in the chain validates with
132        the subsequent one, until some certificate validates with (or
133        is identical to) one of the passed-in root certificates.
134
135        The cryptlib_py module must be installed in order to use this
136        function.
137
138        @type x509TrustList: list of L{tlslite.X509.X509}
139        @param x509TrustList: A list of trusted root certificates.  The
140        certificate chain must extend to one of these certificates to
141        be considered valid.
142        """
143
144        import cryptlib_py
145        c1 = None
146        c2 = None
147        lastC = None
148        rootC = None
149
150        try:
151            rootFingerprints = [c.getFingerprint() for c in x509TrustList]
152
153            #Check that every certificate in the chain validates with the
154            #next one
155            for cert1, cert2 in zip(self.x509List, self.x509List[1:]):
156
157                #If we come upon a root certificate, we're done.
158                if cert1.getFingerprint() in rootFingerprints:
159                    return True
160
161                c1 = cryptlib_py.cryptImportCert(cert1.writeBytes(),
162                                                 cryptlib_py.CRYPT_UNUSED)
163                c2 = cryptlib_py.cryptImportCert(cert2.writeBytes(),
164                                                 cryptlib_py.CRYPT_UNUSED)
165                try:
166                    cryptlib_py.cryptCheckCert(c1, c2)
167                except:
168                    return False
169                cryptlib_py.cryptDestroyCert(c1)
170                c1 = None
171                cryptlib_py.cryptDestroyCert(c2)
172                c2 = None
173
174            #If the last certificate is one of the root certificates, we're
175            #done.
176            if self.x509List[-1].getFingerprint() in rootFingerprints:
177                return True
178
179            #Otherwise, find a root certificate that the last certificate
180            #chains to, and validate them.
181            lastC = cryptlib_py.cryptImportCert(self.x509List[-1].writeBytes(),
182                                                cryptlib_py.CRYPT_UNUSED)
183            for rootCert in x509TrustList:
184                rootC = cryptlib_py.cryptImportCert(rootCert.writeBytes(),
185                                                    cryptlib_py.CRYPT_UNUSED)
186                if self._checkChaining(lastC, rootC):
187                    try:
188                        cryptlib_py.cryptCheckCert(lastC, rootC)
189                        return True
190                    except:
191                        return False
192            return False
193        finally:
194            if not (c1 is None):
195                cryptlib_py.cryptDestroyCert(c1)
196            if not (c2 is None):
197                cryptlib_py.cryptDestroyCert(c2)
198            if not (lastC is None):
199                cryptlib_py.cryptDestroyCert(lastC)
200            if not (rootC is None):
201                cryptlib_py.cryptDestroyCert(rootC)
202
203
204
205    def _checkChaining(self, lastC, rootC):
206        import cryptlib_py
207        import array
208        def compareNames(name):
209            try:
210                length = cryptlib_py.cryptGetAttributeString(lastC, name, None)
211                lastName = array.array('B', [0] * length)
212                cryptlib_py.cryptGetAttributeString(lastC, name, lastName)
213                lastName = lastName.tostring()
214            except cryptlib_py.CryptException, e:
215                if e[0] == cryptlib_py.CRYPT_ERROR_NOTFOUND:
216                    lastName = None
217            try:
218                length = cryptlib_py.cryptGetAttributeString(rootC, name, None)
219                rootName = array.array('B', [0] * length)
220                cryptlib_py.cryptGetAttributeString(rootC, name, rootName)
221                rootName = rootName.tostring()
222            except cryptlib_py.CryptException, e:
223                if e[0] == cryptlib_py.CRYPT_ERROR_NOTFOUND:
224                    rootName = None
225
226            return lastName == rootName
227
228        cryptlib_py.cryptSetAttribute(lastC,
229                                      cryptlib_py.CRYPT_CERTINFO_ISSUERNAME,
230                                      cryptlib_py.CRYPT_UNUSED)
231
232        if not compareNames(cryptlib_py.CRYPT_CERTINFO_COUNTRYNAME):
233            return False
234        if not compareNames(cryptlib_py.CRYPT_CERTINFO_LOCALITYNAME):
235            return False
236        if not compareNames(cryptlib_py.CRYPT_CERTINFO_ORGANIZATIONNAME):
237            return False
238        if not compareNames(cryptlib_py.CRYPT_CERTINFO_ORGANIZATIONALUNITNAME):
239            return False
240        if not compareNames(cryptlib_py.CRYPT_CERTINFO_COMMONNAME):
241            return False
242        return True