1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License
15 */
16
17package org.conscrypt;
18
19import static org.conscrypt.testing.TestUtil.PROTOCOL_TLS_V1_2;
20import static org.conscrypt.testing.TestUtil.initEngine;
21import static org.conscrypt.testing.TestUtil.initSslContext;
22import static org.conscrypt.testing.TestUtil.newTextMessage;
23import static org.junit.Assert.assertArrayEquals;
24import static org.junit.Assert.assertEquals;
25
26import java.io.ByteArrayOutputStream;
27import java.nio.ByteBuffer;
28import java.security.NoSuchAlgorithmException;
29import java.util.ArrayList;
30import java.util.Arrays;
31import java.util.List;
32import javax.net.ssl.SSLContext;
33import javax.net.ssl.SSLEngine;
34import javax.net.ssl.SSLEngineResult;
35import javax.net.ssl.SSLEngineResult.HandshakeStatus;
36import javax.net.ssl.SSLEngineResult.Status;
37import javax.net.ssl.SSLException;
38import javax.net.ssl.SSLHandshakeException;
39import libcore.java.security.TestKeyStore;
40import org.conscrypt.testing.TestUtil;
41import org.junit.Test;
42import org.junit.runner.RunWith;
43import org.junit.runners.Parameterized;
44import org.junit.runners.Parameterized.Parameter;
45import org.junit.runners.Parameterized.Parameters;
46
47@RunWith(Parameterized.class)
48public class OpenSSLEngineImplTest {
49    private static final String CIPHER = "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256";
50    private static final int MESSAGE_SIZE = 4096;
51
52    public enum BufferType {
53        HEAP {
54            @Override
55            ByteBuffer newBuffer(int size) {
56                return ByteBuffer.allocate(size);
57            }
58        },
59        DIRECT {
60            @Override
61            ByteBuffer newBuffer(int size) {
62                return ByteBuffer.allocateDirect(size);
63            }
64        };
65
66        abstract ByteBuffer newBuffer(int size);
67    }
68
69    private enum ClientAuth {
70        NONE {
71            @Override
72            SSLEngine apply(SSLEngine engine) {
73                engine.setWantClientAuth(false);
74                engine.setNeedClientAuth(false);
75                return engine;
76            }
77        },
78        OPTIONAL {
79            @Override
80            SSLEngine apply(SSLEngine engine) {
81                engine.setWantClientAuth(true);
82                engine.setNeedClientAuth(false);
83                return engine;
84            }
85        },
86        REQUIRED {
87            @Override
88            SSLEngine apply(SSLEngine engine) {
89                engine.setWantClientAuth(false);
90                engine.setNeedClientAuth(true);
91                return engine;
92            }
93        };
94
95        abstract SSLEngine apply(SSLEngine engine);
96    }
97
98    @Parameters(name = "{0}")
99    public static Iterable<BufferType> data() {
100        return Arrays.asList(BufferType.HEAP, BufferType.DIRECT);
101    }
102
103    @Parameter
104    public BufferType bufferType;
105
106    private SSLEngine clientEngine;
107    private SSLEngine serverEngine;
108
109    @Test
110    public void mutualAuthWithSameCertsShouldSucceed() throws Exception {
111        doMutualAuthHandshake(TestKeyStore.getServer(), TestKeyStore.getServer(), ClientAuth.NONE);
112    }
113
114    @Test
115    public void mutualAuthWithDifferentCertsShouldSucceed() throws Exception {
116        doMutualAuthHandshake(TestKeyStore.getClient(), TestKeyStore.getServer(), ClientAuth.NONE);
117    }
118
119    @Test(expected = SSLHandshakeException.class)
120    public void mutualAuthWithUntrustedServerShouldFail() throws Exception {
121        doMutualAuthHandshake(TestKeyStore.getClientCA2(), TestKeyStore.getServer(), ClientAuth.NONE);
122    }
123
124    @Test(expected = SSLHandshakeException.class)
125    public void mutualAuthWithUntrustedClientShouldFail() throws Exception {
126        doMutualAuthHandshake(TestKeyStore.getClient(), TestKeyStore.getClient(), ClientAuth.NONE);
127    }
128
129    @Test
130    public void optionalClientAuthShouldSucceed() throws Exception {
131        doMutualAuthHandshake(TestKeyStore.getClient(), TestKeyStore.getServer(), ClientAuth.OPTIONAL);
132    }
133
134    @Test(expected = SSLHandshakeException.class)
135    public void optionalClientAuthShouldFail() throws Exception {
136        doMutualAuthHandshake(TestKeyStore.getClient(), TestKeyStore.getClient(), ClientAuth.OPTIONAL);
137    }
138
139    @Test
140    public void requiredClientAuthShouldSucceed() throws Exception {
141        doMutualAuthHandshake(TestKeyStore.getServer(), TestKeyStore.getServer(), ClientAuth.REQUIRED);
142    }
143
144    @Test(expected = SSLHandshakeException.class)
145    public void requiredClientAuthShouldFail() throws Exception {
146        doMutualAuthHandshake(TestKeyStore.getClient(), TestKeyStore.getClient(), ClientAuth.REQUIRED);
147    }
148
149    @Test
150    public void exchangeMessages() throws Exception {
151        setupEngines(TestKeyStore.getClient(), TestKeyStore.getServer());
152        TestUtil.doEngineHandshake(clientEngine, serverEngine);
153
154        ByteBuffer clientCleartextBuffer = bufferType.newBuffer(MESSAGE_SIZE);
155        clientCleartextBuffer.put(newTextMessage(MESSAGE_SIZE));
156        clientCleartextBuffer.flip();
157
158        // Wrap the original message and create the encrypted data.
159        final int numMessages = 100;
160        ByteBuffer[] encryptedBuffers = new ByteBuffer[numMessages];
161        for (int i = 0; i < numMessages; ++i) {
162            clientCleartextBuffer.position(0);
163            ByteBuffer out = bufferType.newBuffer(clientEngine.getSession().getPacketBufferSize());
164            SSLEngineResult wrapResult = clientEngine.wrap(clientCleartextBuffer, out);
165            assertEquals(SSLEngineResult.Status.OK, wrapResult.getStatus());
166            out.flip();
167            encryptedBuffers[i] = out;
168        }
169
170        // Create the expected cleartext message
171        clientCleartextBuffer.position(0);
172        byte[] expectedMessage = toArray(clientCleartextBuffer);
173
174        // Unwrap the all of the encrypted messages.
175        ByteBuffer[] cleartextBuffers = new ByteBuffer[numMessages];
176        for (int i = 0; i < numMessages; ++i) {
177            ByteBuffer out = bufferType.newBuffer(2 * MESSAGE_SIZE);
178            cleartextBuffers[i] = out;
179            SSLEngineResult unwrapResult = Conscrypt.Engines.unwrap(serverEngine, encryptedBuffers,
180                    new ByteBuffer[] {out});
181            assertEquals(SSLEngineResult.Status.OK, unwrapResult.getStatus());
182            assertEquals(MESSAGE_SIZE, unwrapResult.bytesProduced());
183
184            out.flip();
185            byte[] actualMessage = toArray(out);
186            assertArrayEquals(expectedMessage, actualMessage);
187        }
188    }
189
190    @Test
191    public void exchangeLargeMessage() throws Exception {
192        setupEngines(TestKeyStore.getClient(), TestKeyStore.getServer());
193        TestUtil.doEngineHandshake(clientEngine, serverEngine);
194
195        // Create the input message.
196        final int largeMessageSize = 16413;
197        final byte[] message = newTextMessage(largeMessageSize);
198        ByteBuffer inputBuffer = bufferType.newBuffer(largeMessageSize);
199        inputBuffer.put(message);
200        inputBuffer.flip();
201
202        // Encrypt the input message.
203        List<ByteBuffer> encryptedBufferList = new ArrayList<ByteBuffer>();
204        while(inputBuffer.hasRemaining()) {
205            ByteBuffer encryptedBuffer = bufferType.newBuffer(clientEngine.getSession().getPacketBufferSize());
206            SSLEngineResult wrapResult = clientEngine.wrap(inputBuffer, encryptedBuffer);
207            assertEquals(SSLEngineResult.Status.OK, wrapResult.getStatus());
208            encryptedBuffer.flip();
209            encryptedBufferList.add(encryptedBuffer);
210        }
211
212        // Unwrap the all of the encrypted messages.
213        ByteArrayOutputStream cleartextStream = new ByteArrayOutputStream();
214        ByteBuffer[] encryptedBuffers = encryptedBufferList.toArray(new ByteBuffer[encryptedBufferList.size()]);
215        int decryptedBufferSize = 8192;
216        final ByteBuffer decryptedBuffer = bufferType.newBuffer(decryptedBufferSize);
217        for (ByteBuffer encryptedBuffer : encryptedBuffers) {
218            SSLEngineResult.Status status = SSLEngineResult.Status.OK;
219            while (encryptedBuffer.hasRemaining() || status.equals(Status.BUFFER_OVERFLOW)) {
220                if (!decryptedBuffer.hasRemaining()) {
221                    decryptedBuffer.clear();
222                }
223                int prevPos = decryptedBuffer.position();
224                SSLEngineResult unwrapResult = Conscrypt.Engines.unwrap(serverEngine,
225                        encryptedBuffers, new ByteBuffer[]{decryptedBuffer});
226                status = unwrapResult.getStatus();
227                int newPos = decryptedBuffer.position();
228                int bytesProduced = unwrapResult.bytesProduced();
229                assertEquals(bytesProduced, newPos - prevPos);
230
231                // Add any generated bytes to the output stream.
232                if (bytesProduced > 0) {
233                    byte[] decryptedBytes = new byte[unwrapResult.bytesProduced()];
234
235                    // Read the chunk that was just written to the output array.
236                    int limit = decryptedBuffer.limit();
237                    decryptedBuffer.limit(newPos);
238                    decryptedBuffer.position(prevPos);
239                    decryptedBuffer.get(decryptedBytes);
240
241                    // Restore the position and limit.
242                    decryptedBuffer.limit(limit);
243
244                    // Write the decrypted bytes to the stream.
245                    cleartextStream.write(decryptedBytes);
246                }
247            }
248        }
249        byte[] actualMessage = cleartextStream.toByteArray();
250        assertArrayEquals(message, actualMessage);
251    }
252
253    private void doMutualAuthHandshake(TestKeyStore clientKs, TestKeyStore serverKs, ClientAuth clientAuth) throws Exception {
254        setupEngines(clientKs, serverKs);
255        clientAuth.apply(serverEngine);
256        TestUtil.doEngineHandshake(clientEngine, serverEngine);
257        assertEquals(HandshakeStatus.NOT_HANDSHAKING, clientEngine.getHandshakeStatus());
258        assertEquals(HandshakeStatus.NOT_HANDSHAKING, serverEngine.getHandshakeStatus());
259    }
260
261    private void setupEngines(TestKeyStore clientKeyStore, TestKeyStore serverKeyStore) throws SSLException {
262        SSLContext clientContext = initSslContext(newContext(), clientKeyStore);
263        SSLContext serverContext = initSslContext(newContext(), serverKeyStore);
264
265        clientEngine = initEngine(clientContext.createSSLEngine(), CIPHER, true);
266        serverEngine = initEngine(serverContext.createSSLEngine(), CIPHER, false);
267    }
268
269    private static byte[] toArray(ByteBuffer buffer) {
270        byte[] data = new byte[buffer.remaining()];
271        buffer.get(data);
272        return data;
273    }
274
275    private static SSLContext newContext() {
276        try {
277            return SSLContext.getInstance(PROTOCOL_TLS_V1_2, new OpenSSLProvider());
278        } catch (NoSuchAlgorithmException e) {
279            throw new RuntimeException(e);
280        }
281    }
282}
283