1/*
2 * Copyright (C) 2010 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package libcore.javax.net.ssl;
18
19import java.io.IOException;
20import java.nio.ByteBuffer;
21import javax.net.ssl.SSLEngine;
22import javax.net.ssl.SSLEngineResult;
23import javax.net.ssl.SSLEngineResult.HandshakeStatus;
24import javax.net.ssl.SSLException;
25import javax.net.ssl.SSLSession;
26import junit.framework.Assert;
27
28/**
29 * TestSSLEnginePair is a convenience class for other tests that want
30 * a pair of connected and handshaked client and server SSLEngines for
31 * testing.
32 */
33public final class TestSSLEnginePair extends Assert implements AutoCloseable {
34    public final TestSSLContext c;
35    public final SSLEngine server;
36    public final SSLEngine client;
37
38    private TestSSLEnginePair(TestSSLContext c,
39                              SSLEngine server,
40                              SSLEngine client) {
41        this.c = c;
42        this.server = server;
43        this.client = client;
44    }
45
46    public static TestSSLEnginePair create() throws IOException {
47        return create(null);
48    }
49
50    public static TestSSLEnginePair create(Hooks hooks) throws IOException {
51        return create(TestSSLContext.create(), hooks);
52    }
53
54    public static TestSSLEnginePair create(TestSSLContext c, Hooks hooks) throws IOException {
55        return create(c, hooks, null);
56    }
57
58    public static TestSSLEnginePair create(TestSSLContext c, Hooks hooks, boolean[] finished)
59            throws IOException {
60        SSLEngine[] engines = connect(c, hooks, finished);
61        return new TestSSLEnginePair(c, engines[0], engines[1]);
62    }
63
64    public static SSLEngine[] connect(TestSSLContext c, Hooks hooks) throws IOException {
65        return connect(c, hooks, null);
66    }
67
68    /**
69     * Create a new connected server/client engine pair within a
70     * existing SSLContext. Optionally specify clientCipherSuites to
71     * allow forcing new SSLSession to test SSLSessionContext
72     * caching. Optionally specify serverCipherSuites for testing
73     * cipher suite negotiation.
74     */
75    public static SSLEngine[] connect(final TestSSLContext c,
76                                      Hooks hooks,
77                                      boolean finished[]) throws IOException {
78        if (hooks == null) {
79            hooks = new Hooks();
80        }
81
82        // FINISHED state should be returned only once.
83        boolean[] clientFinished = new boolean[1];
84        boolean[] serverFinished = new boolean[1];
85
86        SSLSession session = c.clientContext.createSSLEngine().getSession();
87
88        int packetBufferSize = session.getPacketBufferSize();
89        ByteBuffer clientToServer = ByteBuffer.allocate(packetBufferSize);
90        ByteBuffer serverToClient = ByteBuffer.allocate(packetBufferSize);
91
92        int applicationBufferSize = session.getApplicationBufferSize();
93        ByteBuffer scratch = ByteBuffer.allocate(applicationBufferSize);
94
95        SSLEngine client = c.clientContext.createSSLEngine(c.host.getHostName(), c.port);
96        SSLEngine server = c.serverContext.createSSLEngine();
97        client.setUseClientMode(true);
98        server.setUseClientMode(false);
99        hooks.beforeBeginHandshake(client, server);
100        client.beginHandshake();
101        server.beginHandshake();
102
103        while (true) {
104            boolean clientDone = client.getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING;
105            boolean serverDone = server.getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING;
106            if (clientDone && serverDone) {
107                break;
108            }
109
110            boolean progress = false;
111            if (!clientDone) {
112                progress |= handshakeCompleted(client,
113                                               clientToServer,
114                                               serverToClient,
115                                               scratch,
116                                               clientFinished);
117            }
118            if (!serverDone) {
119                progress |= handshakeCompleted(server,
120                                               serverToClient,
121                                               clientToServer,
122                                               scratch,
123                                               serverFinished);
124            }
125            if (!progress) {
126                break;
127            }
128        }
129
130        if (finished != null) {
131            assertEquals(2, finished.length);
132            finished[0] = clientFinished[0];
133            finished[1] = serverFinished[0];
134        }
135        return new SSLEngine[] { server, client };
136    }
137
138    public static class Hooks {
139        void beforeBeginHandshake(SSLEngine client, SSLEngine server) {}
140    }
141
142    public void close() throws SSLException {
143        close(new SSLEngine[] { client, server });
144    }
145
146    public static void close(SSLEngine[] engines) {
147        try {
148            for (SSLEngine engine : engines) {
149                if (engine != null) {
150                    engine.closeInbound();
151                    engine.closeOutbound();
152                }
153            }
154        } catch (Exception e) {
155            throw new RuntimeException(e);
156        }
157    }
158
159    private static boolean handshakeCompleted(SSLEngine engine,
160                                              ByteBuffer output,
161                                              ByteBuffer input,
162                                              ByteBuffer scratch,
163                                              boolean[] finished) throws IOException {
164        try {
165            // make the other side's output into our input
166            input.flip();
167
168            HandshakeStatus status = engine.getHandshakeStatus();
169            switch (status) {
170
171                case NEED_TASK: {
172                    boolean progress = false;
173                    while (true) {
174                        Runnable runnable = engine.getDelegatedTask();
175                        if (runnable == null) {
176                            return progress;
177                        }
178                        runnable.run();
179                        progress = true;
180                    }
181                }
182
183                case NEED_UNWRAP: {
184                    // avoid underflow
185                    if (input.remaining() == 0) {
186                        return false;
187                    }
188                    int inputPositionBefore = input.position();
189                    SSLEngineResult unwrapResult = engine.unwrap(input, scratch);
190                    assertEquals(SSLEngineResult.Status.OK, unwrapResult.getStatus());
191                    assertEquals(0, scratch.position());
192                    assertEquals(0, unwrapResult.bytesProduced());
193                    assertEquals(input.position() - inputPositionBefore, unwrapResult.bytesConsumed());
194                    assertFinishedOnce(finished, unwrapResult);
195                    return true;
196                }
197
198                case NEED_WRAP: {
199                    // avoid possible overflow
200                    if (output.remaining() != output.capacity()) {
201                        return false;
202                    }
203                    ByteBuffer emptyByteBuffer = ByteBuffer.allocate(0);
204                    int inputPositionBefore = emptyByteBuffer.position();
205                    int outputPositionBefore = output.position();
206                    SSLEngineResult wrapResult = engine.wrap(emptyByteBuffer, output);
207                    assertEquals(SSLEngineResult.Status.OK, wrapResult.getStatus());
208                    assertEquals(0, wrapResult.bytesConsumed());
209                    assertEquals(inputPositionBefore, emptyByteBuffer.position());
210                    assertEquals(output.position() - outputPositionBefore,
211                            wrapResult.bytesProduced());
212                    assertFinishedOnce(finished, wrapResult);
213                    return true;
214                }
215
216                case NOT_HANDSHAKING:
217                    // should have been checked by caller before calling
218                case FINISHED:
219                    // only returned by wrap/unrap status, not getHandshakeStatus
220                    throw new IllegalStateException("Unexpected HandshakeStatus = " + status);
221                default:
222                    throw new IllegalStateException("Unknown HandshakeStatus = " + status);
223            }
224        } finally {
225            // shift consumed input, restore to output mode
226            input.compact();
227        }
228    }
229
230    private static void assertFinishedOnce(boolean[] finishedOut, SSLEngineResult result) {
231        if (result.getHandshakeStatus() == HandshakeStatus.FINISHED) {
232            assertFalse("should only return FINISHED once", finishedOut[0]);
233            finishedOut[0] = true;
234        }
235    }
236}
237