1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/socket/ssl_session_cache_openssl.h"
6
7#include <openssl/ssl.h>
8
9#include "base/lazy_instance.h"
10#include "base/logging.h"
11#include "base/strings/stringprintf.h"
12#include "crypto/openssl_util.h"
13#include "crypto/scoped_openssl_types.h"
14
15#include "testing/gtest/include/gtest/gtest.h"
16
17// This is an internal OpenSSL function that can be used to create a new
18// session for an existing SSL object. This shall force a call to the
19// 'generate_session_id' callback from the SSL's session context.
20// |s| is the target SSL connection handle.
21// |session| is non-0 to ask for the creation of a new session. If 0,
22// this will set an empty session with no ID instead.
23extern "C" int ssl_get_new_session(SSL* s, int session);
24
25// This is an internal OpenSSL function which is used internally to add
26// a new session to the cache. It is normally triggered by a succesful
27// connection. However, this unit test does not use the network at all.
28extern "C" void ssl_update_cache(SSL* s, int mode);
29
30namespace net {
31
32namespace {
33
34typedef crypto::ScopedOpenSSL<SSL, SSL_free>::Type ScopedSSL;
35typedef crypto::ScopedOpenSSL<SSL_CTX, SSL_CTX_free>::Type ScopedSSL_CTX;
36
37// Helper class used to associate arbitrary std::string keys with SSL objects.
38class SSLKeyHelper {
39 public:
40  // Return the string associated with a given SSL handle |ssl|, or the
41  // empty string if none exists.
42  static std::string Get(const SSL* ssl) {
43    return GetInstance()->GetValue(ssl);
44  }
45
46  // Associate a string with a given SSL handle |ssl|.
47  static void Set(SSL* ssl, const std::string& value) {
48    GetInstance()->SetValue(ssl, value);
49  }
50
51  static SSLKeyHelper* GetInstance() {
52    static base::LazyInstance<SSLKeyHelper>::Leaky s_instance =
53        LAZY_INSTANCE_INITIALIZER;
54    return s_instance.Pointer();
55  }
56
57  SSLKeyHelper() {
58    ex_index_ = SSL_get_ex_new_index(0, NULL, NULL, KeyDup, KeyFree);
59    CHECK_NE(-1, ex_index_);
60  }
61
62  std::string GetValue(const SSL* ssl) {
63    std::string* value =
64        reinterpret_cast<std::string*>(SSL_get_ex_data(ssl, ex_index_));
65    if (!value)
66      return std::string();
67    return *value;
68  }
69
70  void SetValue(SSL* ssl, const std::string& value) {
71    int ret = SSL_set_ex_data(ssl, ex_index_, new std::string(value));
72    CHECK_EQ(1, ret);
73  }
74
75  // Called when an SSL object is copied through SSL_dup(). This needs to copy
76  // the value as well.
77  static int KeyDup(CRYPTO_EX_DATA* to,
78                    const CRYPTO_EX_DATA* from,
79                    void** from_fd,
80                    int idx,
81                    long argl,
82                    void* argp) {
83    // |from_fd| is really the address of a temporary pointer. On input, it
84    // points to the value from the original SSL object. The function must
85    // update it to the address of a copy.
86    std::string** ptr = reinterpret_cast<std::string**>(from_fd);
87    std::string* old_string = *ptr;
88    std::string* new_string = new std::string(*old_string);
89    *ptr = new_string;
90    return 0;  // Ignored by the implementation.
91  }
92
93  // Called to destroy the value associated with an SSL object.
94  static void KeyFree(void* parent,
95                      void* ptr,
96                      CRYPTO_EX_DATA* ad,
97                      int index,
98                      long argl,
99                      void* argp) {
100    std::string* value = reinterpret_cast<std::string*>(ptr);
101    delete value;
102  }
103
104  int ex_index_;
105};
106
107}  // namespace
108
109class SSLSessionCacheOpenSSLTest : public testing::Test {
110 public:
111  SSLSessionCacheOpenSSLTest() {
112    crypto::EnsureOpenSSLInit();
113    ctx_.reset(SSL_CTX_new(SSLv23_client_method()));
114    cache_.Reset(ctx_.get(), kDefaultConfig);
115  }
116
117  // Reset cache configuration.
118  void ResetConfig(const SSLSessionCacheOpenSSL::Config& config) {
119    cache_.Reset(ctx_.get(), config);
120  }
121
122  // Helper function to create a new SSL connection object associated with
123  // a given unique |cache_key|. This does _not_ add the session to the cache.
124  // Caller must free the object with SSL_free().
125  SSL* NewSSL(const std::string& cache_key) {
126    SSL* ssl = SSL_new(ctx_.get());
127    if (!ssl)
128      return NULL;
129
130    SSLKeyHelper::Set(ssl, cache_key);  // associate cache key.
131    ResetSessionID(ssl);                // create new unique session ID.
132    return ssl;
133  }
134
135  // Reset the session ID of a given SSL object. This creates a new session
136  // with a new unique random ID. Does not add it to the cache.
137  static void ResetSessionID(SSL* ssl) { ssl_get_new_session(ssl, 1); }
138
139  // Add a given SSL object and its session to the cache.
140  void AddToCache(SSL* ssl) {
141    ssl_update_cache(ssl, ctx_.get()->session_cache_mode);
142  }
143
144  static const SSLSessionCacheOpenSSL::Config kDefaultConfig;
145
146 protected:
147  ScopedSSL_CTX ctx_;
148  // |cache_| must be destroyed before |ctx_| and thus appears after it.
149  SSLSessionCacheOpenSSL cache_;
150};
151
152// static
153const SSLSessionCacheOpenSSL::Config
154    SSLSessionCacheOpenSSLTest::kDefaultConfig = {
155        &SSLKeyHelper::Get,  // key_func
156        1024,                // max_entries
157        256,                 // expiration_check_count
158        60 * 60,             // timeout_seconds
159};
160
161TEST_F(SSLSessionCacheOpenSSLTest, EmptyCacheCreation) {
162  EXPECT_EQ(0U, cache_.size());
163}
164
165TEST_F(SSLSessionCacheOpenSSLTest, CacheOneSession) {
166  ScopedSSL ssl(NewSSL("hello"));
167
168  EXPECT_EQ(0U, cache_.size());
169  AddToCache(ssl.get());
170  EXPECT_EQ(1U, cache_.size());
171  ssl.reset(NULL);
172  EXPECT_EQ(1U, cache_.size());
173}
174
175TEST_F(SSLSessionCacheOpenSSLTest, CacheMultipleSessions) {
176  const size_t kNumItems = 100;
177  int local_id = 1;
178
179  // Add kNumItems to the cache.
180  for (size_t n = 0; n < kNumItems; ++n) {
181    std::string local_id_string = base::StringPrintf("%d", local_id++);
182    ScopedSSL ssl(NewSSL(local_id_string));
183    AddToCache(ssl.get());
184    EXPECT_EQ(n + 1, cache_.size());
185  }
186}
187
188TEST_F(SSLSessionCacheOpenSSLTest, Flush) {
189  const size_t kNumItems = 100;
190  int local_id = 1;
191
192  // Add kNumItems to the cache.
193  for (size_t n = 0; n < kNumItems; ++n) {
194    std::string local_id_string = base::StringPrintf("%d", local_id++);
195    ScopedSSL ssl(NewSSL(local_id_string));
196    AddToCache(ssl.get());
197  }
198  EXPECT_EQ(kNumItems, cache_.size());
199
200  cache_.Flush();
201  EXPECT_EQ(0U, cache_.size());
202}
203
204TEST_F(SSLSessionCacheOpenSSLTest, SetSSLSession) {
205  const std::string key("hello");
206  ScopedSSL ssl(NewSSL(key));
207
208  // First call should fail because the session is not in the cache.
209  EXPECT_FALSE(cache_.SetSSLSession(ssl.get()));
210  SSL_SESSION* session = ssl.get()->session;
211  EXPECT_TRUE(session);
212  EXPECT_EQ(1, session->references);
213
214  AddToCache(ssl.get());
215  EXPECT_EQ(2, session->references);
216
217  // Mark the session as good, so that it is re-used for the second connection.
218  cache_.MarkSSLSessionAsGood(ssl.get());
219
220  ssl.reset(NULL);
221  EXPECT_EQ(1, session->references);
222
223  // Second call should find the session ID and associate it with |ssl2|.
224  ScopedSSL ssl2(NewSSL(key));
225  EXPECT_TRUE(cache_.SetSSLSession(ssl2.get()));
226
227  EXPECT_EQ(session, ssl2.get()->session);
228  EXPECT_EQ(2, session->references);
229}
230
231TEST_F(SSLSessionCacheOpenSSLTest, SetSSLSessionWithKey) {
232  const std::string key("hello");
233  ScopedSSL ssl(NewSSL(key));
234  AddToCache(ssl.get());
235  cache_.MarkSSLSessionAsGood(ssl.get());
236  ssl.reset(NULL);
237
238  ScopedSSL ssl2(NewSSL(key));
239  EXPECT_TRUE(cache_.SetSSLSessionWithKey(ssl2.get(), key));
240}
241
242TEST_F(SSLSessionCacheOpenSSLTest, CheckSessionReplacement) {
243  // Check that if two SSL connections have the same key, only one
244  // corresponding session can be stored in the cache.
245  const std::string common_key("common-key");
246  ScopedSSL ssl1(NewSSL(common_key));
247  ScopedSSL ssl2(NewSSL(common_key));
248
249  AddToCache(ssl1.get());
250  EXPECT_EQ(1U, cache_.size());
251  EXPECT_EQ(2, ssl1.get()->session->references);
252
253  // This ends up calling OnSessionAdded which will discover that there is
254  // already one session ID associated with the key, and will replace it.
255  AddToCache(ssl2.get());
256  EXPECT_EQ(1U, cache_.size());
257  EXPECT_EQ(1, ssl1.get()->session->references);
258  EXPECT_EQ(2, ssl2.get()->session->references);
259}
260
261// Check that when two connections have the same key, a new session is created
262// if the existing session has not yet been marked "good". Further, after the
263// first session completes, if the second session has replaced it in the cache,
264// new sessions should continue to fail until the currently cached session
265// succeeds.
266TEST_F(SSLSessionCacheOpenSSLTest, CheckSessionReplacementWhenNotGood) {
267  const std::string key("hello");
268  ScopedSSL ssl(NewSSL(key));
269
270  // First call should fail because the session is not in the cache.
271  EXPECT_FALSE(cache_.SetSSLSession(ssl.get()));
272  SSL_SESSION* session = ssl.get()->session;
273  ASSERT_TRUE(session);
274  EXPECT_EQ(1, session->references);
275
276  AddToCache(ssl.get());
277  EXPECT_EQ(2, session->references);
278
279  // Second call should find the session ID, but because it is not yet good,
280  // fail to associate it with |ssl2|.
281  ScopedSSL ssl2(NewSSL(key));
282  EXPECT_FALSE(cache_.SetSSLSession(ssl2.get()));
283  SSL_SESSION* session2 = ssl2.get()->session;
284  ASSERT_TRUE(session2);
285  EXPECT_EQ(1, session2->references);
286
287  EXPECT_NE(session, session2);
288
289  // Add the second connection to the cache. It should replace the first
290  // session, and the cache should hold on to the second session.
291  AddToCache(ssl2.get());
292  EXPECT_EQ(1, session->references);
293  EXPECT_EQ(2, session2->references);
294
295  // Mark the first session as good, simulating it completing.
296  cache_.MarkSSLSessionAsGood(ssl.get());
297
298  // Third call should find the session ID, but because the second session (the
299  // current cache entry) is not yet good, fail to associate it with |ssl3|.
300  ScopedSSL ssl3(NewSSL(key));
301  EXPECT_FALSE(cache_.SetSSLSession(ssl3.get()));
302  EXPECT_NE(session, ssl3.get()->session);
303  EXPECT_NE(session2, ssl3.get()->session);
304  EXPECT_EQ(1, ssl3.get()->session->references);
305}
306
307TEST_F(SSLSessionCacheOpenSSLTest, CheckEviction) {
308  const size_t kMaxItems = 20;
309  int local_id = 1;
310
311  SSLSessionCacheOpenSSL::Config config = kDefaultConfig;
312  config.max_entries = kMaxItems;
313  ResetConfig(config);
314
315  // Add kMaxItems to the cache.
316  for (size_t n = 0; n < kMaxItems; ++n) {
317    std::string local_id_string = base::StringPrintf("%d", local_id++);
318    ScopedSSL ssl(NewSSL(local_id_string));
319
320    AddToCache(ssl.get());
321    EXPECT_EQ(n + 1, cache_.size());
322  }
323
324  // Continue adding new items to the cache, check that old ones are
325  // evicted.
326  for (size_t n = 0; n < kMaxItems; ++n) {
327    std::string local_id_string = base::StringPrintf("%d", local_id++);
328    ScopedSSL ssl(NewSSL(local_id_string));
329
330    AddToCache(ssl.get());
331    EXPECT_EQ(kMaxItems, cache_.size());
332  }
333}
334
335// Check that session expiration works properly.
336TEST_F(SSLSessionCacheOpenSSLTest, CheckExpiration) {
337  const size_t kMaxCheckCount = 10;
338  const size_t kNumEntries = 20;
339
340  SSLSessionCacheOpenSSL::Config config = kDefaultConfig;
341  config.expiration_check_count = kMaxCheckCount;
342  config.timeout_seconds = 1000;
343  ResetConfig(config);
344
345  // Add |kNumItems - 1| session entries with crafted time values.
346  for (size_t n = 0; n < kNumEntries - 1U; ++n) {
347    std::string key = base::StringPrintf("%d", static_cast<int>(n));
348    ScopedSSL ssl(NewSSL(key));
349    // Cheat a little: Force the session |time| value, this guarantees that they
350    // are expired, given that ::time() will always return a value that is
351    // past the first 100 seconds after the Unix epoch.
352    ssl.get()->session->time = static_cast<long>(n);
353    AddToCache(ssl.get());
354  }
355  EXPECT_EQ(kNumEntries - 1U, cache_.size());
356
357  // Add nother session which will get the current time, and thus not be
358  // expirable until 1000 seconds have passed.
359  ScopedSSL good_ssl(NewSSL("good-key"));
360  AddToCache(good_ssl.get());
361  good_ssl.reset(NULL);
362  EXPECT_EQ(kNumEntries, cache_.size());
363
364  // Call SetSSLSession() |kMaxCheckCount - 1| times, this shall not expire
365  // any session
366  for (size_t n = 0; n < kMaxCheckCount - 1U; ++n) {
367    ScopedSSL ssl(NewSSL("unknown-key"));
368    cache_.SetSSLSession(ssl.get());
369    EXPECT_EQ(kNumEntries, cache_.size());
370  }
371
372  // Call SetSSLSession another time, this shall expire all sessions except
373  // the last one.
374  ScopedSSL bad_ssl(NewSSL("unknown-key"));
375  cache_.SetSSLSession(bad_ssl.get());
376  bad_ssl.reset(NULL);
377  EXPECT_EQ(1U, cache_.size());
378}
379
380}  // namespace net
381