1# Copyright 2014 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Test routines to generate dummy certificates."""
16
17import BaseHTTPServer
18import shutil
19import signal
20import socket
21import tempfile
22import threading
23import time
24import unittest
25
26import certutils
27import sslproxy
28
29
30class Client(object):
31
32  def __init__(self, ca_cert_path, verify_cb, port, host_name='foo.com',
33               host='localhost'):
34    self.host_name = host_name
35    self.verify_cb = verify_cb
36    self.ca_cert_path = ca_cert_path
37    self.port = port
38    self.host_name = host_name
39    self.host = host
40    self.connection = None
41
42  def run_request(self):
43    context = certutils.get_ssl_context()
44    context.set_verify(certutils.VERIFY_PEER, self.verify_cb)  # Demand a cert
45    context.use_certificate_file(self.ca_cert_path)
46    context.load_verify_locations(self.ca_cert_path)
47
48    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
49    self.connection = certutils.get_ssl_connection(context, s)
50    self.connection.connect((self.host, self.port))
51    self.connection.set_tlsext_host_name(self.host_name)
52
53    try:
54      self.connection.send('\r\n\r\n')
55    finally:
56      self.connection.shutdown()
57      self.connection.close()
58
59
60class Handler(BaseHTTPServer.BaseHTTPRequestHandler):
61  protocol_version = 'HTTP/1.1'  # override BaseHTTPServer setting
62
63  def handle_one_request(self):
64    """Handle a single HTTP request."""
65    self.raw_requestline = self.rfile.readline(65537)
66
67
68class WrappedErrorHandler(Handler):
69  """Wraps handler to verify expected sslproxy errors are being raised."""
70
71  def setup(self):
72    Handler.setup(self)
73    try:
74      sslproxy._SetUpUsingDummyCert(self)
75    except certutils.Error:
76      self.server.error_function = certutils.Error
77
78  def finish(self):
79    Handler.finish(self)
80    self.connection.shutdown()
81    self.connection.close()
82
83
84class DummyArchive(object):
85
86  def __init__(self):
87    pass
88
89
90class DummyFetch(object):
91
92  def __init__(self):
93    self.http_archive = DummyArchive()
94
95
96class Server(BaseHTTPServer.HTTPServer):
97  """SSL server."""
98
99  def __init__(self, ca_cert_path, use_error_handler=False, port=0,
100               host='localhost'):
101    self.ca_cert_path = ca_cert_path
102    with open(ca_cert_path, 'r') as ca_file:
103      self.ca_cert_str = ca_file.read()
104    self.http_archive_fetch = DummyFetch()
105    if use_error_handler:
106      self.HANDLER = WrappedErrorHandler
107    else:
108      self.HANDLER = sslproxy.wrap_handler(Handler)
109    try:
110      BaseHTTPServer.HTTPServer.__init__(self, (host, port), self.HANDLER)
111    except Exception, e:
112      raise RuntimeError('Could not start HTTPSServer on port %d: %s'
113                         % (port, e))
114
115  def __enter__(self):
116    thread = threading.Thread(target=self.serve_forever)
117    thread.daemon = True
118    thread.start()
119    return self
120
121  def cleanup(self):
122    try:
123      self.shutdown()
124    except KeyboardInterrupt:
125      pass
126
127  def __exit__(self, type_, value_, traceback_):
128    self.cleanup()
129
130  def get_certificate(self, host):
131    return certutils.generate_cert(self.ca_cert_str, '', host)
132
133
134class TestClient(unittest.TestCase):
135  _temp_dir = None
136
137  def setUp(self):
138    self._temp_dir = tempfile.mkdtemp(prefix='sslproxy_', dir='/tmp')
139    self.ca_cert_path = self._temp_dir + 'testCA.pem'
140    self.cert_path = self._temp_dir + 'testCA-cert.cer'
141    self.wrong_ca_cert_path = self._temp_dir + 'wrong.pem'
142    self.wrong_cert_path = self._temp_dir + 'wrong-cert.cer'
143
144    # Write both pem and cer files for certificates
145    certutils.write_dummy_ca_cert(*certutils.generate_dummy_ca_cert(),
146                                  cert_path=self.ca_cert_path)
147    certutils.write_dummy_ca_cert(*certutils.generate_dummy_ca_cert(),
148                                  cert_path=self.ca_cert_path)
149
150  def tearDown(self):
151    if self._temp_dir:
152      shutil.rmtree(self._temp_dir)
153
154  def verify_cb(self, conn, cert, errnum, depth, ok):
155    """A callback that verifies the certificate authentication worked.
156
157    Args:
158      conn: Connection object
159      cert: x509 object
160      errnum: possible error number
161      depth: error depth
162      ok: 1 if the authentication worked 0 if it didnt.
163    Returns:
164      1 or 0 depending on if the verification worked
165    """
166    self.assertFalse(cert.has_expired())
167    self.assertGreater(time.strftime('%Y%m%d%H%M%SZ', time.gmtime()),
168                       cert.get_notBefore())
169    return ok
170
171  def test_no_host(self):
172    with Server(self.ca_cert_path) as server:
173      c = Client(self.cert_path, self.verify_cb, server.server_port, '')
174      self.assertRaises(certutils.Error, c.run_request)
175
176  def test_client_connection(self):
177    with Server(self.ca_cert_path) as server:
178      c = Client(self.cert_path, self.verify_cb, server.server_port, 'foo.com')
179      c.run_request()
180
181      c = Client(self.cert_path, self.verify_cb, server.server_port,
182                 'random.host')
183      c.run_request()
184
185  def test_wrong_cert(self):
186    with Server(self.ca_cert_path, True) as server:
187      c = Client(self.wrong_cert_path, self.verify_cb, server.server_port,
188                 'foo.com')
189      self.assertRaises(certutils.Error, c.run_request)
190
191
192if __name__ == '__main__':
193  signal.signal(signal.SIGINT, signal.SIG_DFL)  # Exit on Ctrl-C
194  unittest.main()
195