AbstractSessionContext.java revision 9acacc36bafda869c6e9cc63786cdddd995ca96a
1/*
2 * Copyright (C) 2009 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package org.apache.harmony.xnet.provider.jsse;
18
19import java.io.ByteArrayInputStream;
20import java.io.ByteArrayOutputStream;
21import java.io.DataInputStream;
22import java.io.DataOutputStream;
23import java.io.IOException;
24import java.util.Arrays;
25import java.util.Enumeration;
26import java.util.Iterator;
27import java.util.LinkedHashMap;
28import java.util.Map;
29import java.util.NoSuchElementException;
30import java.util.logging.Level;
31import javax.net.ssl.SSLSession;
32import javax.net.ssl.SSLSessionContext;
33import javax.security.cert.CertificateEncodingException;
34import javax.security.cert.CertificateException;
35import javax.security.cert.X509Certificate;
36
37/**
38 * Supports SSL session caches.
39 */
40abstract class AbstractSessionContext implements SSLSessionContext {
41
42    volatile int maximumSize;
43    volatile int timeout;
44
45    final int sslCtxNativePointer = NativeCrypto.SSL_CTX_new();
46
47    /** Identifies OpenSSL sessions. */
48    static final int OPEN_SSL = 1;
49
50    private final Map<ByteArray, SSLSession> sessions
51            = new LinkedHashMap<ByteArray, SSLSession>() {
52        @Override
53        protected boolean removeEldestEntry(
54                Map.Entry<ByteArray, SSLSession> eldest) {
55            return maximumSize > 0 && size() > maximumSize;
56        }
57    };
58
59    /**
60     * Constructs a new session context.
61     *
62     * @param maximumSize of cache
63     * @param timeout for cache entries
64     */
65    AbstractSessionContext(int maximumSize, int timeout) {
66        this.maximumSize = maximumSize;
67        this.timeout = timeout;
68    }
69
70    /**
71     * Returns the collection of sessions ordered from oldest to newest
72     */
73    private Iterator<SSLSession> sessionIterator() {
74        synchronized (sessions) {
75            SSLSession[] array = sessions.values().toArray(
76                    new SSLSession[sessions.size()]);
77            return Arrays.asList(array).iterator();
78        }
79    }
80
81    public final Enumeration getIds() {
82        final Iterator<SSLSession> i = sessionIterator();
83        return new Enumeration<byte[]>() {
84            private SSLSession next;
85            public boolean hasMoreElements() {
86                if (next != null) {
87                    return true;
88                }
89                while (i.hasNext()) {
90                    SSLSession session = i.next();
91                    if (session.isValid()) {
92                        next = session;
93                        return true;
94                    }
95                }
96                next = null;
97                return false;
98            }
99            public byte[] nextElement() {
100                if (hasMoreElements()) {
101                    byte[] id = next.getId();
102                    next = null;
103                    return id;
104                }
105                throw new NoSuchElementException();
106            }
107        };
108    }
109
110    public final int getSessionCacheSize() {
111        return maximumSize;
112    }
113
114    public final int getSessionTimeout() {
115        return timeout;
116    }
117
118    /**
119     * Makes sure cache size is < maximumSize.
120     */
121    protected void trimToSize() {
122        synchronized (sessions) {
123            int size = sessions.size();
124            if (size > maximumSize) {
125                int removals = size - maximumSize;
126                Iterator<SSLSession> i = sessions.values().iterator();
127                do {
128                    SSLSession session = i.next();
129                    i.remove();
130                    sessionRemoved(session);
131                } while (--removals > 0);
132            }
133        }
134    }
135
136    public void setSessionTimeout(int seconds)
137            throws IllegalArgumentException {
138        if (seconds < 0) {
139            throw new IllegalArgumentException("seconds < 0");
140        }
141        timeout = seconds;
142
143        synchronized (sessions) {
144            Iterator<SSLSession> i = sessions.values().iterator();
145            while (i.hasNext()) {
146                SSLSession session = i.next();
147                // SSLSession's know their context and consult the
148                // timeout as part of their validity condition.
149                if (!session.isValid()) {
150                    i.remove();
151                    sessionRemoved(session);
152                }
153            }
154        }
155    }
156
157    /**
158     * Called when a session is removed. Used by ClientSessionContext
159     * to update its host-and-port based cache.
160     */
161    abstract protected void sessionRemoved(SSLSession session);
162
163    public final void setSessionCacheSize(int size)
164            throws IllegalArgumentException {
165        if (size < 0) {
166            throw new IllegalArgumentException("size < 0");
167        }
168
169        int oldMaximum = maximumSize;
170        maximumSize = size;
171
172        // Trim cache to size if necessary.
173        if (size < oldMaximum) {
174            trimToSize();
175        }
176    }
177
178    /**
179     * Converts the given session to bytes.
180     *
181     * @return session data as bytes or null if the session can't be converted
182     */
183    byte[] toBytes(SSLSession session) {
184        // TODO: Support SSLSessionImpl, too.
185        if (!(session instanceof OpenSSLSessionImpl)) {
186            return null;
187        }
188
189        OpenSSLSessionImpl sslSession = (OpenSSLSessionImpl) session;
190        try {
191            ByteArrayOutputStream baos = new ByteArrayOutputStream();
192            DataOutputStream daos = new DataOutputStream(baos);
193
194            daos.writeInt(OPEN_SSL); // session type ID
195
196            // Session data.
197            byte[] data = sslSession.getEncoded();
198            daos.writeInt(data.length);
199            daos.write(data);
200
201            // Certificates.
202            X509Certificate[] certs = session.getPeerCertificateChain();
203            daos.writeInt(certs.length);
204
205            // TODO: Call nativegetpeercertificates()
206            for (X509Certificate cert : certs) {
207                data = cert.getEncoded();
208                daos.writeInt(data.length);
209                daos.write(data);
210            }
211            // TODO: local certificates?
212
213            return baos.toByteArray();
214        } catch (IOException e) {
215            log(e);
216            return null;
217        } catch (CertificateEncodingException e) {
218            log(e);
219            return null;
220        }
221    }
222
223    /**
224     * Creates a session from the given bytes.
225     *
226     * @return a session or null if the session can't be converted
227     */
228    SSLSession toSession(byte[] data, String host, int port) {
229        ByteArrayInputStream bais = new ByteArrayInputStream(data);
230        DataInputStream dais = new DataInputStream(bais);
231        try {
232            int type = dais.readInt();
233            if (type != OPEN_SSL) {
234                log(new AssertionError("Unexpected type ID: " + type));
235                return null;
236            }
237
238            int length = dais.readInt();
239            byte[] sessionData = new byte[length];
240            dais.readFully(sessionData);
241
242            int count = dais.readInt();
243            X509Certificate[] certs = new X509Certificate[count];
244            for (int i = 0; i < count; i++) {
245                length = dais.readInt();
246                byte[] certData = new byte[length];
247                dais.readFully(certData);
248                certs[i] = X509Certificate.getInstance(certData);
249            }
250
251            return new OpenSSLSessionImpl(sessionData, host, port, certs, this);
252        } catch (IOException e) {
253            log(e);
254            return null;
255        } catch (CertificateException e) {
256            log(e);
257            return null;
258        }
259    }
260
261    public SSLSession getSession(byte[] sessionId) {
262        if (sessionId == null) {
263            throw new NullPointerException("sessionId == null");
264        }
265        ByteArray key = new ByteArray(sessionId);
266        SSLSession session;
267        synchronized (sessions) {
268            session = sessions.get(key);
269        }
270        if (session != null && session.isValid()) {
271            return session;
272        }
273        return null;
274    }
275
276    void putSession(SSLSession session) {
277        byte[] id = session.getId();
278        if (id.length == 0) {
279            return;
280        }
281        ByteArray key = new ByteArray(id);
282        synchronized (sessions) {
283            sessions.put(key, session);
284        }
285    }
286
287    static void log(Throwable t) {
288        java.util.logging.Logger.global.log(Level.WARNING,
289                "Error converting session.", t);
290    }
291
292    protected void finalize() throws IOException {
293        NativeCrypto.SSL_CTX_free(sslCtxNativePointer);
294    }
295
296    /**
297     * Byte array wrapper. Implements equals() and hashCode().
298     */
299    static class ByteArray {
300
301        private final byte[] bytes;
302
303        ByteArray(byte[] bytes) {
304            this.bytes = bytes;
305        }
306
307        @Override
308        public int hashCode() {
309            return Arrays.hashCode(bytes);
310        }
311
312        @Override
313        @SuppressWarnings("EqualsWhichDoesntCheckParameterClass")
314        public boolean equals(Object o) {
315            ByteArray other = (ByteArray) o;
316            return Arrays.equals(bytes, other.bytes);
317        }
318    }
319}
320