1/*
2 * Copyright (C) 2011 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package org.conscrypt;
18
19import java.io.File;
20import java.io.FileWriter;
21import java.security.cert.Certificate;
22import java.security.cert.CertificateException;
23import java.security.cert.X509Certificate;
24import java.security.KeyStore;
25import java.security.MessageDigest;
26import java.security.Principal;
27import java.util.ArrayList;
28import java.util.Arrays;
29import java.util.List;
30import javax.net.ssl.SSLPeerUnverifiedException;
31import javax.net.ssl.SSLSession;
32import javax.net.ssl.SSLSessionContext;
33import javax.net.ssl.TrustManagerFactory;
34import javax.net.ssl.X509TrustManager;
35import junit.framework.TestCase;
36import libcore.java.security.TestKeyStore;
37
38public class TrustManagerImplTest extends TestCase {
39
40    private List<File> tmpFiles = new ArrayList<File>();
41
42    private String getFingerprint(X509Certificate cert) throws Exception {
43        MessageDigest dgst = MessageDigest.getInstance("SHA512");
44        byte[] encoded = cert.getPublicKey().getEncoded();
45        byte[] fingerprint = dgst.digest(encoded);
46        return IntegralToString.bytesToHexString(fingerprint, false);
47    }
48
49    private String writeTmpPinFile(String text) throws Exception {
50        File tmp = File.createTempFile("pins", null);
51        FileWriter fstream = new FileWriter(tmp);
52        fstream.write(text);
53        fstream.close();
54        tmpFiles.add(tmp);
55        return tmp.getPath();
56    }
57
58    @Override
59    public void tearDown() throws Exception {
60        try {
61            for (File f : tmpFiles) {
62                f.delete();
63            }
64            tmpFiles.clear();
65        } finally {
66            super.tearDown();
67        }
68    }
69
70    /**
71     * Ensure that our non-standard behavior of learning to trust new
72     * intermediate CAs does not regress. http://b/3404902
73     */
74    public void testLearnIntermediate() throws Exception {
75        // chain3 should be server/intermediate/root
76        KeyStore.PrivateKeyEntry pke = TestKeyStore.getServer().getPrivateKey("RSA", "RSA");
77        X509Certificate[] chain3 = (X509Certificate[])pke.getCertificateChain();
78        X509Certificate root = chain3[2];
79        X509Certificate intermediate = chain3[1];
80        X509Certificate server = chain3[0];
81        X509Certificate[] chain2 =  new X509Certificate[] { server, intermediate };
82        X509Certificate[] chain1 =  new X509Certificate[] { server };
83
84        // Normal behavior
85        assertValid(chain3,   trustManager(root));
86        assertValid(chain2,   trustManager(root));
87        assertInvalid(chain1, trustManager(root));
88        assertValid(chain3,   trustManager(intermediate));
89        assertValid(chain2,   trustManager(intermediate));
90        assertValid(chain1,   trustManager(intermediate));
91        assertValid(chain3,   trustManager(server));
92        assertValid(chain2,   trustManager(server));
93        assertValid(chain1,   trustManager(server));
94
95        // non-standard behavior
96        X509TrustManager tm = trustManager(root);
97        // fail on short chain with only root trusted
98        assertInvalid(chain1, tm);
99        // succeed on longer chain, learn intermediate
100        assertValid(chain2, tm);
101        // now we can validate the short chain
102        assertValid(chain1, tm);
103    }
104
105    // We should ignore duplicate cruft in the certificate chain
106    // See https://code.google.com/p/android/issues/detail?id=52295 http://b/8313312
107    public void testDuplicateInChain() throws Exception {
108        // chain3 should be server/intermediate/root
109        KeyStore.PrivateKeyEntry pke = TestKeyStore.getServer().getPrivateKey("RSA", "RSA");
110        X509Certificate[] chain3 = (X509Certificate[])pke.getCertificateChain();
111        X509Certificate root = chain3[2];
112        X509Certificate intermediate = chain3[1];
113        X509Certificate server = chain3[0];
114
115        X509Certificate[] chain4 = new X509Certificate[] { server, intermediate,
116                                                           server, intermediate
117        };
118        assertValid(chain4, trustManager(root));
119    }
120
121    public void testGetFullChain() throws Exception {
122        // build the trust manager
123        KeyStore.PrivateKeyEntry pke = TestKeyStore.getServer().getPrivateKey("RSA", "RSA");
124        X509Certificate[] chain3 = (X509Certificate[])pke.getCertificateChain();
125        X509Certificate root = chain3[2];
126        X509TrustManager tm = trustManager(root);
127
128        // build the chains we'll use for testing
129        X509Certificate intermediate = chain3[1];
130        X509Certificate server = chain3[0];
131        X509Certificate[] chain2 =  new X509Certificate[] { server, intermediate };
132        X509Certificate[] chain1 =  new X509Certificate[] { server };
133
134        assertTrue(tm instanceof TrustManagerImpl);
135        TrustManagerImpl tmi = (TrustManagerImpl) tm;
136        List<X509Certificate> certs = tmi.checkServerTrusted(chain2, "RSA", new MySSLSession(
137                "purple.com"));
138        assertEquals(Arrays.asList(chain3), certs);
139        certs = tmi.checkServerTrusted(chain1, "RSA", new MySSLSession("purple.com"));
140        assertEquals(Arrays.asList(chain3), certs);
141    }
142
143    public void testCertPinning() throws Exception {
144        // chain3 should be server/intermediate/root
145        KeyStore.PrivateKeyEntry pke = TestKeyStore.getServer().getPrivateKey("RSA", "RSA");
146        X509Certificate[] chain3 = (X509Certificate[]) pke.getCertificateChain();
147        X509Certificate root = chain3[2];
148        X509Certificate intermediate = chain3[1];
149        X509Certificate server = chain3[0];
150        X509Certificate[] chain2 =  new X509Certificate[] { server, intermediate };
151        X509Certificate[] chain1 =  new X509Certificate[] { server };
152
153        // test without a hostname, expecting failure
154        assertInvalidPinned(chain1, trustManager(root, "gugle.com", root), null);
155        // test without a hostname, expecting success
156        assertValidPinned(chain3, trustManager(root, "gugle.com", root), null, chain3);
157        // test an unpinned hostname that should fail
158        assertInvalidPinned(chain1, trustManager(root, "gugle.com", root), "purple.com");
159        // test an unpinned hostname that should succeed
160        assertValidPinned(chain3, trustManager(root, "gugle.com", root), "purple.com", chain3);
161        // test a pinned hostname that should fail
162        assertInvalidPinned(chain1, trustManager(intermediate, "gugle.com", root), "gugle.com");
163        // test a pinned hostname that should succeed
164        assertValidPinned(chain2, trustManager(intermediate, "gugle.com", server), "gugle.com",
165                          chain2);
166        // test a pinned hostname that chains to user installed that should succeed
167        assertValidPinned(chain2, trustManagerUserInstalled(
168            (X509Certificate)TestKeyStore.getIntermediateCa2().getPrivateKey("RSA", "RSA")
169                .getCertificateChain()[1], intermediate, "gugle.com", server), "gugle.com",
170                chain2, true);
171    }
172
173    private X509TrustManager trustManager(X509Certificate ca) throws Exception {
174        KeyStore keyStore = TestKeyStore.createKeyStore();
175        keyStore.setCertificateEntry("alias", ca);
176
177        String algorithm = TrustManagerFactory.getDefaultAlgorithm();
178        TrustManagerFactory tmf = TrustManagerFactory.getInstance(algorithm);
179        tmf.init(keyStore);
180        return (X509TrustManager) tmf.getTrustManagers()[0];
181    }
182
183    private TrustManagerImpl trustManager(X509Certificate ca, String hostname, X509Certificate pin)
184                                          throws Exception {
185        // build the cert pin manager
186        CertPinManager cm = certManager(hostname, pin);
187        // insert it into the trust manager
188        KeyStore keyStore = TestKeyStore.createKeyStore();
189        keyStore.setCertificateEntry("alias", ca);
190        return new TrustManagerImpl(keyStore, cm);
191    }
192
193    private TrustManagerImpl trustManagerUserInstalled(
194        X509Certificate caKeyStore, X509Certificate caUserStore, String hostname,
195        X509Certificate pin) throws Exception {
196        // build the cert pin manager
197        CertPinManager cm = certManager(hostname, pin);
198
199        // install at least one cert in the store (requirement)
200        KeyStore keyStore = TestKeyStore.createKeyStore();
201        keyStore.setCertificateEntry("alias", caKeyStore);
202
203        // install a cert into the user installed store
204        final File DIR_TEMP = new File(System.getProperty("java.io.tmpdir"));
205        final File DIR_TEST = new File(DIR_TEMP, "test");
206        final File system = new File(DIR_TEST, "system-test");
207        final File added = new File(DIR_TEST, "added-test");
208        final File deleted = new File(DIR_TEST, "deleted-test");
209
210        TrustedCertificateStore tcs = new TrustedCertificateStore(system, added, deleted);
211        added.mkdirs();
212        tcs.installCertificate(caUserStore);
213        return new TrustManagerImpl(keyStore, cm, tcs);
214    }
215
216    private CertPinManager certManager(String hostname, X509Certificate pin) throws Exception {
217        String pinString = "";
218        if (pin != null) {
219            pinString = hostname + "=true|" + getFingerprint(pin);
220        }
221        // write it to a pinfile
222        String path = writeTmpPinFile(pinString);
223        // build the certpinmanager
224        return new CertPinManager(path, new TrustedCertificateStore());
225    }
226
227    private void assertValid(X509Certificate[] chain, X509TrustManager tm) throws Exception {
228        if (tm instanceof TrustManagerImpl) {
229            TrustManagerImpl tmi = (TrustManagerImpl) tm;
230            tmi.checkServerTrusted(chain, "RSA");
231        }
232        tm.checkServerTrusted(chain, "RSA");
233    }
234
235    private void assertValidPinned(X509Certificate[] chain, X509TrustManager tm, String hostname,
236                                   X509Certificate[] fullChain) throws Exception {
237        assertValidPinned(chain, tm, hostname, fullChain, false);
238    }
239
240    private void assertValidPinned(X509Certificate[] chain, X509TrustManager tm, String hostname,
241                                   X509Certificate[] fullChain, boolean expectUserInstalled)
242                                   throws Exception {
243        if (tm instanceof TrustManagerImpl) {
244            TrustManagerImpl tmi = (TrustManagerImpl) tm;
245            List<X509Certificate> checkedChain = tmi.checkServerTrusted(chain, "RSA",
246                    new MySSLSession(hostname));
247            assertEquals(checkedChain, Arrays.asList(fullChain));
248            boolean chainContainsUserInstalled = false;
249            for (X509Certificate cert : checkedChain) {
250                if (tmi.isUserAddedCertificate(cert)) {
251                    chainContainsUserInstalled = true;
252                }
253            }
254            assertEquals(expectUserInstalled, chainContainsUserInstalled);
255        }
256        tm.checkServerTrusted(chain, "RSA");
257    }
258
259    private void assertInvalid(X509Certificate[] chain, X509TrustManager tm) {
260        try {
261            tm.checkClientTrusted(chain, "RSA");
262            fail();
263        } catch (CertificateException expected) {
264        }
265        try {
266            tm.checkServerTrusted(chain, "RSA");
267            fail();
268        } catch (CertificateException expected) {
269        }
270    }
271
272    private void assertInvalidPinned(X509Certificate[] chain, X509TrustManager tm, String hostname)
273                                     throws Exception {
274        assertTrue(tm.getClass().getName(), tm instanceof TrustManagerImpl);
275        try {
276            TrustManagerImpl tmi = (TrustManagerImpl) tm;
277            tmi.checkServerTrusted(chain, "RSA", new MySSLSession(hostname));
278            fail();
279        } catch (CertificateException expected) {
280        }
281    }
282
283    private class MySSLSession implements SSLSession {
284        private final String hostname;
285
286        public MySSLSession(String hostname) {
287            this.hostname = hostname;
288        }
289
290        @Override
291        public int getApplicationBufferSize() {
292            throw new UnsupportedOperationException();
293        }
294
295        @Override
296        public String getCipherSuite() {
297            throw new UnsupportedOperationException();
298        }
299
300        @Override
301        public long getCreationTime() {
302            throw new UnsupportedOperationException();
303        }
304
305        @Override
306        public byte[] getId() {
307            throw new UnsupportedOperationException();
308        }
309
310        @Override
311        public long getLastAccessedTime() {
312            throw new UnsupportedOperationException();
313        }
314
315        @Override
316        public Certificate[] getLocalCertificates() {
317            throw new UnsupportedOperationException();
318        }
319
320        @Override
321        public Principal getLocalPrincipal() {
322            throw new UnsupportedOperationException();
323        }
324
325        @Override
326        public int getPacketBufferSize() {
327            throw new UnsupportedOperationException();
328        }
329
330        @Override
331        public javax.security.cert.X509Certificate[] getPeerCertificateChain()
332                throws SSLPeerUnverifiedException {
333            throw new UnsupportedOperationException();
334        }
335
336        @Override
337        public Certificate[] getPeerCertificates() throws SSLPeerUnverifiedException {
338            throw new UnsupportedOperationException();
339        }
340
341        @Override
342        public String getPeerHost() {
343            return hostname;
344        }
345
346        @Override
347        public int getPeerPort() {
348            throw new UnsupportedOperationException();
349        }
350
351        @Override
352        public Principal getPeerPrincipal() throws SSLPeerUnverifiedException {
353            throw new UnsupportedOperationException();
354        }
355
356        @Override
357        public String getProtocol() {
358            throw new UnsupportedOperationException();
359        }
360
361        @Override
362        public SSLSessionContext getSessionContext() {
363            throw new UnsupportedOperationException();
364        }
365
366        @Override
367        public Object getValue(String name) {
368            throw new UnsupportedOperationException();
369        }
370
371        @Override
372        public String[] getValueNames() {
373            throw new UnsupportedOperationException();
374        }
375
376        @Override
377        public void invalidate() {
378            throw new UnsupportedOperationException();
379        }
380
381        @Override
382        public boolean isValid() {
383            throw new UnsupportedOperationException();
384        }
385
386        @Override
387        public void putValue(String name, Object value) {
388            throw new UnsupportedOperationException();
389        }
390
391        @Override
392        public void removeValue(String name) {
393            throw new UnsupportedOperationException();
394        }
395    }
396}
397