1/*
2 * Copyright (C) 2009 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package org.conscrypt;
18
19import java.io.ByteArrayOutputStream;
20import java.io.DataOutputStream;
21import java.io.IOException;
22import java.nio.BufferUnderflowException;
23import java.nio.ByteBuffer;
24import java.security.cert.Certificate;
25import java.security.cert.CertificateEncodingException;
26import java.security.cert.X509Certificate;
27import java.util.Arrays;
28import java.util.Enumeration;
29import java.util.Iterator;
30import java.util.LinkedHashMap;
31import java.util.List;
32import java.util.Map;
33import java.util.NoSuchElementException;
34import javax.net.ssl.SSLSession;
35import javax.net.ssl.SSLSessionContext;
36
37/**
38 * Supports SSL session caches.
39 */
40abstract class AbstractSessionContext implements SSLSessionContext {
41
42    /**
43     * Maximum lifetime of a session (in seconds) after which it's considered invalid and should not
44     * be used to for new connections.
45     */
46    private static final int DEFAULT_SESSION_TIMEOUT_SECONDS = 8 * 60 * 60;
47
48    volatile int maximumSize;
49    volatile int timeout = DEFAULT_SESSION_TIMEOUT_SECONDS;
50
51    final long sslCtxNativePointer = NativeCrypto.SSL_CTX_new();
52
53    /** Identifies OpenSSL sessions. */
54    static final int OPEN_SSL = 1;
55
56    /** Identifies OpenSSL sessions with OCSP stapled data. */
57    static final int OPEN_SSL_WITH_OCSP = 2;
58
59    /** Identifies OpenSSL sessions with TLS SCT data. */
60    static final int OPEN_SSL_WITH_TLS_SCT = 3;
61
62    @SuppressWarnings("serial")
63    private final Map<ByteArray, SSLSession> sessions = new LinkedHashMap<ByteArray, SSLSession>() {
64        @Override
65        protected boolean removeEldestEntry(
66                Map.Entry<ByteArray, SSLSession> eldest) {
67            boolean remove = maximumSize > 0 && size() > maximumSize;
68            if (remove) {
69                remove(eldest.getKey());
70                sessionRemoved(eldest.getValue());
71            }
72            return false;
73        }
74    };
75
76    /**
77     * Constructs a new session context.
78     *
79     * @param maximumSize of cache
80     */
81    AbstractSessionContext(int maximumSize) {
82        this.maximumSize = maximumSize;
83    }
84
85    /**
86     * Returns the collection of sessions ordered from oldest to newest
87     */
88    private Iterator<SSLSession> sessionIterator() {
89        synchronized (sessions) {
90            SSLSession[] array = sessions.values().toArray(
91                    new SSLSession[sessions.size()]);
92            return Arrays.asList(array).iterator();
93        }
94    }
95
96    @Override
97    public final Enumeration<byte[]> getIds() {
98        final Iterator<SSLSession> i = sessionIterator();
99        return new Enumeration<byte[]>() {
100            private SSLSession next;
101
102            @Override
103            public boolean hasMoreElements() {
104                if (next != null) {
105                    return true;
106                }
107                while (i.hasNext()) {
108                    SSLSession session = i.next();
109                    if (session.isValid()) {
110                        next = session;
111                        return true;
112                    }
113                }
114                next = null;
115                return false;
116            }
117
118            @Override
119            public byte[] nextElement() {
120                if (hasMoreElements()) {
121                    byte[] id = next.getId();
122                    next = null;
123                    return id;
124                }
125                throw new NoSuchElementException();
126            }
127        };
128    }
129
130    @Override
131    public final int getSessionCacheSize() {
132        return maximumSize;
133    }
134
135    @Override
136    public final int getSessionTimeout() {
137        return timeout;
138    }
139
140    /**
141     * Makes sure cache size is < maximumSize.
142     */
143    protected void trimToSize() {
144        synchronized (sessions) {
145            int size = sessions.size();
146            if (size > maximumSize) {
147                int removals = size - maximumSize;
148                Iterator<SSLSession> i = sessions.values().iterator();
149                do {
150                    SSLSession session = i.next();
151                    i.remove();
152                    sessionRemoved(session);
153                } while (--removals > 0);
154            }
155        }
156    }
157
158    @Override
159    public void setSessionTimeout(int seconds)
160            throws IllegalArgumentException {
161        if (seconds < 0) {
162            throw new IllegalArgumentException("seconds < 0");
163        }
164        timeout = seconds;
165
166        synchronized (sessions) {
167            Iterator<SSLSession> i = sessions.values().iterator();
168            while (i.hasNext()) {
169                SSLSession session = i.next();
170                // SSLSession's know their context and consult the
171                // timeout as part of their validity condition.
172                if (!session.isValid()) {
173                    i.remove();
174                    sessionRemoved(session);
175                }
176            }
177        }
178    }
179
180    /**
181     * Called when a session is removed. Used by ClientSessionContext
182     * to update its host-and-port based cache.
183     */
184    protected abstract void sessionRemoved(SSLSession session);
185
186    @Override
187    public final void setSessionCacheSize(int size)
188            throws IllegalArgumentException {
189        if (size < 0) {
190            throw new IllegalArgumentException("size < 0");
191        }
192
193        int oldMaximum = maximumSize;
194        maximumSize = size;
195
196        // Trim cache to size if necessary.
197        if (size < oldMaximum) {
198            trimToSize();
199        }
200    }
201
202    /**
203     * Converts the given session to bytes.
204     *
205     * @return session data as bytes or null if the session can't be converted
206     */
207    public byte[] toBytes(SSLSession session) {
208        // TODO: Support SSLSessionImpl, too.
209        if (!(session instanceof OpenSSLSessionImpl)) {
210            return null;
211        }
212
213        OpenSSLSessionImpl sslSession = (OpenSSLSessionImpl) session;
214        try {
215            ByteArrayOutputStream baos = new ByteArrayOutputStream();
216            DataOutputStream daos = new DataOutputStream(baos);
217
218            daos.writeInt(OPEN_SSL_WITH_TLS_SCT); // session type ID
219
220            // Session data.
221            byte[] data = sslSession.getEncoded();
222            daos.writeInt(data.length);
223            daos.write(data);
224
225            // Certificates.
226            Certificate[] certs = session.getPeerCertificates();
227            daos.writeInt(certs.length);
228
229            for (Certificate cert : certs) {
230                data = cert.getEncoded();
231                daos.writeInt(data.length);
232                daos.write(data);
233            }
234
235            List<byte[]> ocspResponses = sslSession.getStatusResponses();
236            daos.writeInt(ocspResponses.size());
237            for (byte[] ocspResponse : ocspResponses) {
238                daos.writeInt(ocspResponse.length);
239                daos.write(ocspResponse);
240            }
241
242            byte[] tlsSctData = sslSession.getTlsSctData();
243            if (tlsSctData != null) {
244                daos.writeInt(tlsSctData.length);
245                daos.write(tlsSctData);
246            } else {
247                daos.writeInt(0);
248            }
249
250            // TODO: local certificates?
251
252            return baos.toByteArray();
253        } catch (IOException e) {
254            System.err.println("Failed to convert saved SSL Session: " + e.getMessage());
255            return null;
256        } catch (CertificateEncodingException e) {
257            log(e);
258            return null;
259        }
260    }
261
262    private static void checkRemaining(ByteBuffer buf, int length) throws IOException {
263        if (length < 0) {
264            throw new IOException("Length is negative: " + length);
265        }
266        if (length > buf.remaining()) {
267            throw new IOException(
268                    "Length of blob is longer than available: " + length + " > " + buf.remaining());
269        }
270    }
271
272    /**
273     * Creates a session from the given bytes.
274     *
275     * @return a session or null if the session can't be converted
276     */
277    public OpenSSLSessionImpl toSession(byte[] data, String host, int port) {
278        ByteBuffer buf = ByteBuffer.wrap(data);
279        try {
280            int type = buf.getInt();
281            if (type != OPEN_SSL && type != OPEN_SSL_WITH_OCSP && type != OPEN_SSL_WITH_TLS_SCT) {
282                throw new IOException("Unexpected type ID: " + type);
283            }
284
285            int length = buf.getInt();
286            checkRemaining(buf, length);
287
288            byte[] sessionData = new byte[length];
289            buf.get(sessionData);
290
291            int count = buf.getInt();
292            checkRemaining(buf, count);
293
294            X509Certificate[] certs = new X509Certificate[count];
295            for (int i = 0; i < count; i++) {
296                length = buf.getInt();
297                checkRemaining(buf, length);
298
299                byte[] certData = new byte[length];
300                buf.get(certData);
301                try {
302                    certs[i] = OpenSSLX509Certificate.fromX509Der(certData);
303                } catch (Exception e) {
304                    throw new IOException("Can not read certificate " + i + "/" + count);
305                }
306            }
307
308            byte[] ocspData = null;
309            if (type >= OPEN_SSL_WITH_OCSP) {
310                // We only support one OCSP response now, but in the future
311                // we may support RFC 6961 which has multiple.
312                int countOcspResponses = buf.getInt();
313                checkRemaining(buf, countOcspResponses);
314
315                if (countOcspResponses >= 1) {
316                    int ocspLength = buf.getInt();
317                    checkRemaining(buf, ocspLength);
318
319                    ocspData = new byte[ocspLength];
320                    buf.get(ocspData);
321
322                    // Skip the rest of the responses.
323                    for (int i = 1; i < countOcspResponses; i++) {
324                        ocspLength = buf.getInt();
325                        checkRemaining(buf, ocspLength);
326                        buf.position(buf.position() + ocspLength);
327                    }
328                }
329            }
330
331            byte[] tlsSctData = null;
332            if (type == OPEN_SSL_WITH_TLS_SCT) {
333                int tlsSctDataLength = buf.getInt();
334                checkRemaining(buf, tlsSctDataLength);
335
336                if (tlsSctDataLength > 0) {
337                    tlsSctData = new byte[tlsSctDataLength];
338                    buf.get(tlsSctData);
339                }
340            }
341
342            if (buf.remaining() != 0) {
343                log(new AssertionError("Read entire session, but data still remains; rejecting"));
344                return null;
345            }
346
347            return new OpenSSLSessionImpl(sessionData, host, port, certs, ocspData, tlsSctData,
348                    this);
349        } catch (IOException e) {
350            log(e);
351            return null;
352        } catch (BufferUnderflowException e) {
353            log(e);
354            return null;
355        }
356    }
357
358    protected SSLSession wrapSSLSessionIfNeeded(SSLSession session) {
359        if (session instanceof AbstractOpenSSLSession) {
360            return Platform.wrapSSLSession((AbstractOpenSSLSession) session);
361        } else {
362            return session;
363        }
364    }
365
366    @Override
367    public SSLSession getSession(byte[] sessionId) {
368        if (sessionId == null) {
369            throw new NullPointerException("sessionId == null");
370        }
371        ByteArray key = new ByteArray(sessionId);
372        SSLSession session;
373        synchronized (sessions) {
374            session = sessions.get(key);
375        }
376        if (session != null && session.isValid()) {
377            return wrapSSLSessionIfNeeded(session);
378        }
379        return null;
380    }
381
382    void putSession(SSLSession session) {
383        byte[] id = session.getId();
384        if (id.length == 0) {
385            return;
386        }
387        ByteArray key = new ByteArray(id);
388        synchronized (sessions) {
389            sessions.put(key, session);
390        }
391    }
392
393    static void log(Throwable t) {
394        System.out.println("Error inflating SSL session: "
395                + (t.getMessage() != null ? t.getMessage() : t.getClass().getName()));
396    }
397
398    @Override
399    protected void finalize() throws Throwable {
400        try {
401            NativeCrypto.SSL_CTX_free(sslCtxNativePointer);
402        } finally {
403            super.finalize();
404        }
405    }
406}
407