1# Copyright 2013 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import BaseHTTPServer
6import os
7import threading
8
9
10class Responder(object):
11  """Sends a HTTP response. Used with TestWebServer."""
12
13  def __init__(self, handler):
14    self._handler = handler
15
16  def SendResponse(self, body):
17    """Sends OK response with body."""
18    self.SendHeaders(len(body))
19    self.SendBody(body)
20
21  def SendResponseFromFile(self, path):
22    """Sends OK response with the given file as the body."""
23    with open(path, 'r') as f:
24      self.SendResponse(f.read())
25
26  def SendHeaders(self, content_length=None):
27    """Sends headers for OK response."""
28    self._handler.send_response(200)
29    if content_length:
30      self._handler.send_header('Content-Length', content_length)
31    self._handler.end_headers()
32
33  def SendError(self, code):
34    """Sends response for the given HTTP error code."""
35    self._handler.send_error(code)
36
37  def SendBody(self, body):
38    """Just sends the body, no headers."""
39    self._handler.wfile.write(body)
40
41
42class Request(object):
43  """An HTTP request."""
44
45  def __init__(self, handler):
46    self._handler = handler
47
48  def GetPath(self):
49    return self._handler.path
50
51  def GetHeader(self, name):
52    return self._handler.headers.getheader(name)
53
54
55class _BaseServer(BaseHTTPServer.HTTPServer):
56  """Internal server that throws if timed out waiting for a request."""
57
58  def __init__(self, on_request, server_cert_and_key_path=None):
59    """Starts the server.
60
61    It is an HTTP server if parameter server_cert_and_key_path is not provided.
62    Otherwise, it is an HTTPS server.
63
64    Args:
65      server_cert_and_key_path: path to a PEM file containing the cert and key.
66                                if it is None, start the server as an HTTP one.
67    """
68    class _Handler(BaseHTTPServer.BaseHTTPRequestHandler):
69      """Internal handler that just asks the server to handle the request."""
70
71      def do_GET(self):
72        if self.path.endswith('favicon.ico'):
73          self.send_error(404)
74          return
75        on_request(Request(self), Responder(self))
76
77      def log_message(self, *args, **kwargs):
78        """Overriddes base class method to disable logging."""
79        pass
80
81    BaseHTTPServer.HTTPServer.__init__(self, ('127.0.0.1', 0), _Handler)
82
83    if server_cert_and_key_path is not None:
84      self._is_https_enabled = True
85      self._server.socket = ssl.wrap_socket(
86          self._server.socket, certfile=server_cert_and_key_path,
87          server_side=True)
88    else:
89      self._is_https_enabled = False
90
91  def handle_timeout(self):
92    """Overridden from SocketServer."""
93    raise RuntimeError('Timed out waiting for http request')
94
95  def GetUrl(self):
96    """Returns the base URL of the server."""
97    postfix = '://127.0.0.1:%s' % self.server_port
98    if self._is_https_enabled:
99      return 'https' + postfix
100    return 'http' + postfix
101
102
103class WebServer(object):
104  """An HTTP or HTTPS server that serves on its own thread.
105
106  Serves files from given directory but may use custom data for specific paths.
107  """
108
109  def __init__(self, root_dir, server_cert_and_key_path=None):
110    """Starts the server.
111
112    It is an HTTP server if parameter server_cert_and_key_path is not provided.
113    Otherwise, it is an HTTPS server.
114
115    Args:
116      root_dir: root path to serve files from. This parameter is required.
117      server_cert_and_key_path: path to a PEM file containing the cert and key.
118                                if it is None, start the server as an HTTP one.
119    """
120    self._root_dir = os.path.abspath(root_dir)
121    self._server = _BaseServer(self._OnRequest, server_cert_and_key_path)
122    self._thread = threading.Thread(target=self._server.serve_forever)
123    self._thread.daemon = True
124    self._thread.start()
125    self._path_data_map = {}
126    self._path_callback_map = {}
127    self._path_maps_lock = threading.Lock()
128
129  def _OnRequest(self, request, responder):
130    path = request.GetPath().split('?')[0]
131
132    # Serve from path -> callback and data maps.
133    self._path_maps_lock.acquire()
134    try:
135      if path in self._path_callback_map:
136        body = self._path_callback_map[path](request)
137        if body:
138          responder.SendResponse(body)
139        else:
140          responder.SendError(503)
141        return
142
143      if path in self._path_data_map:
144        responder.SendResponse(self._path_data_map[path])
145        return
146    finally:
147      self._path_maps_lock.release()
148
149    # Serve from file.
150    path = os.path.normpath(
151        os.path.join(self._root_dir, *path.split('/')))
152    if not path.startswith(self._root_dir):
153      responder.SendError(403)
154      return
155    if not os.path.exists(path):
156      responder.SendError(404)
157      return
158    responder.SendResponseFromFile(path)
159
160  def SetDataForPath(self, path, data):
161    self._path_maps_lock.acquire()
162    try:
163      self._path_data_map[path] = data
164    finally:
165      self._path_maps_lock.release()
166
167  def SetCallbackForPath(self, path, func):
168    self._path_maps_lock.acquire()
169    try:
170      self._path_callback_map[path] = func
171    finally:
172      self._path_maps_lock.release()
173
174
175  def GetUrl(self):
176    """Returns the base URL of the server."""
177    return self._server.GetUrl()
178
179  def Shutdown(self):
180    """Shuts down the server synchronously."""
181    self._server.shutdown()
182    self._thread.join()
183
184
185class SyncWebServer(object):
186  """WebServer for testing.
187
188  Incoming requests are blocked until explicitly handled.
189  This was designed for single thread use. All requests should be handled on
190  the same thread.
191  """
192
193  def __init__(self):
194    self._server = _BaseServer(self._OnRequest)
195    # Recognized by SocketServer.
196    self._server.timeout = 10
197    self._on_request = None
198
199  def _OnRequest(self, request, responder):
200    self._on_request(responder)
201    self._on_request = None
202
203  def Respond(self, on_request):
204    """Blocks until request comes in, then calls given handler function.
205
206    Args:
207      on_request: Function that handles the request. Invoked with single
208          parameter, an instance of Responder.
209    """
210    if self._on_request:
211      raise RuntimeError('Must handle 1 request at a time.')
212
213    self._on_request = on_request
214    while self._on_request:
215      # Don't use handle_one_request, because it won't work with the timeout.
216      self._server.handle_request()
217
218  def RespondWithContent(self, content):
219    """Blocks until request comes in, then handles it with the given content."""
220    def SendContent(responder):
221      responder.SendResponse(content)
222    self.Respond(SendContent)
223
224  def GetUrl(self):
225    return self._server.GetUrl()
226