1/* 2 * Copyright (C) 2010 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package libcore.javax.net.ssl; 18 19import java.io.IOException; 20import java.nio.ByteBuffer; 21import javax.net.ssl.SSLEngine; 22import javax.net.ssl.SSLEngineResult; 23import javax.net.ssl.SSLEngineResult.HandshakeStatus; 24import javax.net.ssl.SSLException; 25import javax.net.ssl.SSLSession; 26import junit.framework.Assert; 27 28/** 29 * TestSSLEnginePair is a convenience class for other tests that want 30 * a pair of connected and handshaked client and server SSLEngines for 31 * testing. 32 */ 33public final class TestSSLEnginePair extends Assert implements AutoCloseable { 34 public final TestSSLContext c; 35 public final SSLEngine server; 36 public final SSLEngine client; 37 38 private TestSSLEnginePair(TestSSLContext c, 39 SSLEngine server, 40 SSLEngine client) { 41 this.c = c; 42 this.server = server; 43 this.client = client; 44 } 45 46 public static TestSSLEnginePair create() throws IOException { 47 return create(null); 48 } 49 50 public static TestSSLEnginePair create(Hooks hooks) throws IOException { 51 return create(TestSSLContext.create(), hooks); 52 } 53 54 public static TestSSLEnginePair create(TestSSLContext c, Hooks hooks) throws IOException { 55 return create(c, hooks, null); 56 } 57 58 public static TestSSLEnginePair create(TestSSLContext c, Hooks hooks, boolean[] finished) 59 throws IOException { 60 SSLEngine[] engines = connect(c, hooks, finished); 61 return new TestSSLEnginePair(c, engines[0], engines[1]); 62 } 63 64 public static SSLEngine[] connect(TestSSLContext c, Hooks hooks) throws IOException { 65 return connect(c, hooks, null); 66 } 67 68 /** 69 * Create a new connected server/client engine pair within a 70 * existing SSLContext. Optionally specify clientCipherSuites to 71 * allow forcing new SSLSession to test SSLSessionContext 72 * caching. Optionally specify serverCipherSuites for testing 73 * cipher suite negotiation. 74 */ 75 public static SSLEngine[] connect(final TestSSLContext c, 76 Hooks hooks, 77 boolean finished[]) throws IOException { 78 if (hooks == null) { 79 hooks = new Hooks(); 80 } 81 82 // FINISHED state should be returned only once. 83 boolean[] clientFinished = new boolean[1]; 84 boolean[] serverFinished = new boolean[1]; 85 86 SSLSession session = c.clientContext.createSSLEngine().getSession(); 87 88 int packetBufferSize = session.getPacketBufferSize(); 89 ByteBuffer clientToServer = ByteBuffer.allocate(packetBufferSize); 90 ByteBuffer serverToClient = ByteBuffer.allocate(packetBufferSize); 91 92 int applicationBufferSize = session.getApplicationBufferSize(); 93 ByteBuffer scratch = ByteBuffer.allocate(applicationBufferSize); 94 95 SSLEngine client = c.clientContext.createSSLEngine(c.host.getHostName(), c.port); 96 SSLEngine server = c.serverContext.createSSLEngine(); 97 client.setUseClientMode(true); 98 server.setUseClientMode(false); 99 hooks.beforeBeginHandshake(client, server); 100 client.beginHandshake(); 101 server.beginHandshake(); 102 103 while (true) { 104 boolean clientDone = client.getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING; 105 boolean serverDone = server.getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING; 106 if (clientDone && serverDone) { 107 break; 108 } 109 110 boolean progress = false; 111 if (!clientDone) { 112 progress |= handshakeCompleted(client, 113 clientToServer, 114 serverToClient, 115 scratch, 116 clientFinished); 117 } 118 if (!serverDone) { 119 progress |= handshakeCompleted(server, 120 serverToClient, 121 clientToServer, 122 scratch, 123 serverFinished); 124 } 125 if (!progress) { 126 break; 127 } 128 } 129 130 if (finished != null) { 131 assertEquals(2, finished.length); 132 finished[0] = clientFinished[0]; 133 finished[1] = serverFinished[0]; 134 } 135 return new SSLEngine[] { server, client }; 136 } 137 138 public static class Hooks { 139 void beforeBeginHandshake(SSLEngine client, SSLEngine server) {} 140 } 141 142 public void close() throws SSLException { 143 close(new SSLEngine[] { client, server }); 144 } 145 146 public static void close(SSLEngine[] engines) { 147 try { 148 for (SSLEngine engine : engines) { 149 if (engine != null) { 150 engine.closeInbound(); 151 engine.closeOutbound(); 152 } 153 } 154 } catch (Exception e) { 155 throw new RuntimeException(e); 156 } 157 } 158 159 private static boolean handshakeCompleted(SSLEngine engine, 160 ByteBuffer output, 161 ByteBuffer input, 162 ByteBuffer scratch, 163 boolean[] finished) throws IOException { 164 try { 165 // make the other side's output into our input 166 input.flip(); 167 168 HandshakeStatus status = engine.getHandshakeStatus(); 169 switch (status) { 170 171 case NEED_TASK: { 172 boolean progress = false; 173 while (true) { 174 Runnable runnable = engine.getDelegatedTask(); 175 if (runnable == null) { 176 return progress; 177 } 178 runnable.run(); 179 progress = true; 180 } 181 } 182 183 case NEED_UNWRAP: { 184 // avoid underflow 185 if (input.remaining() == 0) { 186 return false; 187 } 188 int inputPositionBefore = input.position(); 189 SSLEngineResult unwrapResult = engine.unwrap(input, scratch); 190 assertEquals(SSLEngineResult.Status.OK, unwrapResult.getStatus()); 191 assertEquals(0, scratch.position()); 192 assertEquals(0, unwrapResult.bytesProduced()); 193 assertEquals(input.position() - inputPositionBefore, unwrapResult.bytesConsumed()); 194 assertFinishedOnce(finished, unwrapResult); 195 return true; 196 } 197 198 case NEED_WRAP: { 199 // avoid possible overflow 200 if (output.remaining() != output.capacity()) { 201 return false; 202 } 203 ByteBuffer emptyByteBuffer = ByteBuffer.allocate(0); 204 int inputPositionBefore = emptyByteBuffer.position(); 205 int outputPositionBefore = output.position(); 206 SSLEngineResult wrapResult = engine.wrap(emptyByteBuffer, output); 207 assertEquals(SSLEngineResult.Status.OK, wrapResult.getStatus()); 208 assertEquals(0, wrapResult.bytesConsumed()); 209 assertEquals(inputPositionBefore, emptyByteBuffer.position()); 210 assertEquals(output.position() - outputPositionBefore, 211 wrapResult.bytesProduced()); 212 assertFinishedOnce(finished, wrapResult); 213 return true; 214 } 215 216 case NOT_HANDSHAKING: 217 // should have been checked by caller before calling 218 case FINISHED: 219 // only returned by wrap/unrap status, not getHandshakeStatus 220 throw new IllegalStateException("Unexpected HandshakeStatus = " + status); 221 default: 222 throw new IllegalStateException("Unknown HandshakeStatus = " + status); 223 } 224 } finally { 225 // shift consumed input, restore to output mode 226 input.compact(); 227 } 228 } 229 230 private static void assertFinishedOnce(boolean[] finishedOut, SSLEngineResult result) { 231 if (result.getHandshakeStatus() == HandshakeStatus.FINISHED) { 232 assertFalse("should only return FINISHED once", finishedOut[0]); 233 finishedOut[0] = true; 234 } 235 } 236} 237