1# Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5"""Spins up a trivial HTTP cgi form listener in a thread.
6
7   This HTTPThread class is a utility for use with test cases that
8   need to call back to the Autotest test case with some form value, e.g.
9   http://localhost:nnnn/?status="Browser started!"
10"""
11
12import cgi, errno, logging, os, posixpath, SimpleHTTPServer, socket, ssl, sys
13import threading, urllib, urlparse
14from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
15from SocketServer import BaseServer, ThreadingMixIn
16
17
18def _handle_http_errors(func):
19    """Decorator function for cleaner presentation of certain exceptions."""
20    def wrapper(self):
21        try:
22            func(self)
23        except IOError, e:
24            if e.errno == errno.EPIPE or e.errno == errno.ECONNRESET:
25                # Instead of dumping a stack trace, a single line is sufficient.
26                self.log_error(str(e))
27            else:
28                raise
29
30    return wrapper
31
32
33class FormHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
34    """Implements a form handler (for POST requests only) which simply
35    echoes the key=value parameters back in the response.
36
37    If the form submission is a file upload, the file will be written
38    to disk with the name contained in the 'filename' field.
39    """
40
41    SimpleHTTPServer.SimpleHTTPRequestHandler.extensions_map.update({
42        '.webm': 'video/webm',
43    })
44
45    # Override the default logging methods to use the logging module directly.
46    def log_error(self, format, *args):
47        logging.warning("(httpd error) %s - - [%s] %s\n" %
48                     (self.address_string(), self.log_date_time_string(),
49                      format%args))
50
51    def log_message(self, format, *args):
52        logging.debug("%s - - [%s] %s\n" %
53                     (self.address_string(), self.log_date_time_string(),
54                      format%args))
55
56    @_handle_http_errors
57    def do_POST(self):
58        form = cgi.FieldStorage(
59            fp=self.rfile,
60            headers=self.headers,
61            environ={'REQUEST_METHOD': 'POST',
62                     'CONTENT_TYPE': self.headers['Content-Type']})
63        # You'd think form.keys() would just return [], like it does for empty
64        # python dicts; you'd be wrong. It raises TypeError if called when it
65        # has no keys.
66        if form:
67            for field in form.keys():
68                field_item = form[field]
69                self.server._form_entries[field] = field_item.value
70        path = urlparse.urlparse(self.path)[2]
71        if path in self.server._url_handlers:
72            self.server._url_handlers[path](self, form)
73        else:
74            # Echo back information about what was posted in the form.
75            self.write_post_response(form)
76        self._fire_event()
77
78
79    def write_post_response(self, form):
80        """Called to fill out the response to an HTTP POST.
81
82        Override this class to give custom responses.
83        """
84        # Send response boilerplate
85        self.send_response(200)
86        self.end_headers()
87        self.wfile.write('Hello from Autotest!\nClient: %s\n' %
88                         str(self.client_address))
89        self.wfile.write('Request for path: %s\n' % self.path)
90        self.wfile.write('Got form data:\n')
91
92        # See the note in do_POST about form.keys().
93        if form:
94            for field in form.keys():
95                field_item = form[field]
96                if field_item.filename:
97                    # The field contains an uploaded file
98                    upload = field_item.file.read()
99                    self.wfile.write('\tUploaded %s (%d bytes)<br>' %
100                                     (field, len(upload)))
101                    # Write submitted file to specified filename.
102                    file(field_item.filename, 'w').write(upload)
103                    del upload
104                else:
105                    self.wfile.write('\t%s=%s<br>' % (field, form[field].value))
106
107
108    def translate_path(self, path):
109        """Override SimpleHTTPRequestHandler's translate_path to serve
110        from arbitrary docroot
111        """
112        # abandon query parameters
113        path = urlparse.urlparse(path)[2]
114        path = posixpath.normpath(urllib.unquote(path))
115        words = path.split('/')
116        words = filter(None, words)
117        path = self.server.docroot
118        for word in words:
119            drive, word = os.path.splitdrive(word)
120            head, word = os.path.split(word)
121            if word in (os.curdir, os.pardir): continue
122            path = os.path.join(path, word)
123        logging.debug('Translated path: %s', path)
124        return path
125
126
127    def _fire_event(self):
128        wait_urls = self.server._wait_urls
129        if self.path in wait_urls:
130            _, e = wait_urls[self.path]
131            e.set()
132            del wait_urls[self.path]
133        else:
134            logging.debug('URL %s not in watch list' % self.path)
135
136
137    @_handle_http_errors
138    def do_GET(self):
139        form = cgi.FieldStorage(
140            fp=self.rfile,
141            headers=self.headers,
142            environ={'REQUEST_METHOD': 'GET'})
143        split_url = urlparse.urlsplit(self.path)
144        path = split_url[2]
145        # Strip off query parameters to ensure that the url path
146        # matches any registered events.
147        self.path = path
148        args = urlparse.parse_qs(split_url[3])
149        if path in self.server._url_handlers:
150            self.server._url_handlers[path](self, args)
151        else:
152            SimpleHTTPServer.SimpleHTTPRequestHandler.do_GET(self)
153        self._fire_event()
154
155
156    @_handle_http_errors
157    def do_HEAD(self):
158        SimpleHTTPServer.SimpleHTTPRequestHandler.do_HEAD(self)
159
160
161class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
162    def __init__(self, server_address, HandlerClass):
163        HTTPServer.__init__(self, server_address, HandlerClass)
164
165
166class HTTPListener(object):
167    # Point default docroot to a non-existent directory (instead of None) to
168    # avoid exceptions when page content is served through handlers only.
169    def __init__(self, port=0, docroot='/_', wait_urls={}, url_handlers={}):
170        self._server = ThreadedHTTPServer(('', port), FormHandler)
171        self.config_server(self._server, docroot, wait_urls, url_handlers)
172
173    def config_server(self, server, docroot, wait_urls, url_handlers):
174        # Stuff some convenient data fields into the server object.
175        self._server.docroot = docroot
176        self._server._wait_urls = wait_urls
177        self._server._url_handlers = url_handlers
178        self._server._form_entries = {}
179        self._server_thread = threading.Thread(
180            target=self._server.serve_forever)
181
182
183    def add_wait_url(self, url='/', matchParams={}):
184        e = threading.Event()
185        self._server._wait_urls[url] = (matchParams, e)
186        return e
187
188
189    def add_url_handler(self, url, handler_func):
190        self._server._url_handlers[url] = handler_func
191
192
193    def clear_form_entries(self):
194        self._server._form_entries = {}
195
196
197    def get_form_entries(self):
198        """Returns a dictionary of all field=values recieved by the server.
199        """
200        return self._server._form_entries
201
202
203    def run(self):
204        logging.debug('http server on %s:%d' %
205                      (self._server.server_name, self._server.server_port))
206        self._server_thread.start()
207
208
209    def stop(self):
210        self._server.shutdown()
211        self._server.socket.close()
212        self._server_thread.join()
213
214
215class SecureHTTPServer(ThreadingMixIn, HTTPServer):
216    def __init__(self, server_address, HandlerClass, cert_path, key_path):
217        _socket = socket.socket(self.address_family, self.socket_type)
218        self.socket = ssl.wrap_socket(_socket,
219                                      server_side=True,
220                                      ssl_version=ssl.PROTOCOL_TLSv1,
221                                      certfile=cert_path,
222                                      keyfile=key_path)
223        BaseServer.__init__(self, server_address, HandlerClass)
224        self.server_bind()
225        self.server_activate()
226
227
228class SecureHTTPRequestHandler(FormHandler):
229    def setup(self):
230        self.connection = self.request
231        self.rfile = socket._fileobject(self.request, 'rb', self.rbufsize)
232        self.wfile = socket._fileobject(self.request, 'wb', self.wbufsize)
233
234    # Override the default logging methods to use the logging module directly.
235    def log_error(self, format, *args):
236        logging.warning("(httpd error) %s - - [%s] %s\n" %
237                     (self.address_string(), self.log_date_time_string(),
238                      format%args))
239
240    def log_message(self, format, *args):
241        logging.debug("%s - - [%s] %s\n" %
242                     (self.address_string(), self.log_date_time_string(),
243                      format%args))
244
245
246class SecureHTTPListener(HTTPListener):
247    def __init__(self,
248                 cert_path='/etc/login_trust_root.pem',
249                 key_path='/etc/mock_server.key',
250                 port=0,
251                 docroot='/_',
252                 wait_urls={},
253                 url_handlers={}):
254        self._server = SecureHTTPServer(('', port),
255                                        SecureHTTPRequestHandler,
256                                        cert_path,
257                                        key_path)
258        self.config_server(self._server, docroot, wait_urls, url_handlers)
259
260
261    def getsockname(self):
262        return self._server.socket.getsockname()
263
264