1/*
2 * Copyright (C) 2009 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package org.conscrypt;
18
19import java.io.ByteArrayInputStream;
20import java.io.ByteArrayOutputStream;
21import java.io.DataInputStream;
22import java.io.DataOutputStream;
23import java.io.IOException;
24import java.security.cert.Certificate;
25import java.security.cert.CertificateEncodingException;
26import java.security.cert.X509Certificate;
27import java.util.Arrays;
28import java.util.Enumeration;
29import java.util.Iterator;
30import java.util.LinkedHashMap;
31import java.util.Map;
32import java.util.NoSuchElementException;
33import javax.net.ssl.SSLSession;
34import javax.net.ssl.SSLSessionContext;
35
36/**
37 * Supports SSL session caches.
38 */
39abstract class AbstractSessionContext implements SSLSessionContext {
40
41    /**
42     * Maximum lifetime of a session (in seconds) after which it's considered invalid and should not
43     * be used to for new connections.
44     */
45    private static final int DEFAULT_SESSION_TIMEOUT_SECONDS = 8 * 60 * 60;
46
47    volatile int maximumSize;
48    volatile int timeout = DEFAULT_SESSION_TIMEOUT_SECONDS;
49
50    final long sslCtxNativePointer = NativeCrypto.SSL_CTX_new();
51
52    /** Identifies OpenSSL sessions. */
53    static final int OPEN_SSL = 1;
54
55    private final Map<ByteArray, SSLSession> sessions
56            = new LinkedHashMap<ByteArray, SSLSession>() {
57        @Override
58        protected boolean removeEldestEntry(
59                Map.Entry<ByteArray, SSLSession> eldest) {
60            boolean remove = maximumSize > 0 && size() > maximumSize;
61            if (remove) {
62                remove(eldest.getKey());
63                sessionRemoved(eldest.getValue());
64            }
65            return false;
66        }
67    };
68
69    /**
70     * Constructs a new session context.
71     *
72     * @param maximumSize of cache
73     */
74    AbstractSessionContext(int maximumSize) {
75        this.maximumSize = maximumSize;
76    }
77
78    /**
79     * Returns the collection of sessions ordered from oldest to newest
80     */
81    private Iterator<SSLSession> sessionIterator() {
82        synchronized (sessions) {
83            SSLSession[] array = sessions.values().toArray(
84                    new SSLSession[sessions.size()]);
85            return Arrays.asList(array).iterator();
86        }
87    }
88
89    @Override
90    public final Enumeration<byte[]> getIds() {
91        final Iterator<SSLSession> i = sessionIterator();
92        return new Enumeration<byte[]>() {
93            private SSLSession next;
94
95            @Override
96            public boolean hasMoreElements() {
97                if (next != null) {
98                    return true;
99                }
100                while (i.hasNext()) {
101                    SSLSession session = i.next();
102                    if (session.isValid()) {
103                        next = session;
104                        return true;
105                    }
106                }
107                next = null;
108                return false;
109            }
110
111            @Override
112            public byte[] nextElement() {
113                if (hasMoreElements()) {
114                    byte[] id = next.getId();
115                    next = null;
116                    return id;
117                }
118                throw new NoSuchElementException();
119            }
120        };
121    }
122
123    @Override
124    public final int getSessionCacheSize() {
125        return maximumSize;
126    }
127
128    @Override
129    public final int getSessionTimeout() {
130        return timeout;
131    }
132
133    /**
134     * Makes sure cache size is < maximumSize.
135     */
136    protected void trimToSize() {
137        synchronized (sessions) {
138            int size = sessions.size();
139            if (size > maximumSize) {
140                int removals = size - maximumSize;
141                Iterator<SSLSession> i = sessions.values().iterator();
142                do {
143                    SSLSession session = i.next();
144                    i.remove();
145                    sessionRemoved(session);
146                } while (--removals > 0);
147            }
148        }
149    }
150
151    @Override
152    public void setSessionTimeout(int seconds)
153            throws IllegalArgumentException {
154        if (seconds < 0) {
155            throw new IllegalArgumentException("seconds < 0");
156        }
157        timeout = seconds;
158
159        synchronized (sessions) {
160            Iterator<SSLSession> i = sessions.values().iterator();
161            while (i.hasNext()) {
162                SSLSession session = i.next();
163                // SSLSession's know their context and consult the
164                // timeout as part of their validity condition.
165                if (!session.isValid()) {
166                    i.remove();
167                    sessionRemoved(session);
168                }
169            }
170        }
171    }
172
173    /**
174     * Called when a session is removed. Used by ClientSessionContext
175     * to update its host-and-port based cache.
176     */
177    protected abstract void sessionRemoved(SSLSession session);
178
179    @Override
180    public final void setSessionCacheSize(int size)
181            throws IllegalArgumentException {
182        if (size < 0) {
183            throw new IllegalArgumentException("size < 0");
184        }
185
186        int oldMaximum = maximumSize;
187        maximumSize = size;
188
189        // Trim cache to size if necessary.
190        if (size < oldMaximum) {
191            trimToSize();
192        }
193    }
194
195    /**
196     * Converts the given session to bytes.
197     *
198     * @return session data as bytes or null if the session can't be converted
199     */
200    byte[] toBytes(SSLSession session) {
201        // TODO: Support SSLSessionImpl, too.
202        if (!(session instanceof OpenSSLSessionImpl)) {
203            return null;
204        }
205
206        OpenSSLSessionImpl sslSession = (OpenSSLSessionImpl) session;
207        try {
208            ByteArrayOutputStream baos = new ByteArrayOutputStream();
209            DataOutputStream daos = new DataOutputStream(baos);
210
211            daos.writeInt(OPEN_SSL); // session type ID
212
213            // Session data.
214            byte[] data = sslSession.getEncoded();
215            daos.writeInt(data.length);
216            daos.write(data);
217
218            // Certificates.
219            Certificate[] certs = session.getPeerCertificates();
220            daos.writeInt(certs.length);
221
222            for (Certificate cert : certs) {
223                data = cert.getEncoded();
224                daos.writeInt(data.length);
225                daos.write(data);
226            }
227            // TODO: local certificates?
228
229            return baos.toByteArray();
230        } catch (IOException e) {
231            log(e);
232            return null;
233        } catch (CertificateEncodingException e) {
234            log(e);
235            return null;
236        }
237    }
238
239    /**
240     * Creates a session from the given bytes.
241     *
242     * @return a session or null if the session can't be converted
243     */
244    SSLSession toSession(byte[] data, String host, int port) {
245        ByteArrayInputStream bais = new ByteArrayInputStream(data);
246        DataInputStream dais = new DataInputStream(bais);
247        try {
248            int type = dais.readInt();
249            if (type != OPEN_SSL) {
250                log(new AssertionError("Unexpected type ID: " + type));
251                return null;
252            }
253
254            int length = dais.readInt();
255            byte[] sessionData = new byte[length];
256            dais.readFully(sessionData);
257
258            int count = dais.readInt();
259            X509Certificate[] certs = new X509Certificate[count];
260            for (int i = 0; i < count; i++) {
261                length = dais.readInt();
262                byte[] certData = new byte[length];
263                dais.readFully(certData);
264                certs[i] = OpenSSLX509Certificate.fromX509Der(certData);
265            }
266
267            return new OpenSSLSessionImpl(sessionData, host, port, certs, this);
268        } catch (IOException e) {
269            log(e);
270            return null;
271        }
272    }
273
274    @Override
275    public SSLSession getSession(byte[] sessionId) {
276        if (sessionId == null) {
277            throw new NullPointerException("sessionId == null");
278        }
279        ByteArray key = new ByteArray(sessionId);
280        SSLSession session;
281        synchronized (sessions) {
282            session = sessions.get(key);
283        }
284        if (session != null && session.isValid()) {
285            return session;
286        }
287        return null;
288    }
289
290    void putSession(SSLSession session) {
291        byte[] id = session.getId();
292        if (id.length == 0) {
293            return;
294        }
295        ByteArray key = new ByteArray(id);
296        synchronized (sessions) {
297            sessions.put(key, session);
298        }
299    }
300
301    static void log(Throwable t) {
302        new Exception("Error converting session", t).printStackTrace();
303    }
304
305    @Override
306    protected void finalize() throws Throwable {
307        try {
308            NativeCrypto.SSL_CTX_free(sslCtxNativePointer);
309        } finally {
310            super.finalize();
311        }
312    }
313}
314