1/*
2 * Copyright (C) 2015 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package org.conscrypt;
18
19import static org.junit.Assert.assertEquals;
20import static org.junit.Assert.assertFalse;
21
22import java.io.FileNotFoundException;
23import java.io.IOException;
24import java.io.InputStream;
25import java.lang.reflect.Method;
26import java.net.InetAddress;
27import java.net.ServerSocket;
28import java.net.UnknownHostException;
29import java.nio.ByteBuffer;
30import java.nio.charset.Charset;
31import java.security.NoSuchAlgorithmException;
32import java.security.Provider;
33import java.security.Security;
34import java.util.ArrayList;
35import java.util.Arrays;
36import java.util.Iterator;
37import java.util.LinkedHashSet;
38import java.util.List;
39import java.util.Set;
40import javax.net.ssl.SSLContext;
41import javax.net.ssl.SSLEngine;
42import javax.net.ssl.SSLEngineResult;
43import javax.net.ssl.SSLException;
44import javax.net.ssl.SSLParameters;
45import javax.net.ssl.SSLServerSocketFactory;
46import javax.net.ssl.SSLSocketFactory;
47import libcore.io.Streams;
48import org.bouncycastle.jce.provider.BouncyCastleProvider;
49import org.conscrypt.java.security.TestKeyStore;
50import org.junit.Assume;
51
52/**
53 * Utility methods to support testing.
54 */
55public final class TestUtils {
56    public static final Charset UTF_8 = Charset.forName("UTF-8");
57    private static final String PROTOCOL_TLS_V1_2 = "TLSv1.2";
58    private static final String PROTOCOL_TLS_V1_1 = "TLSv1.1";
59    private static final String PROTOCOL_TLS_V1 = "TLSv1";
60    private static final String[] DESIRED_PROTOCOLS =
61        new String[] {PROTOCOL_TLS_V1_2, PROTOCOL_TLS_V1_1, /* For Java 6 */ PROTOCOL_TLS_V1};
62    private static final Provider JDK_PROVIDER = getDefaultTlsProvider();
63    private static final byte[] CHARS =
64            "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789".getBytes(UTF_8);
65    private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocateDirect(0);
66    private static final String[] PROTOCOLS = getProtocolsInternal();
67
68    static final String TEST_CIPHER = "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256";
69
70    private TestUtils() {}
71
72    private static Provider getDefaultTlsProvider() {
73        for (String protocol : DESIRED_PROTOCOLS) {
74            for (Provider p : Security.getProviders()) {
75                if (hasProtocol(p, protocol)) {
76                    return p;
77                }
78            }
79        }
80        // For Java 1.6 testing
81        return new BouncyCastleProvider();
82    }
83
84    private static boolean hasProtocol(Provider p, String protocol) {
85        return p.get("SSLContext." + protocol) != null;
86    }
87
88    static Provider getJdkProvider() {
89        return JDK_PROVIDER;
90    }
91
92    private static void assumeClassAvailable(String classname) {
93        boolean available = false;
94        try {
95            Class.forName(classname);
96            available = true;
97        } catch (ClassNotFoundException ignore) {
98            // Ignored
99        }
100        Assume.assumeTrue("Skipping test: " + classname + " unavailable", available);
101    }
102
103    public static void assumeSNIHostnameAvailable() {
104        assumeClassAvailable("javax.net.ssl.SNIHostName");
105    }
106
107    public static void assumeSetEndpointIdentificationAlgorithmAvailable() {
108        boolean supported = false;
109        try {
110            SSLParameters.class.getMethod("setEndpointIdentificationAlgorithm", String.class);
111            supported = true;
112        } catch (NoSuchMethodException ignore) {
113            // Ignored
114        }
115        Assume.assumeTrue("Skipping test: "
116                + "SSLParameters.setEndpointIdentificationAlgorithm unavailable", supported);
117    }
118
119    public static void assumeAEADAvailable() {
120        assumeClassAvailable("javax.crypto.AEADBadTagException");
121    }
122
123    private static boolean isAndroid() {
124        try {
125            Class.forName("android.app.Application", false, ClassLoader.getSystemClassLoader());
126            return true;
127        } catch (Throwable ignored) {
128            // Failed to load the class uniquely available in Android.
129            return false;
130        }
131    }
132
133    public static void assumeAndroid() {
134        Assume.assumeTrue(isAndroid());
135    }
136
137    public static void assumeAllowsUnsignedCrypto() {
138        // The Oracle JRE disallows loading crypto providers from unsigned jars
139        Assume.assumeTrue(isAndroid()
140                || !System.getProperty("java.vm.name").contains("HotSpot"));
141    }
142
143    public static InetAddress getLoopbackAddress() {
144        try {
145            Method method = InetAddress.class.getMethod("getLoopbackAddress");
146            return (InetAddress) method.invoke(null);
147        } catch (Exception ignore) {
148            // Ignored.
149        }
150        try {
151            return InetAddress.getLocalHost();
152        } catch (UnknownHostException e) {
153            throw new RuntimeException(e);
154        }
155    }
156
157    public static Provider getConscryptProvider() {
158        try {
159            return (Provider) conscryptClass("OpenSSLProvider").getConstructor().newInstance();
160        } catch (Exception e) {
161            throw new RuntimeException(e);
162        }
163    }
164
165    public static synchronized void installConscryptAsDefaultProvider() {
166        final Provider conscryptProvider = getConscryptProvider();
167        Provider[] providers = Security.getProviders();
168        if (providers.length == 0 || !providers[0].equals(conscryptProvider)) {
169            Security.insertProviderAt(conscryptProvider, 1);
170        }
171    }
172
173    public static InputStream openTestFile(String name) throws FileNotFoundException {
174        InputStream is = TestUtils.class.getResourceAsStream("/" + name);
175        if (is == null) {
176            throw new FileNotFoundException(name);
177        }
178        return is;
179    }
180
181    public static byte[] readTestFile(String name) throws IOException {
182        return Streams.readFully(openTestFile(name));
183    }
184
185    /**
186     * Looks up the conscrypt class for the given simple name (i.e. no package prefix).
187     */
188    public static Class<?> conscryptClass(String simpleName) throws ClassNotFoundException {
189        ClassNotFoundException ex = null;
190        for (String packageName : new String[] {"org.conscrypt", "com.android.org.conscrypt"}) {
191            String name = packageName + "." + simpleName;
192            try {
193                return Class.forName(name);
194            } catch (ClassNotFoundException e) {
195                ex = e;
196            }
197        }
198        throw ex;
199    }
200
201    /**
202     * Returns an array containing only {@link #PROTOCOL_TLS_V1_2}.
203     */
204    public static String[] getProtocols() {
205        return PROTOCOLS;
206    }
207
208    private static String[] getProtocolsInternal() {
209        List<String> protocols = new ArrayList<String>();
210        for (String protocol : DESIRED_PROTOCOLS) {
211            if (hasProtocol(getJdkProvider(), protocol)) {
212                protocols.add(protocol);
213            }
214        }
215        return protocols.toArray(new String[protocols.size()]);
216    }
217
218    public static SSLSocketFactory getJdkSocketFactory() {
219        return getSocketFactory(JDK_PROVIDER);
220    }
221
222    public static SSLServerSocketFactory getJdkServerSocketFactory() {
223        return getServerSocketFactory(JDK_PROVIDER);
224    }
225
226    static SSLSocketFactory setUseEngineSocket(
227            SSLSocketFactory conscryptFactory, boolean useEngineSocket) {
228        try {
229            Class<?> clazz = conscryptClass("Conscrypt");
230            Method method =
231                    clazz.getMethod("setUseEngineSocket", SSLSocketFactory.class, boolean.class);
232            method.invoke(null, conscryptFactory, useEngineSocket);
233            return conscryptFactory;
234        } catch (Exception e) {
235            throw new RuntimeException(e);
236        }
237    }
238
239    static SSLServerSocketFactory setUseEngineSocket(
240            SSLServerSocketFactory conscryptFactory, boolean useEngineSocket) {
241        try {
242            Class<?> clazz = conscryptClass("Conscrypt");
243            Method method = clazz.getMethod(
244                    "setUseEngineSocket", SSLServerSocketFactory.class, boolean.class);
245            method.invoke(null, conscryptFactory, useEngineSocket);
246            return conscryptFactory;
247        } catch (Exception e) {
248            throw new RuntimeException(e);
249        }
250    }
251
252    public static SSLSocketFactory getConscryptSocketFactory(boolean useEngineSocket) {
253        return setUseEngineSocket(getSocketFactory(getConscryptProvider()), useEngineSocket);
254    }
255
256    public static SSLServerSocketFactory getConscryptServerSocketFactory(boolean useEngineSocket) {
257        return setUseEngineSocket(getServerSocketFactory(getConscryptProvider()), useEngineSocket);
258    }
259
260    private static SSLSocketFactory getSocketFactory(Provider provider) {
261        SSLContext clientContext = initClientSslContext(newContext(provider));
262        return clientContext.getSocketFactory();
263    }
264
265    private static SSLServerSocketFactory getServerSocketFactory(Provider provider) {
266        SSLContext serverContext = initServerSslContext(newContext(provider));
267        return serverContext.getServerSocketFactory();
268    }
269
270    static SSLContext newContext(Provider provider) {
271        try {
272            return SSLContext.getInstance("TLS", provider);
273        } catch (NoSuchAlgorithmException e) {
274            throw new RuntimeException(e);
275        }
276    }
277
278    static String[] getCommonCipherSuites() {
279        SSLContext jdkContext =
280                TestUtils.initSslContext(newContext(getJdkProvider()), TestKeyStore.getClient());
281        SSLContext conscryptContext = TestUtils.initSslContext(
282                newContext(getConscryptProvider()), TestKeyStore.getClient());
283        Set<String> supported = new LinkedHashSet<String>();
284        supported.addAll(supportedCiphers(jdkContext));
285        supported.retainAll(supportedCiphers(conscryptContext));
286        filterCiphers(supported);
287
288        return supported.toArray(new String[supported.size()]);
289    }
290
291    private static List<String> supportedCiphers(SSLContext ctx) {
292        return Arrays.asList(ctx.getDefaultSSLParameters().getCipherSuites());
293    }
294
295    private static void filterCiphers(Iterable<String> ciphers) {
296        // Filter all non-TLS ciphers.
297        Iterator<String> iter = ciphers.iterator();
298        while (iter.hasNext()) {
299            String cipher = iter.next();
300            if (cipher.startsWith("SSL_") || cipher.startsWith("TLS_EMPTY")
301                    || cipher.contains("_RC4_")) {
302                iter.remove();
303            }
304        }
305    }
306
307    /**
308     * Picks a port that is not used right at this moment.
309     * Warning: Not thread safe. May see "BindException: Address already in use: bind" if using the
310     * returned port to create a new server socket when other threads/processes are concurrently
311     * creating new sockets without a specific port.
312     */
313    public static int pickUnusedPort() {
314        try {
315            ServerSocket serverSocket = new ServerSocket(0);
316            int port = serverSocket.getLocalPort();
317            serverSocket.close();
318            return port;
319        } catch (IOException e) {
320            throw new RuntimeException(e);
321        }
322    }
323
324    /**
325     * Creates a text message of the given length.
326     */
327    public static byte[] newTextMessage(int length) {
328        byte[] msg = new byte[length];
329        for (int msgIndex = 0; msgIndex < length;) {
330            int remaining = length - msgIndex;
331            int numChars = Math.min(remaining, CHARS.length);
332            System.arraycopy(CHARS, 0, msg, msgIndex, numChars);
333            msgIndex += numChars;
334        }
335        return msg;
336    }
337
338    static SSLContext newClientSslContext(Provider provider) {
339        SSLContext context = newContext(provider);
340        return initClientSslContext(context);
341    }
342
343    static SSLContext newServerSslContext(Provider provider) {
344        SSLContext context = newContext(provider);
345        return initServerSslContext(context);
346    }
347
348    /**
349     * Initializes the given client-side {@code context} with a default cert.
350     */
351    public static SSLContext initClientSslContext(SSLContext context) {
352        return initSslContext(context, TestKeyStore.getClient());
353    }
354
355    /**
356     * Initializes the given server-side {@code context} with the given cert chain and private key.
357     */
358    public static SSLContext initServerSslContext(SSLContext context) {
359        return initSslContext(context, TestKeyStore.getServer());
360    }
361
362    /**
363     * Initializes the given {@code context} from the {@code keyStore}.
364     */
365    static SSLContext initSslContext(SSLContext context, TestKeyStore keyStore) {
366        try {
367            context.init(keyStore.keyManagers, keyStore.trustManagers, null);
368            return context;
369        } catch (Exception e) {
370            throw new RuntimeException(e);
371        }
372    }
373
374    /**
375     * Performs the intial TLS handshake between the two {@link SSLEngine} instances.
376     */
377    public static void doEngineHandshake(SSLEngine clientEngine, SSLEngine serverEngine,
378        ByteBuffer clientAppBuffer, ByteBuffer clientPacketBuffer, ByteBuffer serverAppBuffer,
379        ByteBuffer serverPacketBuffer, boolean beginHandshake) throws SSLException {
380        if (beginHandshake) {
381            clientEngine.beginHandshake();
382            serverEngine.beginHandshake();
383        }
384
385        SSLEngineResult clientResult;
386        SSLEngineResult serverResult;
387
388        boolean clientHandshakeFinished = false;
389        boolean serverHandshakeFinished = false;
390
391        do {
392            int cTOsPos = clientPacketBuffer.position();
393            int sTOcPos = serverPacketBuffer.position();
394
395            clientResult = clientEngine.wrap(EMPTY_BUFFER, clientPacketBuffer);
396            runDelegatedTasks(clientResult, clientEngine);
397            serverResult = serverEngine.wrap(EMPTY_BUFFER, serverPacketBuffer);
398            runDelegatedTasks(serverResult, serverEngine);
399
400            // Verify that the consumed and produced number match what is in the buffers now.
401            assertEquals(0, clientResult.bytesConsumed());
402            assertEquals(0, serverResult.bytesConsumed());
403            assertEquals(clientPacketBuffer.position() - cTOsPos, clientResult.bytesProduced());
404            assertEquals(serverPacketBuffer.position() - sTOcPos, serverResult.bytesProduced());
405
406            clientPacketBuffer.flip();
407            serverPacketBuffer.flip();
408
409            // Verify that we only had one SSLEngineResult.HandshakeStatus.FINISHED
410            if (isHandshakeFinished(clientResult)) {
411                assertFalse(clientHandshakeFinished);
412                clientHandshakeFinished = true;
413            }
414            if (isHandshakeFinished(serverResult)) {
415                assertFalse(serverHandshakeFinished);
416                serverHandshakeFinished = true;
417            }
418
419            cTOsPos = clientPacketBuffer.position();
420            sTOcPos = serverPacketBuffer.position();
421
422            int clientAppReadBufferPos = clientAppBuffer.position();
423            int serverAppReadBufferPos = serverAppBuffer.position();
424
425            clientResult = clientEngine.unwrap(serverPacketBuffer, clientAppBuffer);
426            runDelegatedTasks(clientResult, clientEngine);
427            serverResult = serverEngine.unwrap(clientPacketBuffer, serverAppBuffer);
428            runDelegatedTasks(serverResult, serverEngine);
429
430            // Verify that the consumed and produced number match what is in the buffers now.
431            assertEquals(serverPacketBuffer.position() - sTOcPos, clientResult.bytesConsumed());
432            assertEquals(clientPacketBuffer.position() - cTOsPos, serverResult.bytesConsumed());
433            assertEquals(clientAppBuffer.position() - clientAppReadBufferPos,
434                clientResult.bytesProduced());
435            assertEquals(serverAppBuffer.position() - serverAppReadBufferPos,
436                serverResult.bytesProduced());
437
438            clientPacketBuffer.compact();
439            serverPacketBuffer.compact();
440
441            // Verify that we only had one SSLEngineResult.HandshakeStatus.FINISHED
442            if (isHandshakeFinished(clientResult)) {
443                assertFalse(clientHandshakeFinished);
444                clientHandshakeFinished = true;
445            }
446            if (isHandshakeFinished(serverResult)) {
447                assertFalse(serverHandshakeFinished);
448                serverHandshakeFinished = true;
449            }
450        } while (!clientHandshakeFinished || !serverHandshakeFinished);
451    }
452
453    private static boolean isHandshakeFinished(SSLEngineResult result) {
454        return result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED;
455    }
456
457    private static void runDelegatedTasks(SSLEngineResult result, SSLEngine engine) {
458        if (result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_TASK) {
459            for (;;) {
460                Runnable task = engine.getDelegatedTask();
461                if (task == null) {
462                    break;
463                }
464                task.run();
465            }
466        }
467    }
468
469    /**
470     * Decodes the provided hexadecimal string into a byte array.  Odd-length inputs
471     * are not allowed.
472     *
473     * Throws an {@code IllegalArgumentException} if the input is malformed.
474     */
475    public static byte[] decodeHex(String encoded) throws IllegalArgumentException {
476        return decodeHex(encoded.toCharArray());
477    }
478
479    /**
480     * Decodes the provided hexadecimal string into a byte array. If {@code allowSingleChar}
481     * is {@code true} odd-length inputs are allowed and the first character is interpreted
482     * as the lower bits of the first result byte.
483     *
484     * Throws an {@code IllegalArgumentException} if the input is malformed.
485     */
486    public static byte[] decodeHex(String encoded, boolean allowSingleChar) throws IllegalArgumentException {
487        return decodeHex(encoded.toCharArray(), allowSingleChar);
488    }
489
490    /**
491     * Decodes the provided hexadecimal string into a byte array.  Odd-length inputs
492     * are not allowed.
493     *
494     * Throws an {@code IllegalArgumentException} if the input is malformed.
495     */
496    public static byte[] decodeHex(char[] encoded) throws IllegalArgumentException {
497        return decodeHex(encoded, false);
498    }
499
500    /**
501     * Decodes the provided hexadecimal string into a byte array. If {@code allowSingleChar}
502     * is {@code true} odd-length inputs are allowed and the first character is interpreted
503     * as the lower bits of the first result byte.
504     *
505     * Throws an {@code IllegalArgumentException} if the input is malformed.
506     */
507    public static byte[] decodeHex(char[] encoded, boolean allowSingleChar) throws IllegalArgumentException {
508        int resultLengthBytes = (encoded.length + 1) / 2;
509        byte[] result = new byte[resultLengthBytes];
510
511        int resultOffset = 0;
512        int i = 0;
513        if (allowSingleChar) {
514            if ((encoded.length % 2) != 0) {
515                // Odd number of digits -- the first digit is the lower 4 bits of the first result byte.
516                result[resultOffset++] = (byte) toDigit(encoded, i);
517                i++;
518            }
519        } else {
520            if ((encoded.length % 2) != 0) {
521                throw new IllegalArgumentException("Invalid input length: " + encoded.length);
522            }
523        }
524
525        for (int len = encoded.length; i < len; i += 2) {
526            result[resultOffset++] = (byte) ((toDigit(encoded, i) << 4) | toDigit(encoded, i + 1));
527        }
528
529        return result;
530    }
531
532
533    private static int toDigit(char[] str, int offset) throws IllegalArgumentException {
534        // NOTE: that this isn't really a code point in the traditional sense, since we're
535        // just rejecting surrogate pairs outright.
536        int pseudoCodePoint = str[offset];
537
538        if ('0' <= pseudoCodePoint && pseudoCodePoint <= '9') {
539            return pseudoCodePoint - '0';
540        } else if ('a' <= pseudoCodePoint && pseudoCodePoint <= 'f') {
541            return 10 + (pseudoCodePoint - 'a');
542        } else if ('A' <= pseudoCodePoint && pseudoCodePoint <= 'F') {
543            return 10 + (pseudoCodePoint - 'A');
544        }
545
546        throw new IllegalArgumentException("Illegal char: " + str[offset] +
547                " at offset " + offset);
548    }
549}
550