1# Copyright 2011, Google Inc.
2# All rights reserved.
3#
4# Redistribution and use in source and binary forms, with or without
5# modification, are permitted provided that the following conditions are
6# met:
7#
8#     * Redistributions of source code must retain the above copyright
9# notice, this list of conditions and the following disclaimer.
10#     * Redistributions in binary form must reproduce the above
11# copyright notice, this list of conditions and the following disclaimer
12# in the documentation and/or other materials provided with the
13# distribution.
14#     * Neither the name of Google Inc. nor the names of its
15# contributors may be used to endorse or promote products derived from
16# this software without specific prior written permission.
17#
18# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
30
31"""WebSocket utilities.
32"""
33
34
35import array
36import errno
37
38# Import hash classes from a module available and recommended for each Python
39# version and re-export those symbol. Use sha and md5 module in Python 2.4, and
40# hashlib module in Python 2.6.
41try:
42    import hashlib
43    md5_hash = hashlib.md5
44    sha1_hash = hashlib.sha1
45except ImportError:
46    import md5
47    import sha
48    md5_hash = md5.md5
49    sha1_hash = sha.sha
50
51import StringIO
52import logging
53import os
54import re
55import socket
56import traceback
57import zlib
58
59
60def get_stack_trace():
61    """Get the current stack trace as string.
62
63    This is needed to support Python 2.3.
64    TODO: Remove this when we only support Python 2.4 and above.
65          Use traceback.format_exc instead.
66    """
67
68    out = StringIO.StringIO()
69    traceback.print_exc(file=out)
70    return out.getvalue()
71
72
73def prepend_message_to_exception(message, exc):
74    """Prepend message to the exception."""
75
76    exc.args = (message + str(exc),)
77    return
78
79
80def __translate_interp(interp, cygwin_path):
81    """Translate interp program path for Win32 python to run cygwin program
82    (e.g. perl).  Note that it doesn't support path that contains space,
83    which is typically true for Unix, where #!-script is written.
84    For Win32 python, cygwin_path is a directory of cygwin binaries.
85
86    Args:
87      interp: interp command line
88      cygwin_path: directory name of cygwin binary, or None
89    Returns:
90      translated interp command line.
91    """
92    if not cygwin_path:
93        return interp
94    m = re.match('^[^ ]*/([^ ]+)( .*)?', interp)
95    if m:
96        cmd = os.path.join(cygwin_path, m.group(1))
97        return cmd + m.group(2)
98    return interp
99
100
101def get_script_interp(script_path, cygwin_path=None):
102    """Gets #!-interpreter command line from the script.
103
104    It also fixes command path.  When Cygwin Python is used, e.g. in WebKit,
105    it could run "/usr/bin/perl -wT hello.pl".
106    When Win32 Python is used, e.g. in Chromium, it couldn't.  So, fix
107    "/usr/bin/perl" to "<cygwin_path>\perl.exe".
108
109    Args:
110      script_path: pathname of the script
111      cygwin_path: directory name of cygwin binary, or None
112    Returns:
113      #!-interpreter command line, or None if it is not #!-script.
114    """
115    fp = open(script_path)
116    line = fp.readline()
117    fp.close()
118    m = re.match('^#!(.*)', line)
119    if m:
120        return __translate_interp(m.group(1), cygwin_path)
121    return None
122
123
124def wrap_popen3_for_win(cygwin_path):
125    """Wrap popen3 to support #!-script on Windows.
126
127    Args:
128      cygwin_path:  path for cygwin binary if command path is needed to be
129                    translated.  None if no translation required.
130    """
131
132    __orig_popen3 = os.popen3
133
134    def __wrap_popen3(cmd, mode='t', bufsize=-1):
135        cmdline = cmd.split(' ')
136        interp = get_script_interp(cmdline[0], cygwin_path)
137        if interp:
138            cmd = interp + ' ' + cmd
139        return __orig_popen3(cmd, mode, bufsize)
140
141    os.popen3 = __wrap_popen3
142
143
144def hexify(s):
145    return ' '.join(map(lambda x: '%02x' % ord(x), s))
146
147
148def get_class_logger(o):
149    return logging.getLogger(
150        '%s.%s' % (o.__class__.__module__, o.__class__.__name__))
151
152
153class NoopMasker(object):
154    """A masking object that has the same interface as RepeatedXorMasker but
155    just returns the string passed in without making any change.
156    """
157
158    def __init__(self):
159        pass
160
161    def mask(self, s):
162        return s
163
164
165class RepeatedXorMasker(object):
166    """A masking object that applies XOR on the string given to mask method
167    with the masking bytes given to the constructor repeatedly. This object
168    remembers the position in the masking bytes the last mask method call
169    ended and resumes from that point on the next mask method call.
170    """
171
172    def __init__(self, mask):
173        self._mask = map(ord, mask)
174        self._mask_size = len(self._mask)
175        self._count = 0
176
177    def mask(self, s):
178        result = array.array('B')
179        result.fromstring(s)
180        # Use temporary local variables to eliminate the cost to access
181        # attributes
182        count = self._count
183        mask = self._mask
184        mask_size = self._mask_size
185        for i in xrange(len(result)):
186            result[i] ^= mask[count]
187            count = (count + 1) % mask_size
188        self._count = count
189
190        return result.tostring()
191
192
193class DeflateRequest(object):
194    """A wrapper class for request object to intercept send and recv to perform
195    deflate compression and decompression transparently.
196    """
197
198    def __init__(self, request):
199        self._request = request
200        self.connection = DeflateConnection(request.connection)
201
202    def __getattribute__(self, name):
203        if name in ('_request', 'connection'):
204            return object.__getattribute__(self, name)
205        return self._request.__getattribute__(name)
206
207    def __setattr__(self, name, value):
208        if name in ('_request', 'connection'):
209            return object.__setattr__(self, name, value)
210        return self._request.__setattr__(name, value)
211
212
213# By making wbits option negative, we can suppress CMF/FLG (2 octet) and
214# ADLER32 (4 octet) fields of zlib so that we can use zlib module just as
215# deflate library. DICTID won't be added as far as we don't set dictionary.
216# LZ77 window of 32K will be used for both compression and decompression.
217# For decompression, we can just use 32K to cover any windows size. For
218# compression, we use 32K so receivers must use 32K.
219#
220# Compression level is Z_DEFAULT_COMPRESSION. We don't have to match level
221# to decode.
222#
223# See zconf.h, deflate.cc, inflate.cc of zlib library, and zlibmodule.c of
224# Python. See also RFC1950 (ZLIB 3.3).
225
226
227class _Deflater(object):
228
229    def __init__(self, window_bits):
230        self._logger = get_class_logger(self)
231
232        self._compress = zlib.compressobj(
233            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -window_bits)
234
235    def compress_and_flush(self, bytes):
236        compressed_bytes = self._compress.compress(bytes)
237        compressed_bytes += self._compress.flush(zlib.Z_SYNC_FLUSH)
238        self._logger.debug('Compress input %r', bytes)
239        self._logger.debug('Compress result %r', compressed_bytes)
240        return compressed_bytes
241
242
243class _Inflater(object):
244
245    def __init__(self):
246        self._logger = get_class_logger(self)
247
248        self._unconsumed = ''
249
250        self.reset()
251
252    def decompress(self, size):
253        if not (size == -1 or size > 0):
254            raise Exception('size must be -1 or positive')
255
256        data = ''
257
258        while True:
259            if size == -1:
260                data += self._decompress.decompress(self._unconsumed)
261                # See Python bug http://bugs.python.org/issue12050 to
262                # understand why the same code cannot be used for updating
263                # self._unconsumed for here and else block.
264                self._unconsumed = ''
265            else:
266                data += self._decompress.decompress(
267                    self._unconsumed, size - len(data))
268                self._unconsumed = self._decompress.unconsumed_tail
269            if self._decompress.unused_data:
270                # Encountered a last block (i.e. a block with BFINAL = 1) and
271                # found a new stream (unused_data). We cannot use the same
272                # zlib.Decompress object for the new stream. Create a new
273                # Decompress object to decompress the new one.
274                #
275                # It's fine to ignore unconsumed_tail if unused_data is not
276                # empty.
277                self._unconsumed = self._decompress.unused_data
278                self.reset()
279                if size >= 0 and len(data) == size:
280                    # data is filled. Don't call decompress again.
281                    break
282                else:
283                    # Re-invoke Decompress.decompress to try to decompress all
284                    # available bytes before invoking read which blocks until
285                    # any new byte is available.
286                    continue
287            else:
288                # Here, since unused_data is empty, even if unconsumed_tail is
289                # not empty, bytes of requested length are already in data. We
290                # don't have to "continue" here.
291                break
292
293        if data:
294            self._logger.debug('Decompressed %r', data)
295        return data
296
297    def append(self, data):
298        self._logger.debug('Appended %r', data)
299        self._unconsumed += data
300
301    def reset(self):
302        self._logger.debug('Reset')
303        self._decompress = zlib.decompressobj(-zlib.MAX_WBITS)
304
305
306# Compresses/decompresses given octets using the method introduced in RFC1979.
307
308
309class _RFC1979Deflater(object):
310    """A compressor class that applies DEFLATE to given byte sequence and
311    flushes using the algorithm described in the RFC1979 section 2.1.
312    """
313
314    def __init__(self, window_bits, no_context_takeover):
315        self._deflater = None
316        if window_bits is None:
317            window_bits = zlib.MAX_WBITS
318        self._window_bits = window_bits
319        self._no_context_takeover = no_context_takeover
320
321    def filter(self, bytes):
322        if self._deflater is None or self._no_context_takeover:
323            self._deflater = _Deflater(self._window_bits)
324
325        # Strip last 4 octets which is LEN and NLEN field of a non-compressed
326        # block added for Z_SYNC_FLUSH.
327        return self._deflater.compress_and_flush(bytes)[:-4]
328
329
330class _RFC1979Inflater(object):
331    """A decompressor class for byte sequence compressed and flushed following
332    the algorithm described in the RFC1979 section 2.1.
333    """
334
335    def __init__(self):
336        self._inflater = _Inflater()
337
338    def filter(self, bytes):
339        # Restore stripped LEN and NLEN field of a non-compressed block added
340        # for Z_SYNC_FLUSH.
341        self._inflater.append(bytes + '\x00\x00\xff\xff')
342        return self._inflater.decompress(-1)
343
344
345class DeflateSocket(object):
346    """A wrapper class for socket object to intercept send and recv to perform
347    deflate compression and decompression transparently.
348    """
349
350    # Size of the buffer passed to recv to receive compressed data.
351    _RECV_SIZE = 4096
352
353    def __init__(self, socket):
354        self._socket = socket
355
356        self._logger = get_class_logger(self)
357
358        self._deflater = _Deflater(zlib.MAX_WBITS)
359        self._inflater = _Inflater()
360
361    def recv(self, size):
362        """Receives data from the socket specified on the construction up
363        to the specified size. Once any data is available, returns it even
364        if it's smaller than the specified size.
365        """
366
367        # TODO(tyoshino): Allow call with size=0. It should block until any
368        # decompressed data is available.
369        if size <= 0:
370            raise Exception('Non-positive size passed')
371        while True:
372            data = self._inflater.decompress(size)
373            if len(data) != 0:
374                return data
375
376            read_data = self._socket.recv(DeflateSocket._RECV_SIZE)
377            if not read_data:
378                return ''
379            self._inflater.append(read_data)
380
381    def sendall(self, bytes):
382        self.send(bytes)
383
384    def send(self, bytes):
385        self._socket.sendall(self._deflater.compress_and_flush(bytes))
386        return len(bytes)
387
388
389class DeflateConnection(object):
390    """A wrapper class for request object to intercept write and read to
391    perform deflate compression and decompression transparently.
392    """
393
394    def __init__(self, connection):
395        self._connection = connection
396
397        self._logger = get_class_logger(self)
398
399        self._deflater = _Deflater(zlib.MAX_WBITS)
400        self._inflater = _Inflater()
401
402    def get_remote_addr(self):
403        return self._connection.remote_addr
404    remote_addr = property(get_remote_addr)
405
406    def put_bytes(self, bytes):
407        self.write(bytes)
408
409    def read(self, size=-1):
410        """Reads at most size bytes. Blocks until there's at least one byte
411        available.
412        """
413
414        # TODO(tyoshino): Allow call with size=0.
415        if not (size == -1 or size > 0):
416            raise Exception('size must be -1 or positive')
417
418        data = ''
419        while True:
420            if size == -1:
421                data += self._inflater.decompress(-1)
422            else:
423                data += self._inflater.decompress(size - len(data))
424
425            if size >= 0 and len(data) != 0:
426                break
427
428            # TODO(tyoshino): Make this read efficient by some workaround.
429            #
430            # In 3.0.3 and prior of mod_python, read blocks until length bytes
431            # was read. We don't know the exact size to read while using
432            # deflate, so read byte-by-byte.
433            #
434            # _StandaloneRequest.read that ultimately performs
435            # socket._fileobject.read also blocks until length bytes was read
436            read_data = self._connection.read(1)
437            if not read_data:
438                break
439            self._inflater.append(read_data)
440        return data
441
442    def write(self, bytes):
443        self._connection.write(self._deflater.compress_and_flush(bytes))
444
445
446def _is_ewouldblock_errno(error_number):
447    """Returns True iff error_number indicates that receive operation would
448    block. To make this portable, we check availability of errno and then
449    compare them.
450    """
451
452    for error_name in ['WSAEWOULDBLOCK', 'EWOULDBLOCK', 'EAGAIN']:
453        if (error_name in dir(errno) and
454            error_number == getattr(errno, error_name)):
455            return True
456    return False
457
458
459def drain_received_data(raw_socket):
460    # Set the socket non-blocking.
461    original_timeout = raw_socket.gettimeout()
462    raw_socket.settimeout(0.0)
463
464    drained_data = []
465
466    # Drain until the socket is closed or no data is immediately
467    # available for read.
468    while True:
469        try:
470            data = raw_socket.recv(1)
471            if not data:
472                break
473            drained_data.append(data)
474        except socket.error, e:
475            # e can be either a pair (errno, string) or just a string (or
476            # something else) telling what went wrong. We suppress only
477            # the errors that indicates that the socket blocks. Those
478            # exceptions can be parsed as a pair (errno, string).
479            try:
480                error_number, message = e
481            except:
482                # Failed to parse socket.error.
483                raise e
484
485            if _is_ewouldblock_errno(error_number):
486                break
487            else:
488                raise e
489
490    # Rollback timeout value.
491    raw_socket.settimeout(original_timeout)
492
493    return ''.join(drained_data)
494
495
496# vi:sts=4 sw=4 et
497