1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License
15 */
16
17package org.conscrypt;
18
19import static org.junit.Assert.assertArrayEquals;
20import static org.junit.Assert.assertEquals;
21import static org.junit.Assert.assertNotEquals;
22
23import java.io.EOFException;
24import java.io.IOException;
25import java.net.InetSocketAddress;
26import java.nio.ByteBuffer;
27import java.nio.channels.ServerSocketChannel;
28import java.nio.channels.SocketChannel;
29import java.util.Arrays;
30import java.util.LinkedHashSet;
31import java.util.Set;
32import java.util.concurrent.ExecutionException;
33import java.util.concurrent.ExecutorService;
34import java.util.concurrent.Executors;
35import java.util.concurrent.Future;
36import java.util.concurrent.TimeUnit;
37import java.util.concurrent.TimeoutException;
38import javax.net.ssl.SSLContext;
39import javax.net.ssl.SSLEngine;
40import javax.net.ssl.SSLEngineResult;
41import javax.net.ssl.SSLEngineResult.Status;
42import javax.net.ssl.SSLSocket;
43import javax.net.ssl.SSLSocketFactory;
44import org.conscrypt.java.security.TestKeyStore;
45import org.junit.After;
46import org.junit.Before;
47import org.junit.Test;
48import org.junit.runner.RunWith;
49import org.junit.runners.Parameterized;
50import org.junit.runners.Parameterized.Parameter;
51import org.junit.runners.Parameterized.Parameters;
52
53/**
54 * This tests that server-initiated cipher renegotiation works properly with a Conscrypt client.
55 * BoringSSL does not support user-initiated renegotiation, so we use the JDK implementation for
56 * the server.
57 */
58@RunWith(Parameterized.class)
59public class RenegotiationTest {
60    private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocateDirect(0);
61    private static final String[] CIPHERS = TestUtils.getCommonCipherSuites();
62    private static final byte[] MESSAGE_BYTES = "Hello".getBytes(TestUtils.UTF_8);
63    private static final ByteBuffer MESSAGE_BUFFER =
64            ByteBuffer.wrap(MESSAGE_BYTES).asReadOnlyBuffer();
65    private static final int MESSAGE_LENGTH = MESSAGE_BYTES.length;
66
67    public enum SocketType {
68        FILE_DESCRIPTOR {
69            @Override
70            Client newClient(int port) {
71                return new Client(false, port);
72            }
73        },
74        ENGINE {
75            @Override
76            Client newClient(int port) {
77                return new Client(true, port);
78            }
79        };
80
81        abstract Client newClient(int port);
82    }
83
84    @Parameters(name = "{0}")
85    public static Object[] data() {
86        return new Object[] {SocketType.FILE_DESCRIPTOR, SocketType.ENGINE};
87    }
88
89    @Parameter
90    public SocketType socketType;
91
92    private Client client;
93    private Server server;
94
95    @Before
96    public void setup() throws Exception {
97        server = new Server();
98        Future<?> connectedFuture = server.start();
99
100        client = socketType.newClient(server.port());
101        client.start();
102
103        // Wait for the initial connection to complete.
104        connectedFuture.get(5, TimeUnit.SECONDS);
105    }
106
107    @After
108    public void teardown() {
109        client.stop();
110        server.stop();
111    }
112
113    @Test
114    public void test() throws Exception {
115        client.socket.startHandshake();
116        String initialCipher = client.socket.getSession().getCipherSuite();
117
118        client.sendMessage();
119
120        Future<?> repliesFuture = client.readReplies();
121        server.await(5, TimeUnit.SECONDS);
122        repliesFuture.get(5, TimeUnit.SECONDS);
123
124        // Verify that the cipher has changed.
125        assertNotEquals(initialCipher, client.socket.getSession().getCipherSuite());
126    }
127
128    private static SSLContext newConscryptClientContext() {
129        SSLContext context = TestUtils.newContext(TestUtils.getConscryptProvider());
130        return TestUtils.initSslContext(context, TestKeyStore.getClient());
131    }
132
133    private static SSLContext newJdkServerContext() {
134        SSLContext context = TestUtils.newContext(TestUtils.getJdkProvider());
135        return TestUtils.initSslContext(context, TestKeyStore.getServer());
136    }
137
138    private static final class Client {
139        private final SSLSocket socket;
140        private ExecutorService executor;
141        private volatile boolean stopping;
142
143        Client(boolean useEngineSocket, int port) {
144            try {
145                SSLSocketFactory socketFactory = newConscryptClientContext().getSocketFactory();
146                Conscrypt.setUseEngineSocket(socketFactory, useEngineSocket);
147                socket = (SSLSocket) socketFactory.createSocket(
148                        TestUtils.getLoopbackAddress(), port);
149                socket.setEnabledCipherSuites(CIPHERS);
150            } catch (IOException e) {
151                throw new RuntimeException(e);
152            }
153        }
154
155        void start() {
156            try {
157                executor = Executors.newSingleThreadExecutor();
158                socket.startHandshake();
159            } catch (IOException e) {
160                e.printStackTrace();
161                throw new RuntimeException(e);
162            }
163        }
164
165        void stop() {
166            try {
167                stopping = true;
168                socket.close();
169
170                if (executor != null) {
171                    executor.shutdown();
172                    executor.awaitTermination(5, TimeUnit.SECONDS);
173                    executor = null;
174                }
175            } catch (RuntimeException e) {
176                throw e;
177            } catch (Exception e) {
178                throw new RuntimeException(e);
179            }
180        }
181
182        Future<?> readReplies() {
183            return executor.submit(new Runnable() {
184                @Override
185                public void run() {
186                    readReply();
187                }
188            });
189        }
190
191        private void readReply() {
192            try {
193                byte[] buffer = new byte[MESSAGE_LENGTH];
194                int totalBytesRead = 0;
195                while (totalBytesRead < MESSAGE_LENGTH) {
196                    int remaining = MESSAGE_LENGTH - totalBytesRead;
197                    int bytesRead = socket.getInputStream().read(buffer, totalBytesRead, remaining);
198                    if (bytesRead == -1) {
199                        throw new EOFException();
200                    }
201                    totalBytesRead += bytesRead;
202                }
203
204                // Verify the reply is correct.
205                assertEquals(MESSAGE_LENGTH, totalBytesRead);
206                assertArrayEquals(MESSAGE_BYTES, buffer);
207            } catch (IOException e) {
208                throw new RuntimeException(e);
209            }
210        }
211
212        void sendMessage() throws IOException {
213            try {
214                socket.getOutputStream().write(MESSAGE_BYTES);
215                socket.getOutputStream().flush();
216            } catch (IOException e) {
217                throw new RuntimeException(e);
218            }
219        }
220    }
221
222    private static final class Server {
223        private final ServerSocketChannel serverChannel;
224        private final SSLEngine engine;
225        private final ByteBuffer inboundPacketBuffer;
226        private final ByteBuffer inboundAppBuffer;
227        private final ByteBuffer outboundPacketBuffer;
228        private final Set<String> ciphers = new LinkedHashSet<String>(Arrays.asList(CIPHERS));
229        private SocketChannel channel;
230        private ExecutorService executor;
231        private volatile boolean stopping;
232        private volatile Future<?> echoFuture;
233
234        Server() throws IOException {
235            serverChannel = ServerSocketChannel.open();
236            serverChannel.socket().bind(new InetSocketAddress(TestUtils.getLoopbackAddress(), 0));
237            engine = newJdkServerContext().createSSLEngine();
238            engine.setEnabledCipherSuites(CIPHERS);
239            engine.setUseClientMode(false);
240
241            inboundPacketBuffer =
242                    ByteBuffer.allocateDirect(engine.getSession().getPacketBufferSize());
243            inboundAppBuffer =
244                    ByteBuffer.allocateDirect(engine.getSession().getApplicationBufferSize());
245            outboundPacketBuffer =
246                    ByteBuffer.allocateDirect(engine.getSession().getPacketBufferSize());
247        }
248
249        Future<?> start() throws IOException {
250            executor = Executors.newSingleThreadExecutor();
251            return executor.submit(new AcceptTask());
252        }
253
254        void await(long timeout, TimeUnit unit)
255                throws InterruptedException, ExecutionException, TimeoutException {
256            echoFuture.get(timeout, unit);
257        }
258
259        void stop() {
260            try {
261                stopping = true;
262
263                if (channel != null) {
264                    channel.close();
265                    channel = null;
266                }
267
268                serverChannel.close();
269
270                if (executor != null) {
271                    executor.shutdown();
272                    executor.awaitTermination(5, TimeUnit.SECONDS);
273                    executor = null;
274                }
275            } catch (IOException e) {
276                throw new RuntimeException(e);
277            } catch (InterruptedException e) {
278                throw new RuntimeException(e);
279            }
280        }
281
282        int port() {
283            return serverChannel.socket().getLocalPort();
284        }
285
286        private final class AcceptTask implements Runnable {
287            @Override
288            public void run() {
289                try {
290                    if (stopping) {
291                        return;
292                    }
293                    channel = serverChannel.accept();
294                    channel.configureBlocking(false);
295
296                    doHandshake();
297
298                    if (stopping) {
299                        return;
300                    }
301                    echoFuture = executor.submit(new EchoTask());
302                } catch (Throwable e) {
303                    e.printStackTrace();
304                    throw new RuntimeException(e);
305                }
306            }
307        }
308
309        private final class EchoTask implements Runnable {
310            @Override
311            public void run() {
312                try {
313                    readMessage();
314                    renegotiate();
315                    reply();
316                } catch (Throwable e) {
317                    e.printStackTrace();
318                    throw new RuntimeException(e);
319                }
320            }
321
322            private void renegotiate() throws Exception {
323                // Remove the current cipher from the set and renegotiate to force a new
324                // cipher to be selected.
325                String currentCipher = engine.getSession().getCipherSuite();
326                ciphers.remove(currentCipher);
327                engine.setEnabledCipherSuites(ciphers.toArray(new String[ciphers.size()]));
328                doHandshake();
329            }
330
331            private void reply() throws IOException {
332                SSLEngineResult result = wrap(newMessage());
333                if (result.getStatus() != Status.OK) {
334                    throw new RuntimeException("Wrap failed. Status: " + result.getStatus());
335                }
336            }
337
338            private ByteBuffer newMessage() {
339                return MESSAGE_BUFFER.duplicate();
340            }
341
342            private void readMessage() throws IOException {
343                int totalProduced = 0;
344                while (!stopping) {
345                    SSLEngineResult result = unwrap();
346                    if (result.getStatus() != Status.OK) {
347                        throw new RuntimeException("Failed reading message: " + result);
348                    }
349                    totalProduced += result.bytesProduced();
350                    if (totalProduced == MESSAGE_LENGTH) {
351                        return;
352                    }
353                }
354            }
355        }
356
357        private SSLEngineResult wrap(ByteBuffer src) throws IOException {
358            outboundPacketBuffer.clear();
359
360            // Check if the engine has bytes to wrap.
361            SSLEngineResult result = engine.wrap(src, outboundPacketBuffer);
362
363            // Write any wrapped bytes to the socket.
364            outboundPacketBuffer.flip();
365
366            do {
367                channel.write(outboundPacketBuffer);
368            } while (outboundPacketBuffer.hasRemaining());
369
370            return result;
371        }
372
373        private SSLEngineResult unwrap() throws IOException {
374            // Unwrap any available bytes from the socket.
375            SSLEngineResult result = null;
376            boolean done = false;
377            while (!done) {
378                if (channel.read(inboundPacketBuffer) == -1) {
379                    throw new EOFException();
380                }
381                // Just clear the app buffer - we don't really use it.
382                inboundAppBuffer.clear();
383                inboundPacketBuffer.flip();
384                result = engine.unwrap(inboundPacketBuffer, inboundAppBuffer);
385                switch (result.getStatus()) {
386                    case BUFFER_UNDERFLOW:
387                        // Continue reading from the socket in a moment.
388                        try {
389                            Thread.sleep(10);
390                        } catch (InterruptedException e) {
391                            throw new RuntimeException(e);
392                        }
393                        break;
394                    case OK:
395                        done = true;
396                        break;
397                    default: { throw new RuntimeException("Unexpected unwrap result: " + result); }
398                }
399
400                // Compact for the next socket read.
401                inboundPacketBuffer.compact();
402            }
403            return result;
404        }
405
406        private void doHandshake() throws IOException {
407            engine.beginHandshake();
408
409            boolean done = false;
410            while (!done) {
411                switch (engine.getHandshakeStatus()) {
412                    case NEED_WRAP: {
413                        wrap(EMPTY_BUFFER);
414                        break;
415                    }
416                    case NEED_UNWRAP: {
417                        unwrap();
418                        break;
419                    }
420                    case NEED_TASK: {
421                        runDelegatedTasks();
422                        break;
423                    }
424                    default: {
425                        done = true;
426                        break;
427                    }
428                }
429            }
430        }
431
432        private void runDelegatedTasks() {
433            for (;;) {
434                Runnable task = engine.getDelegatedTask();
435                if (task == null) {
436                    break;
437                }
438                task.run();
439            }
440        }
441    }
442}
443