/* * Copyright (C) 2009 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.conscrypt; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; import java.security.cert.Certificate; import java.security.cert.CertificateEncodingException; import java.security.cert.X509Certificate; import java.util.Arrays; import java.util.Enumeration; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.Map; import java.util.NoSuchElementException; import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSessionContext; /** * Supports SSL session caches. */ abstract class AbstractSessionContext implements SSLSessionContext { /** * Maximum lifetime of a session (in seconds) after which it's considered invalid and should not * be used to for new connections. */ private static final int DEFAULT_SESSION_TIMEOUT_SECONDS = 8 * 60 * 60; volatile int maximumSize; volatile int timeout = DEFAULT_SESSION_TIMEOUT_SECONDS; final long sslCtxNativePointer = NativeCrypto.SSL_CTX_new(); /** Identifies OpenSSL sessions. */ static final int OPEN_SSL = 1; private final Map sessions = new LinkedHashMap() { @Override protected boolean removeEldestEntry( Map.Entry eldest) { boolean remove = maximumSize > 0 && size() > maximumSize; if (remove) { remove(eldest.getKey()); sessionRemoved(eldest.getValue()); } return false; } }; /** * Constructs a new session context. * * @param maximumSize of cache */ AbstractSessionContext(int maximumSize) { this.maximumSize = maximumSize; } /** * Returns the collection of sessions ordered from oldest to newest */ private Iterator sessionIterator() { synchronized (sessions) { SSLSession[] array = sessions.values().toArray( new SSLSession[sessions.size()]); return Arrays.asList(array).iterator(); } } @Override public final Enumeration getIds() { final Iterator i = sessionIterator(); return new Enumeration() { private SSLSession next; @Override public boolean hasMoreElements() { if (next != null) { return true; } while (i.hasNext()) { SSLSession session = i.next(); if (session.isValid()) { next = session; return true; } } next = null; return false; } @Override public byte[] nextElement() { if (hasMoreElements()) { byte[] id = next.getId(); next = null; return id; } throw new NoSuchElementException(); } }; } @Override public final int getSessionCacheSize() { return maximumSize; } @Override public final int getSessionTimeout() { return timeout; } /** * Makes sure cache size is < maximumSize. */ protected void trimToSize() { synchronized (sessions) { int size = sessions.size(); if (size > maximumSize) { int removals = size - maximumSize; Iterator i = sessions.values().iterator(); do { SSLSession session = i.next(); i.remove(); sessionRemoved(session); } while (--removals > 0); } } } @Override public void setSessionTimeout(int seconds) throws IllegalArgumentException { if (seconds < 0) { throw new IllegalArgumentException("seconds < 0"); } timeout = seconds; synchronized (sessions) { Iterator i = sessions.values().iterator(); while (i.hasNext()) { SSLSession session = i.next(); // SSLSession's know their context and consult the // timeout as part of their validity condition. if (!session.isValid()) { i.remove(); sessionRemoved(session); } } } } /** * Called when a session is removed. Used by ClientSessionContext * to update its host-and-port based cache. */ protected abstract void sessionRemoved(SSLSession session); @Override public final void setSessionCacheSize(int size) throws IllegalArgumentException { if (size < 0) { throw new IllegalArgumentException("size < 0"); } int oldMaximum = maximumSize; maximumSize = size; // Trim cache to size if necessary. if (size < oldMaximum) { trimToSize(); } } /** * Converts the given session to bytes. * * @return session data as bytes or null if the session can't be converted */ byte[] toBytes(SSLSession session) { // TODO: Support SSLSessionImpl, too. if (!(session instanceof OpenSSLSessionImpl)) { return null; } OpenSSLSessionImpl sslSession = (OpenSSLSessionImpl) session; try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); DataOutputStream daos = new DataOutputStream(baos); daos.writeInt(OPEN_SSL); // session type ID // Session data. byte[] data = sslSession.getEncoded(); daos.writeInt(data.length); daos.write(data); // Certificates. Certificate[] certs = session.getPeerCertificates(); daos.writeInt(certs.length); for (Certificate cert : certs) { data = cert.getEncoded(); daos.writeInt(data.length); daos.write(data); } // TODO: local certificates? return baos.toByteArray(); } catch (IOException e) { log(e); return null; } catch (CertificateEncodingException e) { log(e); return null; } } /** * Creates a session from the given bytes. * * @return a session or null if the session can't be converted */ SSLSession toSession(byte[] data, String host, int port) { ByteArrayInputStream bais = new ByteArrayInputStream(data); DataInputStream dais = new DataInputStream(bais); try { int type = dais.readInt(); if (type != OPEN_SSL) { log(new AssertionError("Unexpected type ID: " + type)); return null; } int length = dais.readInt(); byte[] sessionData = new byte[length]; dais.readFully(sessionData); int count = dais.readInt(); X509Certificate[] certs = new X509Certificate[count]; for (int i = 0; i < count; i++) { length = dais.readInt(); byte[] certData = new byte[length]; dais.readFully(certData); certs[i] = OpenSSLX509Certificate.fromX509Der(certData); } return new OpenSSLSessionImpl(sessionData, host, port, certs, this); } catch (IOException e) { log(e); return null; } } @Override public SSLSession getSession(byte[] sessionId) { if (sessionId == null) { throw new NullPointerException("sessionId == null"); } ByteArray key = new ByteArray(sessionId); SSLSession session; synchronized (sessions) { session = sessions.get(key); } if (session != null && session.isValid()) { return session; } return null; } void putSession(SSLSession session) { byte[] id = session.getId(); if (id.length == 0) { return; } ByteArray key = new ByteArray(id); synchronized (sessions) { sessions.put(key, session); } } static void log(Throwable t) { new Exception("Error converting session", t).printStackTrace(); } @Override protected void finalize() throws Throwable { try { NativeCrypto.SSL_CTX_free(sslCtxNativePointer); } finally { super.finalize(); } } }