1/*
2 * Copyright (C) 2009 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package org.conscrypt;
18
19import java.util.HashMap;
20import java.util.Map;
21import javax.net.ssl.SSLSession;
22
23/**
24 * Caches client sessions. Indexes by host and port. Users are typically
25 * looking to reuse any session for a given host and port.
26 */
27public class ClientSessionContext extends AbstractSessionContext {
28
29    /**
30     * Sessions indexed by host and port. Protect from concurrent
31     * access by holding a lock on sessionsByHostAndPort.
32     */
33    private final HashMap<HostAndPort, SSLSession> sessionsByHostAndPort = new HashMap<>();
34
35    private SSLClientSessionCache persistentCache;
36
37    public ClientSessionContext() {
38        super(10);
39    }
40
41    public int size() {
42        return sessionsByHostAndPort.size();
43    }
44
45    public void setPersistentCache(SSLClientSessionCache persistentCache) {
46        this.persistentCache = persistentCache;
47    }
48
49    @Override
50    protected void sessionRemoved(SSLSession session) {
51        String host = session.getPeerHost();
52        int port = session.getPeerPort();
53        if (host == null) {
54            return;
55        }
56        HostAndPort hostAndPortKey = new HostAndPort(host, port);
57        synchronized (sessionsByHostAndPort) {
58            sessionsByHostAndPort.remove(hostAndPortKey);
59        }
60    }
61
62    /**
63     * Finds a cached session for the given host name and port.
64     *
65     * @param host of server
66     * @param port of server
67     * @return cached session or null if none found
68     */
69    public SSLSession getSession(String host, int port) {
70        if (host == null) {
71            return null;
72        }
73        SSLSession session;
74        HostAndPort hostAndPortKey = new HostAndPort(host, port);
75        synchronized (sessionsByHostAndPort) {
76            session = sessionsByHostAndPort.get(hostAndPortKey);
77        }
78        if (session != null && session.isValid()) {
79            return wrapSSLSessionIfNeeded(session);
80        }
81
82        // Look in persistent cache.
83        if (persistentCache != null) {
84            byte[] data = persistentCache.getSessionData(host, port);
85            if (data != null) {
86                session = toSession(data, host, port);
87                if (session != null && session.isValid()) {
88                    super.putSession(session);
89                    synchronized (sessionsByHostAndPort) {
90                        sessionsByHostAndPort.put(hostAndPortKey, session);
91                    }
92                    return wrapSSLSessionIfNeeded(session);
93                }
94            }
95        }
96
97        return null;
98    }
99
100    @Override
101    public void putSession(SSLSession session) {
102        super.putSession(session);
103
104        String host = session.getPeerHost();
105        int port = session.getPeerPort();
106        if (host == null) {
107            return;
108        }
109
110        HostAndPort hostAndPortKey = new HostAndPort(host, port);
111        synchronized (sessionsByHostAndPort) {
112            sessionsByHostAndPort.put(hostAndPortKey, session);
113        }
114
115        // TODO: This in a background thread.
116        if (persistentCache != null) {
117            byte[] data = toBytes(session);
118            if (data != null) {
119                persistentCache.putSessionData(session, data);
120            }
121        }
122    }
123
124    static class HostAndPort {
125        final String host;
126        final int port;
127
128        HostAndPort(String host, int port) {
129            this.host = host;
130            this.port = port;
131        }
132
133        @Override
134        public int hashCode() {
135            return host.hashCode() * 31 + port;
136        }
137
138        @Override
139        public boolean equals(Object o) {
140            if (!(o instanceof HostAndPort)) {
141                return false;
142            }
143            HostAndPort lhs = (HostAndPort) o;
144            return host.equals(lhs.host) && port == lhs.port;
145        }
146    }
147}
148