AbstractSessionContext.java revision 8acd6134dc84b387608746fbf2054c6d7dcd4f52
1/*
2 * Copyright (C) 2009 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package org.apache.harmony.xnet.provider.jsse;
18
19import java.io.ByteArrayInputStream;
20import java.io.ByteArrayOutputStream;
21import java.io.DataInputStream;
22import java.io.DataOutputStream;
23import java.io.IOException;
24import java.security.cert.Certificate;
25import java.security.cert.CertificateEncodingException;
26import java.util.Arrays;
27import java.util.Enumeration;
28import java.util.Iterator;
29import java.util.LinkedHashMap;
30import java.util.Map;
31import java.util.NoSuchElementException;
32import javax.net.ssl.SSLSession;
33import javax.net.ssl.SSLSessionContext;
34import org.apache.harmony.security.provider.cert.X509CertImpl;
35
36/**
37 * Supports SSL session caches.
38 */
39abstract class AbstractSessionContext implements SSLSessionContext {
40
41    volatile int maximumSize;
42    volatile int timeout;
43
44    final long sslCtxNativePointer = NativeCrypto.SSL_CTX_new();
45
46    /** Identifies OpenSSL sessions. */
47    static final int OPEN_SSL = 1;
48
49    private final Map<ByteArray, SSLSession> sessions
50            = new LinkedHashMap<ByteArray, SSLSession>() {
51        @Override
52        protected boolean removeEldestEntry(
53                Map.Entry<ByteArray, SSLSession> eldest) {
54            return maximumSize > 0 && size() > maximumSize;
55        }
56    };
57
58    /**
59     * Constructs a new session context.
60     *
61     * @param maximumSize of cache
62     * @param timeout for cache entries
63     */
64    AbstractSessionContext(int maximumSize, int timeout) {
65        this.maximumSize = maximumSize;
66        this.timeout = timeout;
67    }
68
69    /**
70     * Returns the collection of sessions ordered from oldest to newest
71     */
72    private Iterator<SSLSession> sessionIterator() {
73        synchronized (sessions) {
74            SSLSession[] array = sessions.values().toArray(
75                    new SSLSession[sessions.size()]);
76            return Arrays.asList(array).iterator();
77        }
78    }
79
80    public final Enumeration getIds() {
81        final Iterator<SSLSession> i = sessionIterator();
82        return new Enumeration<byte[]>() {
83            private SSLSession next;
84            public boolean hasMoreElements() {
85                if (next != null) {
86                    return true;
87                }
88                while (i.hasNext()) {
89                    SSLSession session = i.next();
90                    if (session.isValid()) {
91                        next = session;
92                        return true;
93                    }
94                }
95                next = null;
96                return false;
97            }
98            public byte[] nextElement() {
99                if (hasMoreElements()) {
100                    byte[] id = next.getId();
101                    next = null;
102                    return id;
103                }
104                throw new NoSuchElementException();
105            }
106        };
107    }
108
109    public final int getSessionCacheSize() {
110        return maximumSize;
111    }
112
113    public final int getSessionTimeout() {
114        return timeout;
115    }
116
117    /**
118     * Makes sure cache size is < maximumSize.
119     */
120    protected void trimToSize() {
121        synchronized (sessions) {
122            int size = sessions.size();
123            if (size > maximumSize) {
124                int removals = size - maximumSize;
125                Iterator<SSLSession> i = sessions.values().iterator();
126                do {
127                    SSLSession session = i.next();
128                    i.remove();
129                    sessionRemoved(session);
130                } while (--removals > 0);
131            }
132        }
133    }
134
135    public void setSessionTimeout(int seconds)
136            throws IllegalArgumentException {
137        if (seconds < 0) {
138            throw new IllegalArgumentException("seconds < 0");
139        }
140        timeout = seconds;
141
142        synchronized (sessions) {
143            Iterator<SSLSession> i = sessions.values().iterator();
144            while (i.hasNext()) {
145                SSLSession session = i.next();
146                // SSLSession's know their context and consult the
147                // timeout as part of their validity condition.
148                if (!session.isValid()) {
149                    i.remove();
150                    sessionRemoved(session);
151                }
152            }
153        }
154    }
155
156    /**
157     * Called when a session is removed. Used by ClientSessionContext
158     * to update its host-and-port based cache.
159     */
160    protected abstract void sessionRemoved(SSLSession session);
161
162    public final void setSessionCacheSize(int size)
163            throws IllegalArgumentException {
164        if (size < 0) {
165            throw new IllegalArgumentException("size < 0");
166        }
167
168        int oldMaximum = maximumSize;
169        maximumSize = size;
170
171        // Trim cache to size if necessary.
172        if (size < oldMaximum) {
173            trimToSize();
174        }
175    }
176
177    /**
178     * Converts the given session to bytes.
179     *
180     * @return session data as bytes or null if the session can't be converted
181     */
182    byte[] toBytes(SSLSession session) {
183        // TODO: Support SSLSessionImpl, too.
184        if (!(session instanceof OpenSSLSessionImpl)) {
185            return null;
186        }
187
188        OpenSSLSessionImpl sslSession = (OpenSSLSessionImpl) session;
189        try {
190            ByteArrayOutputStream baos = new ByteArrayOutputStream();
191            DataOutputStream daos = new DataOutputStream(baos);
192
193            daos.writeInt(OPEN_SSL); // session type ID
194
195            // Session data.
196            byte[] data = sslSession.getEncoded();
197            daos.writeInt(data.length);
198            daos.write(data);
199
200            // Certificates.
201            Certificate[] certs = session.getPeerCertificates();
202            daos.writeInt(certs.length);
203
204            for (Certificate cert : certs) {
205                data = cert.getEncoded();
206                daos.writeInt(data.length);
207                daos.write(data);
208            }
209            // TODO: local certificates?
210
211            return baos.toByteArray();
212        } catch (IOException e) {
213            log(e);
214            return null;
215        } catch (CertificateEncodingException e) {
216            log(e);
217            return null;
218        }
219    }
220
221    /**
222     * Creates a session from the given bytes.
223     *
224     * @return a session or null if the session can't be converted
225     */
226    SSLSession toSession(byte[] data, String host, int port) {
227        ByteArrayInputStream bais = new ByteArrayInputStream(data);
228        DataInputStream dais = new DataInputStream(bais);
229        try {
230            int type = dais.readInt();
231            if (type != OPEN_SSL) {
232                log(new AssertionError("Unexpected type ID: " + type));
233                return null;
234            }
235
236            int length = dais.readInt();
237            byte[] sessionData = new byte[length];
238            dais.readFully(sessionData);
239
240            int count = dais.readInt();
241            X509CertImpl[] certs = new X509CertImpl[count];
242            for (int i = 0; i < count; i++) {
243                length = dais.readInt();
244                byte[] certData = new byte[length];
245                dais.readFully(certData);
246                certs[i] = new X509CertImpl(certData);
247            }
248
249            return new OpenSSLSessionImpl(sessionData, host, port, certs, this);
250        } catch (IOException e) {
251            log(e);
252            return null;
253        }
254    }
255
256    public SSLSession getSession(byte[] sessionId) {
257        if (sessionId == null) {
258            throw new NullPointerException("sessionId == null");
259        }
260        ByteArray key = new ByteArray(sessionId);
261        SSLSession session;
262        synchronized (sessions) {
263            session = sessions.get(key);
264        }
265        if (session != null && session.isValid()) {
266            return session;
267        }
268        return null;
269    }
270
271    void putSession(SSLSession session) {
272        byte[] id = session.getId();
273        if (id.length == 0) {
274            return;
275        }
276        ByteArray key = new ByteArray(id);
277        synchronized (sessions) {
278            sessions.put(key, session);
279        }
280    }
281
282    static void log(Throwable t) {
283        System.logW("Error converting session.", t);
284    }
285
286    @Override protected void finalize() throws Throwable {
287        try {
288            NativeCrypto.SSL_CTX_free(sslCtxNativePointer);
289        } finally {
290            super.finalize();
291        }
292    }
293}
294