1/* 2 * Copyright (C) 2009 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package org.conscrypt; 18 19import java.io.ByteArrayOutputStream; 20import java.io.DataOutputStream; 21import java.io.IOException; 22import java.nio.BufferUnderflowException; 23import java.nio.ByteBuffer; 24import java.security.cert.Certificate; 25import java.security.cert.CertificateEncodingException; 26import java.security.cert.X509Certificate; 27import java.util.Arrays; 28import java.util.Enumeration; 29import java.util.Iterator; 30import java.util.LinkedHashMap; 31import java.util.List; 32import java.util.Map; 33import java.util.NoSuchElementException; 34import javax.net.ssl.SSLSession; 35import javax.net.ssl.SSLSessionContext; 36 37/** 38 * Supports SSL session caches. 39 */ 40abstract class AbstractSessionContext implements SSLSessionContext { 41 42 /** 43 * Maximum lifetime of a session (in seconds) after which it's considered invalid and should not 44 * be used to for new connections. 45 */ 46 private static final int DEFAULT_SESSION_TIMEOUT_SECONDS = 8 * 60 * 60; 47 48 volatile int maximumSize; 49 volatile int timeout = DEFAULT_SESSION_TIMEOUT_SECONDS; 50 51 final long sslCtxNativePointer = NativeCrypto.SSL_CTX_new(); 52 53 /** Identifies OpenSSL sessions. */ 54 static final int OPEN_SSL = 1; 55 56 /** Identifies OpenSSL sessions with OCSP stapled data. */ 57 static final int OPEN_SSL_WITH_OCSP = 2; 58 59 /** Identifies OpenSSL sessions with TLS SCT data. */ 60 static final int OPEN_SSL_WITH_TLS_SCT = 3; 61 62 @SuppressWarnings("serial") 63 private final Map<ByteArray, SSLSession> sessions = new LinkedHashMap<ByteArray, SSLSession>() { 64 @Override 65 protected boolean removeEldestEntry( 66 Map.Entry<ByteArray, SSLSession> eldest) { 67 boolean remove = maximumSize > 0 && size() > maximumSize; 68 if (remove) { 69 remove(eldest.getKey()); 70 sessionRemoved(eldest.getValue()); 71 } 72 return false; 73 } 74 }; 75 76 /** 77 * Constructs a new session context. 78 * 79 * @param maximumSize of cache 80 */ 81 AbstractSessionContext(int maximumSize) { 82 this.maximumSize = maximumSize; 83 } 84 85 /** 86 * Returns the collection of sessions ordered from oldest to newest 87 */ 88 private Iterator<SSLSession> sessionIterator() { 89 synchronized (sessions) { 90 SSLSession[] array = sessions.values().toArray( 91 new SSLSession[sessions.size()]); 92 return Arrays.asList(array).iterator(); 93 } 94 } 95 96 @Override 97 public final Enumeration<byte[]> getIds() { 98 final Iterator<SSLSession> i = sessionIterator(); 99 return new Enumeration<byte[]>() { 100 private SSLSession next; 101 102 @Override 103 public boolean hasMoreElements() { 104 if (next != null) { 105 return true; 106 } 107 while (i.hasNext()) { 108 SSLSession session = i.next(); 109 if (session.isValid()) { 110 next = session; 111 return true; 112 } 113 } 114 next = null; 115 return false; 116 } 117 118 @Override 119 public byte[] nextElement() { 120 if (hasMoreElements()) { 121 byte[] id = next.getId(); 122 next = null; 123 return id; 124 } 125 throw new NoSuchElementException(); 126 } 127 }; 128 } 129 130 @Override 131 public final int getSessionCacheSize() { 132 return maximumSize; 133 } 134 135 @Override 136 public final int getSessionTimeout() { 137 return timeout; 138 } 139 140 /** 141 * Makes sure cache size is < maximumSize. 142 */ 143 protected void trimToSize() { 144 synchronized (sessions) { 145 int size = sessions.size(); 146 if (size > maximumSize) { 147 int removals = size - maximumSize; 148 Iterator<SSLSession> i = sessions.values().iterator(); 149 do { 150 SSLSession session = i.next(); 151 i.remove(); 152 sessionRemoved(session); 153 } while (--removals > 0); 154 } 155 } 156 } 157 158 @Override 159 public void setSessionTimeout(int seconds) 160 throws IllegalArgumentException { 161 if (seconds < 0) { 162 throw new IllegalArgumentException("seconds < 0"); 163 } 164 timeout = seconds; 165 166 synchronized (sessions) { 167 Iterator<SSLSession> i = sessions.values().iterator(); 168 while (i.hasNext()) { 169 SSLSession session = i.next(); 170 // SSLSession's know their context and consult the 171 // timeout as part of their validity condition. 172 if (!session.isValid()) { 173 i.remove(); 174 sessionRemoved(session); 175 } 176 } 177 } 178 } 179 180 /** 181 * Called when a session is removed. Used by ClientSessionContext 182 * to update its host-and-port based cache. 183 */ 184 protected abstract void sessionRemoved(SSLSession session); 185 186 @Override 187 public final void setSessionCacheSize(int size) 188 throws IllegalArgumentException { 189 if (size < 0) { 190 throw new IllegalArgumentException("size < 0"); 191 } 192 193 int oldMaximum = maximumSize; 194 maximumSize = size; 195 196 // Trim cache to size if necessary. 197 if (size < oldMaximum) { 198 trimToSize(); 199 } 200 } 201 202 /** 203 * Converts the given session to bytes. 204 * 205 * @return session data as bytes or null if the session can't be converted 206 */ 207 public byte[] toBytes(SSLSession session) { 208 // TODO: Support SSLSessionImpl, too. 209 if (!(session instanceof OpenSSLSessionImpl)) { 210 return null; 211 } 212 213 OpenSSLSessionImpl sslSession = (OpenSSLSessionImpl) session; 214 try { 215 ByteArrayOutputStream baos = new ByteArrayOutputStream(); 216 DataOutputStream daos = new DataOutputStream(baos); 217 218 daos.writeInt(OPEN_SSL_WITH_TLS_SCT); // session type ID 219 220 // Session data. 221 byte[] data = sslSession.getEncoded(); 222 daos.writeInt(data.length); 223 daos.write(data); 224 225 // Certificates. 226 Certificate[] certs = session.getPeerCertificates(); 227 daos.writeInt(certs.length); 228 229 for (Certificate cert : certs) { 230 data = cert.getEncoded(); 231 daos.writeInt(data.length); 232 daos.write(data); 233 } 234 235 List<byte[]> ocspResponses = sslSession.getStatusResponses(); 236 daos.writeInt(ocspResponses.size()); 237 for (byte[] ocspResponse : ocspResponses) { 238 daos.writeInt(ocspResponse.length); 239 daos.write(ocspResponse); 240 } 241 242 byte[] tlsSctData = sslSession.getTlsSctData(); 243 if (tlsSctData != null) { 244 daos.writeInt(tlsSctData.length); 245 daos.write(tlsSctData); 246 } else { 247 daos.writeInt(0); 248 } 249 250 // TODO: local certificates? 251 252 return baos.toByteArray(); 253 } catch (IOException e) { 254 System.err.println("Failed to convert saved SSL Session: " + e.getMessage()); 255 return null; 256 } catch (CertificateEncodingException e) { 257 log(e); 258 return null; 259 } 260 } 261 262 private static void checkRemaining(ByteBuffer buf, int length) throws IOException { 263 if (length < 0) { 264 throw new IOException("Length is negative: " + length); 265 } 266 if (length > buf.remaining()) { 267 throw new IOException( 268 "Length of blob is longer than available: " + length + " > " + buf.remaining()); 269 } 270 } 271 272 /** 273 * Creates a session from the given bytes. 274 * 275 * @return a session or null if the session can't be converted 276 */ 277 public OpenSSLSessionImpl toSession(byte[] data, String host, int port) { 278 ByteBuffer buf = ByteBuffer.wrap(data); 279 try { 280 int type = buf.getInt(); 281 if (type != OPEN_SSL && type != OPEN_SSL_WITH_OCSP && type != OPEN_SSL_WITH_TLS_SCT) { 282 throw new IOException("Unexpected type ID: " + type); 283 } 284 285 int length = buf.getInt(); 286 checkRemaining(buf, length); 287 288 byte[] sessionData = new byte[length]; 289 buf.get(sessionData); 290 291 int count = buf.getInt(); 292 checkRemaining(buf, count); 293 294 X509Certificate[] certs = new X509Certificate[count]; 295 for (int i = 0; i < count; i++) { 296 length = buf.getInt(); 297 checkRemaining(buf, length); 298 299 byte[] certData = new byte[length]; 300 buf.get(certData); 301 try { 302 certs[i] = OpenSSLX509Certificate.fromX509Der(certData); 303 } catch (Exception e) { 304 throw new IOException("Can not read certificate " + i + "/" + count); 305 } 306 } 307 308 byte[] ocspData = null; 309 if (type >= OPEN_SSL_WITH_OCSP) { 310 // We only support one OCSP response now, but in the future 311 // we may support RFC 6961 which has multiple. 312 int countOcspResponses = buf.getInt(); 313 checkRemaining(buf, countOcspResponses); 314 315 if (countOcspResponses >= 1) { 316 int ocspLength = buf.getInt(); 317 checkRemaining(buf, ocspLength); 318 319 ocspData = new byte[ocspLength]; 320 buf.get(ocspData); 321 322 // Skip the rest of the responses. 323 for (int i = 1; i < countOcspResponses; i++) { 324 ocspLength = buf.getInt(); 325 checkRemaining(buf, ocspLength); 326 buf.position(buf.position() + ocspLength); 327 } 328 } 329 } 330 331 byte[] tlsSctData = null; 332 if (type == OPEN_SSL_WITH_TLS_SCT) { 333 int tlsSctDataLength = buf.getInt(); 334 checkRemaining(buf, tlsSctDataLength); 335 336 if (tlsSctDataLength > 0) { 337 tlsSctData = new byte[tlsSctDataLength]; 338 buf.get(tlsSctData); 339 } 340 } 341 342 if (buf.remaining() != 0) { 343 log(new AssertionError("Read entire session, but data still remains; rejecting")); 344 return null; 345 } 346 347 return new OpenSSLSessionImpl(sessionData, host, port, certs, ocspData, tlsSctData, 348 this); 349 } catch (IOException e) { 350 log(e); 351 return null; 352 } catch (BufferUnderflowException e) { 353 log(e); 354 return null; 355 } 356 } 357 358 protected SSLSession wrapSSLSessionIfNeeded(SSLSession session) { 359 if (session instanceof AbstractOpenSSLSession) { 360 return Platform.wrapSSLSession((AbstractOpenSSLSession) session); 361 } else { 362 return session; 363 } 364 } 365 366 @Override 367 public SSLSession getSession(byte[] sessionId) { 368 if (sessionId == null) { 369 throw new NullPointerException("sessionId == null"); 370 } 371 ByteArray key = new ByteArray(sessionId); 372 SSLSession session; 373 synchronized (sessions) { 374 session = sessions.get(key); 375 } 376 if (session != null && session.isValid()) { 377 return wrapSSLSessionIfNeeded(session); 378 } 379 return null; 380 } 381 382 void putSession(SSLSession session) { 383 byte[] id = session.getId(); 384 if (id.length == 0) { 385 return; 386 } 387 ByteArray key = new ByteArray(id); 388 synchronized (sessions) { 389 sessions.put(key, session); 390 } 391 } 392 393 static void log(Throwable t) { 394 System.out.println("Error inflating SSL session: " 395 + (t.getMessage() != null ? t.getMessage() : t.getClass().getName())); 396 } 397 398 @Override 399 protected void finalize() throws Throwable { 400 try { 401 NativeCrypto.SSL_CTX_free(sslCtxNativePointer); 402 } finally { 403 super.finalize(); 404 } 405 } 406} 407