1# (c) 2005 Ian Bicking and contributors; written for Paste (http://pythonpaste.org)
2# Licensed under the MIT license: http://www.opensource.org/licenses/mit-license.php
3
4# (c) 2005 Ian Bicking and contributors; written for Paste (http://pythonpaste.org)
5# Licensed under the MIT license: http://www.opensource.org/licenses/mit-license.php
6
7"""
8WSGI middleware
9
10Gzip-encodes the response.
11"""
12
13import gzip
14from paste.response import header_value, remove_header
15from paste.httpheaders import CONTENT_LENGTH
16import six
17
18class GzipOutput(object):
19    pass
20
21class middleware(object):
22
23    def __init__(self, application, compress_level=6):
24        self.application = application
25        self.compress_level = int(compress_level)
26
27    def __call__(self, environ, start_response):
28        if 'gzip' not in environ.get('HTTP_ACCEPT_ENCODING', ''):
29            # nothing for us to do, so this middleware will
30            # be a no-op:
31            return self.application(environ, start_response)
32        response = GzipResponse(start_response, self.compress_level)
33        app_iter = self.application(environ,
34                                    response.gzip_start_response)
35        if app_iter is not None:
36            response.finish_response(app_iter)
37
38        return response.write()
39
40class GzipResponse(object):
41
42    def __init__(self, start_response, compress_level):
43        self.start_response = start_response
44        self.compress_level = compress_level
45        self.buffer = six.BytesIO()
46        self.compressible = False
47        self.content_length = None
48
49    def gzip_start_response(self, status, headers, exc_info=None):
50        self.headers = headers
51        ct = header_value(headers,'content-type')
52        ce = header_value(headers,'content-encoding')
53        self.compressible = False
54        if ct and (ct.startswith('text/') or ct.startswith('application/')) \
55            and 'zip' not in ct:
56            self.compressible = True
57        if ce:
58            self.compressible = False
59        if self.compressible:
60            headers.append(('content-encoding', 'gzip'))
61        remove_header(headers, 'content-length')
62        self.headers = headers
63        self.status = status
64        return self.buffer.write
65
66    def write(self):
67        out = self.buffer
68        out.seek(0)
69        s = out.getvalue()
70        out.close()
71        return [s]
72
73    def finish_response(self, app_iter):
74        if self.compressible:
75            output = gzip.GzipFile(mode='wb', compresslevel=self.compress_level,
76                fileobj=self.buffer)
77        else:
78            output = self.buffer
79        try:
80            for s in app_iter:
81                output.write(s)
82            if self.compressible:
83                output.close()
84        finally:
85            if hasattr(app_iter, 'close'):
86                app_iter.close()
87        content_length = self.buffer.tell()
88        CONTENT_LENGTH.update(self.headers, content_length)
89        self.start_response(self.status, self.headers)
90
91def filter_factory(application, **conf):
92    import warnings
93    warnings.warn(
94        'This function is deprecated; use make_gzip_middleware instead',
95        DeprecationWarning, 2)
96    def filter(application):
97        return middleware(application)
98    return filter
99
100def make_gzip_middleware(app, global_conf, compress_level=6):
101    """
102    Wrap the middleware, so that it applies gzipping to a response
103    when it is supported by the browser and the content is of
104    type ``text/*`` or ``application/*``
105    """
106    compress_level = int(compress_level)
107    return middleware(app, compress_level=compress_level)
108