1# (c) 2005 Ian Bicking and contributors; written for Paste (http://pythonpaste.org)
2# Licensed under the MIT license: http://www.opensource.org/licenses/mit-license.php
3"""
4An application that proxies WSGI requests to a remote server.
5
6TODO:
7
8* Send ``Via`` header?  It's not clear to me this is a Via in the
9  style of a typical proxy.
10
11* Other headers or metadata?  I put in X-Forwarded-For, but that's it.
12
13* Signed data of non-HTTP keys?  This would be for things like
14  REMOTE_USER.
15
16* Something to indicate what the original URL was?  The original host,
17  scheme, and base path.
18
19* Rewriting ``Location`` headers?  mod_proxy does this.
20
21* Rewriting body?  (Probably not on this one -- that can be done with
22  a different middleware that wraps this middleware)
23
24* Example::
25
26    use = egg:Paste#proxy
27    address = http://server3:8680/exist/rest/db/orgs/sch/config/
28    allowed_request_methods = GET
29
30"""
31
32from six.moves import http_client as httplib
33from six.moves.urllib import parse as urlparse
34from six.moves.urllib.parse import quote
35import six
36
37from paste import httpexceptions
38from paste.util.converters import aslist
39
40# Remove these headers from response (specify lower case header
41# names):
42filtered_headers = (
43    'transfer-encoding',
44    'connection',
45    'keep-alive',
46    'proxy-authenticate',
47    'proxy-authorization',
48    'te',
49    'trailers',
50    'upgrade',
51)
52
53class Proxy(object):
54
55    def __init__(self, address, allowed_request_methods=(),
56                 suppress_http_headers=()):
57        self.address = address
58        self.parsed = urlparse.urlsplit(address)
59        self.scheme = self.parsed[0].lower()
60        self.host = self.parsed[1]
61        self.path = self.parsed[2]
62        self.allowed_request_methods = [
63            x.lower() for x in allowed_request_methods if x]
64
65        self.suppress_http_headers = [
66            x.lower() for x in suppress_http_headers if x]
67
68    def __call__(self, environ, start_response):
69        if (self.allowed_request_methods and
70            environ['REQUEST_METHOD'].lower() not in self.allowed_request_methods):
71            return httpexceptions.HTTPBadRequest("Disallowed")(environ, start_response)
72
73        if self.scheme == 'http':
74            ConnClass = httplib.HTTPConnection
75        elif self.scheme == 'https':
76            ConnClass = httplib.HTTPSConnection
77        else:
78            raise ValueError(
79                "Unknown scheme for %r: %r" % (self.address, self.scheme))
80        conn = ConnClass(self.host)
81        headers = {}
82        for key, value in environ.items():
83            if key.startswith('HTTP_'):
84                key = key[5:].lower().replace('_', '-')
85                if key == 'host' or key in self.suppress_http_headers:
86                    continue
87                headers[key] = value
88        headers['host'] = self.host
89        if 'REMOTE_ADDR' in environ:
90            headers['x-forwarded-for'] = environ['REMOTE_ADDR']
91        if environ.get('CONTENT_TYPE'):
92            headers['content-type'] = environ['CONTENT_TYPE']
93        if environ.get('CONTENT_LENGTH'):
94            if environ['CONTENT_LENGTH'] == '-1':
95                # This is a special case, where the content length is basically undetermined
96                body = environ['wsgi.input'].read(-1)
97                headers['content-length'] = str(len(body))
98            else:
99                headers['content-length'] = environ['CONTENT_LENGTH']
100                length = int(environ['CONTENT_LENGTH'])
101                body = environ['wsgi.input'].read(length)
102        else:
103            body = ''
104
105        path_info = quote(environ['PATH_INFO'])
106        if self.path:
107            request_path = path_info
108            if request_path and request_path[0] == '/':
109                request_path = request_path[1:]
110
111            path = urlparse.urljoin(self.path, request_path)
112        else:
113            path = path_info
114        if environ.get('QUERY_STRING'):
115            path += '?' + environ['QUERY_STRING']
116
117        conn.request(environ['REQUEST_METHOD'],
118                     path,
119                     body, headers)
120        res = conn.getresponse()
121        headers_out = parse_headers(res.msg)
122
123        status = '%s %s' % (res.status, res.reason)
124        start_response(status, headers_out)
125        # @@: Default?
126        length = res.getheader('content-length')
127        if length is not None:
128            body = res.read(int(length))
129        else:
130            body = res.read()
131        conn.close()
132        return [body]
133
134def make_proxy(global_conf, address, allowed_request_methods="",
135               suppress_http_headers=""):
136    """
137    Make a WSGI application that proxies to another address:
138
139    ``address``
140        the full URL ending with a trailing ``/``
141
142    ``allowed_request_methods``:
143        a space seperated list of request methods (e.g., ``GET POST``)
144
145    ``suppress_http_headers``
146        a space seperated list of http headers (lower case, without
147        the leading ``http_``) that should not be passed on to target
148        host
149    """
150    allowed_request_methods = aslist(allowed_request_methods)
151    suppress_http_headers = aslist(suppress_http_headers)
152    return Proxy(
153        address,
154        allowed_request_methods=allowed_request_methods,
155        suppress_http_headers=suppress_http_headers)
156
157
158class TransparentProxy(object):
159
160    """
161    A proxy that sends the request just as it was given, including
162    respecting HTTP_HOST, wsgi.url_scheme, etc.
163
164    This is a way of translating WSGI requests directly to real HTTP
165    requests.  All information goes in the environment; modify it to
166    modify the way the request is made.
167
168    If you specify ``force_host`` (and optionally ``force_scheme``)
169    then HTTP_HOST won't be used to determine where to connect to;
170    instead a specific host will be connected to, but the ``Host``
171    header in the request will remain intact.
172    """
173
174    def __init__(self, force_host=None,
175                 force_scheme='http'):
176        self.force_host = force_host
177        self.force_scheme = force_scheme
178
179    def __repr__(self):
180        return '<%s %s force_host=%r force_scheme=%r>' % (
181            self.__class__.__name__,
182            hex(id(self)),
183            self.force_host, self.force_scheme)
184
185    def __call__(self, environ, start_response):
186        scheme = environ['wsgi.url_scheme']
187        if self.force_host is None:
188            conn_scheme = scheme
189        else:
190            conn_scheme = self.force_scheme
191        if conn_scheme == 'http':
192            ConnClass = httplib.HTTPConnection
193        elif conn_scheme == 'https':
194            ConnClass = httplib.HTTPSConnection
195        else:
196            raise ValueError(
197                "Unknown scheme %r" % scheme)
198        if 'HTTP_HOST' not in environ:
199            raise ValueError(
200                "WSGI environ must contain an HTTP_HOST key")
201        host = environ['HTTP_HOST']
202        if self.force_host is None:
203            conn_host = host
204        else:
205            conn_host = self.force_host
206        conn = ConnClass(conn_host)
207        headers = {}
208        for key, value in environ.items():
209            if key.startswith('HTTP_'):
210                key = key[5:].lower().replace('_', '-')
211                headers[key] = value
212        headers['host'] = host
213        if 'REMOTE_ADDR' in environ and 'HTTP_X_FORWARDED_FOR' not in environ:
214            headers['x-forwarded-for'] = environ['REMOTE_ADDR']
215        if environ.get('CONTENT_TYPE'):
216            headers['content-type'] = environ['CONTENT_TYPE']
217        if environ.get('CONTENT_LENGTH'):
218            length = int(environ['CONTENT_LENGTH'])
219            body = environ['wsgi.input'].read(length)
220            if length == -1:
221                environ['CONTENT_LENGTH'] = str(len(body))
222        elif 'CONTENT_LENGTH' not in environ:
223            body = ''
224            length = 0
225        else:
226            body = ''
227            length = 0
228
229        path = (environ.get('SCRIPT_NAME', '')
230                + environ.get('PATH_INFO', ''))
231        path = quote(path)
232        if 'QUERY_STRING' in environ:
233            path += '?' + environ['QUERY_STRING']
234        conn.request(environ['REQUEST_METHOD'],
235                     path, body, headers)
236        res = conn.getresponse()
237        headers_out = parse_headers(res.msg)
238
239        status = '%s %s' % (res.status, res.reason)
240        start_response(status, headers_out)
241        # @@: Default?
242        length = res.getheader('content-length')
243        if length is not None:
244            body = res.read(int(length))
245        else:
246            body = res.read()
247        conn.close()
248        return [body]
249
250def parse_headers(message):
251    """
252    Turn a Message object into a list of WSGI-style headers.
253    """
254    headers_out = []
255    if six.PY3:
256        for header, value in message.items():
257            if header.lower() not in filtered_headers:
258                headers_out.append((header, value))
259    else:
260        for full_header in message.headers:
261            if not full_header:
262                # Shouldn't happen, but we'll just ignore
263                continue
264            if full_header[0].isspace():
265                # Continuation line, add to the last header
266                if not headers_out:
267                    raise ValueError(
268                        "First header starts with a space (%r)" % full_header)
269                last_header, last_value = headers_out.pop()
270                value = last_value + ' ' + full_header.strip()
271                headers_out.append((last_header, value))
272                continue
273            try:
274                header, value = full_header.split(':', 1)
275            except:
276                raise ValueError("Invalid header: %r" % full_header)
277            value = value.strip()
278            if header.lower() not in filtered_headers:
279                headers_out.append((header, value))
280    return headers_out
281
282def make_transparent_proxy(
283    global_conf, force_host=None, force_scheme='http'):
284    """
285    Create a proxy that connects to a specific host, but does
286    absolutely no other filtering, including the Host header.
287    """
288    return TransparentProxy(force_host=force_host,
289                            force_scheme=force_scheme)
290