1#!/usr/bin/env python
2#
3# Copyright 2012, Google Inc.
4# All rights reserved.
5#
6# Redistribution and use in source and binary forms, with or without
7# modification, are permitted provided that the following conditions are
8# met:
9#
10#     * Redistributions of source code must retain the above copyright
11# notice, this list of conditions and the following disclaimer.
12#     * Redistributions in binary form must reproduce the above
13# copyright notice, this list of conditions and the following disclaimer
14# in the documentation and/or other materials provided with the
15# distribution.
16#     * Neither the name of Google Inc. nor the names of its
17# contributors may be used to endorse or promote products derived from
18# this software without specific prior written permission.
19#
20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
32
33"""Tests for msgutil module."""
34
35
36import array
37import Queue
38import struct
39import unittest
40import zlib
41
42import set_sys_path  # Update sys.path to locate mod_pywebsocket module.
43
44from mod_pywebsocket import common
45from mod_pywebsocket.extensions import DeflateFrameExtensionProcessor
46from mod_pywebsocket.extensions import PerFrameCompressionExtensionProcessor
47from mod_pywebsocket.extensions import PerMessageCompressionExtensionProcessor
48from mod_pywebsocket import msgutil
49from mod_pywebsocket.stream import InvalidUTF8Exception
50from mod_pywebsocket.stream import Stream
51from mod_pywebsocket.stream import StreamHixie75
52from mod_pywebsocket.stream import StreamOptions
53from mod_pywebsocket import util
54from test import mock
55
56
57# We use one fixed nonce for testing instead of cryptographically secure PRNG.
58_MASKING_NONCE = 'ABCD'
59
60
61def _mask_hybi(frame):
62    frame_key = map(ord, _MASKING_NONCE)
63    frame_key_len = len(frame_key)
64    result = array.array('B')
65    result.fromstring(frame)
66    count = 0
67    for i in xrange(len(result)):
68        result[i] ^= frame_key[count]
69        count = (count + 1) % frame_key_len
70    return _MASKING_NONCE + result.tostring()
71
72
73def _install_extension_processor(processor, request, stream_options):
74    response = processor.get_extension_response()
75    if response is not None:
76        processor.setup_stream_options(stream_options)
77        request.ws_extension_processors.append(processor)
78
79
80def _create_request_from_rawdata(
81    read_data, deflate_stream=False, deflate_frame_request=None,
82    perframe_compression_request=None, permessage_compression_request=None):
83    req = mock.MockRequest(connection=mock.MockConn(''.join(read_data)))
84    req.ws_version = common.VERSION_HYBI_LATEST
85    stream_options = StreamOptions()
86    stream_options.deflate_stream = deflate_stream
87    req.ws_extension_processors = []
88    if deflate_frame_request is not None:
89        processor = DeflateFrameExtensionProcessor(deflate_frame_request)
90        _install_extension_processor(processor, req, stream_options)
91    elif perframe_compression_request is not None:
92        processor = PerFrameCompressionExtensionProcessor(
93            perframe_compression_request)
94        _install_extension_processor(processor, req, stream_options)
95    elif permessage_compression_request is not None:
96        processor = PerMessageCompressionExtensionProcessor(
97            permessage_compression_request)
98        _install_extension_processor(processor, req, stream_options)
99
100    req.ws_stream = Stream(req, stream_options)
101    return req
102
103
104def _create_request(*frames):
105    """Creates MockRequest using data given as frames.
106
107    frames will be returned on calling request.connection.read() where request
108    is MockRequest returned by this function.
109    """
110
111    read_data = []
112    for (header, body) in frames:
113        read_data.append(header + _mask_hybi(body))
114
115    return _create_request_from_rawdata(read_data)
116
117
118def _create_blocking_request():
119    """Creates MockRequest.
120
121    Data written to a MockRequest can be read out by calling
122    request.connection.written_data().
123    """
124
125    req = mock.MockRequest(connection=mock.MockBlockingConn())
126    req.ws_version = common.VERSION_HYBI_LATEST
127    stream_options = StreamOptions()
128    req.ws_stream = Stream(req, stream_options)
129    return req
130
131
132def _create_request_hixie75(read_data=''):
133    req = mock.MockRequest(connection=mock.MockConn(read_data))
134    req.ws_stream = StreamHixie75(req)
135    return req
136
137
138def _create_blocking_request_hixie75():
139    req = mock.MockRequest(connection=mock.MockBlockingConn())
140    req.ws_stream = StreamHixie75(req)
141    return req
142
143
144class MessageTest(unittest.TestCase):
145    # Tests for Stream
146
147    def test_send_message(self):
148        request = _create_request()
149        msgutil.send_message(request, 'Hello')
150        self.assertEqual('\x81\x05Hello', request.connection.written_data())
151
152        payload = 'a' * 125
153        request = _create_request()
154        msgutil.send_message(request, payload)
155        self.assertEqual('\x81\x7d' + payload,
156                         request.connection.written_data())
157
158    def test_send_medium_message(self):
159        payload = 'a' * 126
160        request = _create_request()
161        msgutil.send_message(request, payload)
162        self.assertEqual('\x81\x7e\x00\x7e' + payload,
163                         request.connection.written_data())
164
165        payload = 'a' * ((1 << 16) - 1)
166        request = _create_request()
167        msgutil.send_message(request, payload)
168        self.assertEqual('\x81\x7e\xff\xff' + payload,
169                         request.connection.written_data())
170
171    def test_send_large_message(self):
172        payload = 'a' * (1 << 16)
173        request = _create_request()
174        msgutil.send_message(request, payload)
175        self.assertEqual('\x81\x7f\x00\x00\x00\x00\x00\x01\x00\x00' + payload,
176                         request.connection.written_data())
177
178    def test_send_message_unicode(self):
179        request = _create_request()
180        msgutil.send_message(request, u'\u65e5')
181        # U+65e5 is encoded as e6,97,a5 in UTF-8
182        self.assertEqual('\x81\x03\xe6\x97\xa5',
183                         request.connection.written_data())
184
185    def test_send_message_fragments(self):
186        request = _create_request()
187        msgutil.send_message(request, 'Hello', False)
188        msgutil.send_message(request, ' ', False)
189        msgutil.send_message(request, 'World', False)
190        msgutil.send_message(request, '!', True)
191        self.assertEqual('\x01\x05Hello\x00\x01 \x00\x05World\x80\x01!',
192                         request.connection.written_data())
193
194    def test_send_fragments_immediate_zero_termination(self):
195        request = _create_request()
196        msgutil.send_message(request, 'Hello World!', False)
197        msgutil.send_message(request, '', True)
198        self.assertEqual('\x01\x0cHello World!\x80\x00',
199                         request.connection.written_data())
200
201    def test_send_message_deflate_stream(self):
202        compress = zlib.compressobj(
203            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
204
205        request = _create_request_from_rawdata('', deflate_stream=True)
206        msgutil.send_message(request, 'Hello')
207        expected = compress.compress('\x81\x05Hello')
208        expected += compress.flush(zlib.Z_SYNC_FLUSH)
209        self.assertEqual(expected, request.connection.written_data())
210
211    def test_send_message_deflate_frame(self):
212        compress = zlib.compressobj(
213            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
214
215        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
216        request = _create_request_from_rawdata(
217            '', deflate_frame_request=extension)
218        msgutil.send_message(request, 'Hello')
219        msgutil.send_message(request, 'World')
220
221        expected = ''
222
223        compressed_hello = compress.compress('Hello')
224        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
225        compressed_hello = compressed_hello[:-4]
226        expected += '\xc1%c' % len(compressed_hello)
227        expected += compressed_hello
228
229        compressed_world = compress.compress('World')
230        compressed_world += compress.flush(zlib.Z_SYNC_FLUSH)
231        compressed_world = compressed_world[:-4]
232        expected += '\xc1%c' % len(compressed_world)
233        expected += compressed_world
234
235        self.assertEqual(expected, request.connection.written_data())
236
237    def test_send_message_deflate_frame_comp_bit(self):
238        compress = zlib.compressobj(
239            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
240
241        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
242        request = _create_request_from_rawdata(
243            '', deflate_frame_request=extension)
244        self.assertEquals(1, len(request.ws_extension_processors))
245        deflate_frame_processor = request.ws_extension_processors[0]
246        msgutil.send_message(request, 'Hello')
247        deflate_frame_processor.disable_outgoing_compression()
248        msgutil.send_message(request, 'Hello')
249        deflate_frame_processor.enable_outgoing_compression()
250        msgutil.send_message(request, 'Hello')
251
252        expected = ''
253
254        compressed_hello = compress.compress('Hello')
255        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
256        compressed_hello = compressed_hello[:-4]
257        expected += '\xc1%c' % len(compressed_hello)
258        expected += compressed_hello
259
260        expected += '\x81\x05Hello'
261
262        compressed_2nd_hello = compress.compress('Hello')
263        compressed_2nd_hello += compress.flush(zlib.Z_SYNC_FLUSH)
264        compressed_2nd_hello = compressed_2nd_hello[:-4]
265        expected += '\xc1%c' % len(compressed_2nd_hello)
266        expected += compressed_2nd_hello
267
268        self.assertEqual(expected, request.connection.written_data())
269
270    def test_send_message_deflate_frame_no_context_takeover_parameter(self):
271        compress = zlib.compressobj(
272            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
273
274        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
275        extension.add_parameter('no_context_takeover', None)
276        request = _create_request_from_rawdata(
277            '', deflate_frame_request=extension)
278        for i in xrange(3):
279            msgutil.send_message(request, 'Hello')
280
281        compressed_message = compress.compress('Hello')
282        compressed_message += compress.flush(zlib.Z_SYNC_FLUSH)
283        compressed_message = compressed_message[:-4]
284        expected = '\xc1%c' % len(compressed_message)
285        expected += compressed_message
286
287        self.assertEqual(
288            expected + expected + expected, request.connection.written_data())
289
290    def test_deflate_frame_bad_request_parameters(self):
291        """Tests that if there's anything wrong with deflate-frame extension
292        request, deflate-frame is rejected.
293        """
294
295        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
296        # max_window_bits less than 8 is illegal.
297        extension.add_parameter('max_window_bits', '7')
298        processor = DeflateFrameExtensionProcessor(extension)
299        self.assertEqual(None, processor.get_extension_response())
300
301        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
302        # max_window_bits greater than 15 is illegal.
303        extension.add_parameter('max_window_bits', '16')
304        processor = DeflateFrameExtensionProcessor(extension)
305        self.assertEqual(None, processor.get_extension_response())
306
307        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
308        # Non integer max_window_bits is illegal.
309        extension.add_parameter('max_window_bits', 'foobar')
310        processor = DeflateFrameExtensionProcessor(extension)
311        self.assertEqual(None, processor.get_extension_response())
312
313        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
314        # no_context_takeover must not have any value.
315        extension.add_parameter('no_context_takeover', 'foobar')
316        processor = DeflateFrameExtensionProcessor(extension)
317        self.assertEqual(None, processor.get_extension_response())
318
319    def test_deflate_frame_response_parameters(self):
320        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
321        processor = DeflateFrameExtensionProcessor(extension)
322        processor.set_response_window_bits(8)
323        response = processor.get_extension_response()
324        self.assertTrue(response.has_parameter('max_window_bits'))
325        self.assertEqual('8', response.get_parameter_value('max_window_bits'))
326
327        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
328        processor = DeflateFrameExtensionProcessor(extension)
329        processor.set_response_no_context_takeover(True)
330        response = processor.get_extension_response()
331        self.assertTrue(response.has_parameter('no_context_takeover'))
332        self.assertTrue(
333            response.get_parameter_value('no_context_takeover') is None)
334
335    def test_send_message_perframe_compress_deflate(self):
336        compress = zlib.compressobj(
337            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
338        extension = common.ExtensionParameter(
339            common.PERFRAME_COMPRESSION_EXTENSION)
340        extension.add_parameter('method', 'deflate')
341        request = _create_request_from_rawdata(
342                      '', perframe_compression_request=extension)
343        msgutil.send_message(request, 'Hello')
344        msgutil.send_message(request, 'World')
345
346        expected = ''
347
348        compressed_hello = compress.compress('Hello')
349        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
350        compressed_hello = compressed_hello[:-4]
351        expected += '\xc1%c' % len(compressed_hello)
352        expected += compressed_hello
353
354        compressed_world = compress.compress('World')
355        compressed_world += compress.flush(zlib.Z_SYNC_FLUSH)
356        compressed_world = compressed_world[:-4]
357        expected += '\xc1%c' % len(compressed_world)
358        expected += compressed_world
359
360        self.assertEqual(expected, request.connection.written_data())
361
362    def test_send_message_permessage_compress_deflate(self):
363        compress = zlib.compressobj(
364            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
365        extension = common.ExtensionParameter(
366            common.PERMESSAGE_COMPRESSION_EXTENSION)
367        extension.add_parameter('method', 'deflate')
368        request = _create_request_from_rawdata(
369                      '', permessage_compression_request=extension)
370        msgutil.send_message(request, 'Hello')
371        msgutil.send_message(request, 'World')
372
373        expected = ''
374
375        compressed_hello = compress.compress('Hello')
376        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
377        compressed_hello = compressed_hello[:-4]
378        expected += '\xc1%c' % len(compressed_hello)
379        expected += compressed_hello
380
381        compressed_world = compress.compress('World')
382        compressed_world += compress.flush(zlib.Z_SYNC_FLUSH)
383        compressed_world = compressed_world[:-4]
384        expected += '\xc1%c' % len(compressed_world)
385        expected += compressed_world
386
387        self.assertEqual(expected, request.connection.written_data())
388
389    def test_receive_message(self):
390        request = _create_request(
391            ('\x81\x85', 'Hello'), ('\x81\x86', 'World!'))
392        self.assertEqual('Hello', msgutil.receive_message(request))
393        self.assertEqual('World!', msgutil.receive_message(request))
394
395        payload = 'a' * 125
396        request = _create_request(('\x81\xfd', payload))
397        self.assertEqual(payload, msgutil.receive_message(request))
398
399    def test_receive_medium_message(self):
400        payload = 'a' * 126
401        request = _create_request(('\x81\xfe\x00\x7e', payload))
402        self.assertEqual(payload, msgutil.receive_message(request))
403
404        payload = 'a' * ((1 << 16) - 1)
405        request = _create_request(('\x81\xfe\xff\xff', payload))
406        self.assertEqual(payload, msgutil.receive_message(request))
407
408    def test_receive_large_message(self):
409        payload = 'a' * (1 << 16)
410        request = _create_request(
411            ('\x81\xff\x00\x00\x00\x00\x00\x01\x00\x00', payload))
412        self.assertEqual(payload, msgutil.receive_message(request))
413
414    def test_receive_length_not_encoded_using_minimal_number_of_bytes(self):
415        # Log warning on receiving bad payload length field that doesn't use
416        # minimal number of bytes but continue processing.
417
418        payload = 'a'
419        # 1 byte can be represented without extended payload length field.
420        request = _create_request(
421            ('\x81\xff\x00\x00\x00\x00\x00\x00\x00\x01', payload))
422        self.assertEqual(payload, msgutil.receive_message(request))
423
424    def test_receive_message_unicode(self):
425        request = _create_request(('\x81\x83', '\xe6\x9c\xac'))
426        # U+672c is encoded as e6,9c,ac in UTF-8
427        self.assertEqual(u'\u672c', msgutil.receive_message(request))
428
429    def test_receive_message_erroneous_unicode(self):
430        # \x80 and \x81 are invalid as UTF-8.
431        request = _create_request(('\x81\x82', '\x80\x81'))
432        # Invalid characters should raise InvalidUTF8Exception
433        self.assertRaises(InvalidUTF8Exception,
434                          msgutil.receive_message,
435                          request)
436
437    def test_receive_fragments(self):
438        request = _create_request(
439            ('\x01\x85', 'Hello'),
440            ('\x00\x81', ' '),
441            ('\x00\x85', 'World'),
442            ('\x80\x81', '!'))
443        self.assertEqual('Hello World!', msgutil.receive_message(request))
444
445    def test_receive_fragments_unicode(self):
446        # UTF-8 encodes U+6f22 into e6bca2 and U+5b57 into e5ad97.
447        request = _create_request(
448            ('\x01\x82', '\xe6\xbc'),
449            ('\x00\x82', '\xa2\xe5'),
450            ('\x80\x82', '\xad\x97'))
451        self.assertEqual(u'\u6f22\u5b57', msgutil.receive_message(request))
452
453    def test_receive_fragments_immediate_zero_termination(self):
454        request = _create_request(
455            ('\x01\x8c', 'Hello World!'), ('\x80\x80', ''))
456        self.assertEqual('Hello World!', msgutil.receive_message(request))
457
458    def test_receive_fragments_duplicate_start(self):
459        request = _create_request(
460            ('\x01\x85', 'Hello'), ('\x01\x85', 'World'))
461        self.assertRaises(msgutil.InvalidFrameException,
462                          msgutil.receive_message,
463                          request)
464
465    def test_receive_fragments_intermediate_but_not_started(self):
466        request = _create_request(('\x00\x85', 'Hello'))
467        self.assertRaises(msgutil.InvalidFrameException,
468                          msgutil.receive_message,
469                          request)
470
471    def test_receive_fragments_end_but_not_started(self):
472        request = _create_request(('\x80\x85', 'Hello'))
473        self.assertRaises(msgutil.InvalidFrameException,
474                          msgutil.receive_message,
475                          request)
476
477    def test_receive_message_discard(self):
478        request = _create_request(
479            ('\x8f\x86', 'IGNORE'), ('\x81\x85', 'Hello'),
480            ('\x8f\x89', 'DISREGARD'), ('\x81\x86', 'World!'))
481        self.assertRaises(msgutil.UnsupportedFrameException,
482                          msgutil.receive_message, request)
483        self.assertEqual('Hello', msgutil.receive_message(request))
484        self.assertRaises(msgutil.UnsupportedFrameException,
485                          msgutil.receive_message, request)
486        self.assertEqual('World!', msgutil.receive_message(request))
487
488    def test_receive_close(self):
489        request = _create_request(
490            ('\x88\x8a', struct.pack('!H', 1000) + 'Good bye'))
491        self.assertEqual(None, msgutil.receive_message(request))
492        self.assertEqual(1000, request.ws_close_code)
493        self.assertEqual('Good bye', request.ws_close_reason)
494
495    def test_receive_message_deflate_stream(self):
496        compress = zlib.compressobj(
497            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
498
499        data = compress.compress('\x81\x85' + _mask_hybi('Hello'))
500        data += compress.flush(zlib.Z_SYNC_FLUSH)
501        data += compress.compress('\x81\x89' + _mask_hybi('WebSocket'))
502        data += compress.flush(zlib.Z_FINISH)
503
504        compress = zlib.compressobj(
505            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
506
507        data += compress.compress('\x81\x85' + _mask_hybi('World'))
508        data += compress.flush(zlib.Z_SYNC_FLUSH)
509        # Close frame
510        data += compress.compress(
511            '\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + 'Good bye'))
512        data += compress.flush(zlib.Z_SYNC_FLUSH)
513
514        request = _create_request_from_rawdata(data, deflate_stream=True)
515        self.assertEqual('Hello', msgutil.receive_message(request))
516        self.assertEqual('WebSocket', msgutil.receive_message(request))
517        self.assertEqual('World', msgutil.receive_message(request))
518
519        self.assertFalse(request.drain_received_data_called)
520
521        self.assertEqual(None, msgutil.receive_message(request))
522
523        self.assertTrue(request.drain_received_data_called)
524
525    def test_receive_message_deflate_frame(self):
526        compress = zlib.compressobj(
527            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
528
529        data = ''
530
531        compressed_hello = compress.compress('Hello')
532        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
533        compressed_hello = compressed_hello[:-4]
534        data += '\xc1%c' % (len(compressed_hello) | 0x80)
535        data += _mask_hybi(compressed_hello)
536
537        compressed_websocket = compress.compress('WebSocket')
538        compressed_websocket += compress.flush(zlib.Z_FINISH)
539        compressed_websocket += '\x00'
540        data += '\xc1%c' % (len(compressed_websocket) | 0x80)
541        data += _mask_hybi(compressed_websocket)
542
543        compress = zlib.compressobj(
544            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
545
546        compressed_world = compress.compress('World')
547        compressed_world += compress.flush(zlib.Z_SYNC_FLUSH)
548        compressed_world = compressed_world[:-4]
549        data += '\xc1%c' % (len(compressed_world) | 0x80)
550        data += _mask_hybi(compressed_world)
551
552        # Close frame
553        data += '\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + 'Good bye')
554
555        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
556        request = _create_request_from_rawdata(
557            data, deflate_frame_request=extension)
558        self.assertEqual('Hello', msgutil.receive_message(request))
559        self.assertEqual('WebSocket', msgutil.receive_message(request))
560        self.assertEqual('World', msgutil.receive_message(request))
561
562        self.assertEqual(None, msgutil.receive_message(request))
563
564    def test_receive_message_deflate_frame_client_using_smaller_window(self):
565        """Test that frames coming from a client which is using smaller window
566        size that the server are correctly received.
567        """
568
569        # Using the smallest window bits of 8 for generating input frames.
570        compress = zlib.compressobj(
571            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -8)
572
573        data = ''
574
575        # Use a frame whose content is bigger than the clients' DEFLATE window
576        # size before compression. The content mainly consists of 'a' but
577        # repetition of 'b' is put at the head and tail so that if the window
578        # size is big, the head is back-referenced but if small, not.
579        payload = 'b' * 64 + 'a' * 1024 + 'b' * 64
580        compressed_hello = compress.compress(payload)
581        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
582        compressed_hello = compressed_hello[:-4]
583        data += '\xc1%c' % (len(compressed_hello) | 0x80)
584        data += _mask_hybi(compressed_hello)
585
586        # Close frame
587        data += '\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + 'Good bye')
588
589        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
590        request = _create_request_from_rawdata(
591            data, deflate_frame_request=extension)
592        self.assertEqual(payload, msgutil.receive_message(request))
593
594        self.assertEqual(None, msgutil.receive_message(request))
595
596    def test_receive_message_deflate_frame_comp_bit(self):
597        compress = zlib.compressobj(
598            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
599
600        data = ''
601
602        compressed_hello = compress.compress('Hello')
603        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
604        compressed_hello = compressed_hello[:-4]
605        data += '\xc1%c' % (len(compressed_hello) | 0x80)
606        data += _mask_hybi(compressed_hello)
607
608        data += '\x81\x85' + _mask_hybi('Hello')
609
610        compress = zlib.compressobj(
611            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
612
613        compressed_2nd_hello = compress.compress('Hello')
614        compressed_2nd_hello += compress.flush(zlib.Z_SYNC_FLUSH)
615        compressed_2nd_hello = compressed_2nd_hello[:-4]
616        data += '\xc1%c' % (len(compressed_2nd_hello) | 0x80)
617        data += _mask_hybi(compressed_2nd_hello)
618
619        extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
620        request = _create_request_from_rawdata(
621            data, deflate_frame_request=extension)
622        for i in xrange(3):
623            self.assertEqual('Hello', msgutil.receive_message(request))
624
625    def test_receive_message_perframe_compression_frame(self):
626        compress = zlib.compressobj(
627            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
628
629        data = ''
630
631        compressed_hello = compress.compress('Hello')
632        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
633        compressed_hello = compressed_hello[:-4]
634        data += '\xc1%c' % (len(compressed_hello) | 0x80)
635        data += _mask_hybi(compressed_hello)
636
637        compressed_websocket = compress.compress('WebSocket')
638        compressed_websocket += compress.flush(zlib.Z_FINISH)
639        compressed_websocket += '\x00'
640        data += '\xc1%c' % (len(compressed_websocket) | 0x80)
641        data += _mask_hybi(compressed_websocket)
642
643        compress = zlib.compressobj(
644            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
645
646        compressed_world = compress.compress('World')
647        compressed_world += compress.flush(zlib.Z_SYNC_FLUSH)
648        compressed_world = compressed_world[:-4]
649        data += '\xc1%c' % (len(compressed_world) | 0x80)
650        data += _mask_hybi(compressed_world)
651
652        # Close frame
653        data += '\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + 'Good bye')
654
655        extension = common.ExtensionParameter(
656            common.PERFRAME_COMPRESSION_EXTENSION)
657        extension.add_parameter('method', 'deflate')
658        request = _create_request_from_rawdata(
659            data, perframe_compression_request=extension)
660        self.assertEqual('Hello', msgutil.receive_message(request))
661        self.assertEqual('WebSocket', msgutil.receive_message(request))
662        self.assertEqual('World', msgutil.receive_message(request))
663
664        self.assertEqual(None, msgutil.receive_message(request))
665
666    def test_receive_message_permessage_deflate_compression(self):
667        compress = zlib.compressobj(
668            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
669
670        data = ''
671
672        compressed_hello = compress.compress('HelloWebSocket')
673        compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
674        compressed_hello = compressed_hello[:-4]
675        split_position = len(compressed_hello) / 2
676        data += '\x41%c' % (split_position | 0x80)
677        data += _mask_hybi(compressed_hello[:split_position])
678
679        data += '\x80%c' % ((len(compressed_hello) - split_position) | 0x80)
680        data += _mask_hybi(compressed_hello[split_position:])
681
682        compress = zlib.compressobj(
683            zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
684
685        compressed_world = compress.compress('World')
686        compressed_world += compress.flush(zlib.Z_SYNC_FLUSH)
687        compressed_world = compressed_world[:-4]
688        data += '\xc1%c' % (len(compressed_world) | 0x80)
689        data += _mask_hybi(compressed_world)
690
691        # Close frame
692        data += '\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + 'Good bye')
693
694        extension = common.ExtensionParameter(
695            common.PERMESSAGE_COMPRESSION_EXTENSION)
696        extension.add_parameter('method', 'deflate')
697        request = _create_request_from_rawdata(
698            data, permessage_compression_request=extension)
699        self.assertEqual('HelloWebSocket', msgutil.receive_message(request))
700        self.assertEqual('World', msgutil.receive_message(request))
701
702        self.assertEqual(None, msgutil.receive_message(request))
703
704    def test_send_longest_close(self):
705        reason = 'a' * 123
706        request = _create_request(
707            ('\x88\xfd',
708             struct.pack('!H', common.STATUS_NORMAL_CLOSURE) + reason))
709        request.ws_stream.close_connection(common.STATUS_NORMAL_CLOSURE,
710                                           reason)
711        self.assertEqual(request.ws_close_code, common.STATUS_NORMAL_CLOSURE)
712        self.assertEqual(request.ws_close_reason, reason)
713
714    def test_send_close_too_long(self):
715        request = _create_request()
716        self.assertRaises(msgutil.BadOperationException,
717                          Stream.close_connection,
718                          request.ws_stream,
719                          common.STATUS_NORMAL_CLOSURE,
720                          'a' * 124)
721
722    def test_send_close_inconsistent_code_and_reason(self):
723        request = _create_request()
724        # reason parameter must not be specified when code is None.
725        self.assertRaises(msgutil.BadOperationException,
726                          Stream.close_connection,
727                          request.ws_stream,
728                          None,
729                          'a')
730
731    def test_send_ping(self):
732        request = _create_request()
733        msgutil.send_ping(request, 'Hello World!')
734        self.assertEqual('\x89\x0cHello World!',
735                         request.connection.written_data())
736
737    def test_send_longest_ping(self):
738        request = _create_request()
739        msgutil.send_ping(request, 'a' * 125)
740        self.assertEqual('\x89\x7d' + 'a' * 125,
741                         request.connection.written_data())
742
743    def test_send_ping_too_long(self):
744        request = _create_request()
745        self.assertRaises(msgutil.BadOperationException,
746                          msgutil.send_ping,
747                          request,
748                          'a' * 126)
749
750    def test_receive_ping(self):
751        """Tests receiving a ping control frame."""
752
753        def handler(request, message):
754            request.called = True
755
756        # Stream automatically respond to ping with pong without any action
757        # by application layer.
758        request = _create_request(
759            ('\x89\x85', 'Hello'), ('\x81\x85', 'World'))
760        self.assertEqual('World', msgutil.receive_message(request))
761        self.assertEqual('\x8a\x05Hello',
762                         request.connection.written_data())
763
764        request = _create_request(
765            ('\x89\x85', 'Hello'), ('\x81\x85', 'World'))
766        request.on_ping_handler = handler
767        self.assertEqual('World', msgutil.receive_message(request))
768        self.assertTrue(request.called)
769
770    def test_receive_longest_ping(self):
771        request = _create_request(
772            ('\x89\xfd', 'a' * 125), ('\x81\x85', 'World'))
773        self.assertEqual('World', msgutil.receive_message(request))
774        self.assertEqual('\x8a\x7d' + 'a' * 125,
775                         request.connection.written_data())
776
777    def test_receive_ping_too_long(self):
778        request = _create_request(('\x89\xfe\x00\x7e', 'a' * 126))
779        self.assertRaises(msgutil.InvalidFrameException,
780                          msgutil.receive_message,
781                          request)
782
783    def test_receive_pong(self):
784        """Tests receiving a pong control frame."""
785
786        def handler(request, message):
787            request.called = True
788
789        request = _create_request(
790            ('\x8a\x85', 'Hello'), ('\x81\x85', 'World'))
791        request.on_pong_handler = handler
792        msgutil.send_ping(request, 'Hello')
793        self.assertEqual('\x89\x05Hello',
794                         request.connection.written_data())
795        # Valid pong is received, but receive_message won't return for it.
796        self.assertEqual('World', msgutil.receive_message(request))
797        # Check that nothing was written after receive_message call.
798        self.assertEqual('\x89\x05Hello',
799                         request.connection.written_data())
800
801        self.assertTrue(request.called)
802
803    def test_receive_unsolicited_pong(self):
804        # Unsolicited pong is allowed from HyBi 07.
805        request = _create_request(
806            ('\x8a\x85', 'Hello'), ('\x81\x85', 'World'))
807        msgutil.receive_message(request)
808
809        request = _create_request(
810            ('\x8a\x85', 'Hello'), ('\x81\x85', 'World'))
811        msgutil.send_ping(request, 'Jumbo')
812        # Body mismatch.
813        msgutil.receive_message(request)
814
815    def test_ping_cannot_be_fragmented(self):
816        request = _create_request(('\x09\x85', 'Hello'))
817        self.assertRaises(msgutil.InvalidFrameException,
818                          msgutil.receive_message,
819                          request)
820
821    def test_ping_with_too_long_payload(self):
822        request = _create_request(('\x89\xfe\x01\x00', 'a' * 256))
823        self.assertRaises(msgutil.InvalidFrameException,
824                          msgutil.receive_message,
825                          request)
826
827
828class MessageTestHixie75(unittest.TestCase):
829    """Tests for draft-hixie-thewebsocketprotocol-76 stream class."""
830
831    def test_send_message(self):
832        request = _create_request_hixie75()
833        msgutil.send_message(request, 'Hello')
834        self.assertEqual('\x00Hello\xff', request.connection.written_data())
835
836    def test_send_message_unicode(self):
837        request = _create_request_hixie75()
838        msgutil.send_message(request, u'\u65e5')
839        # U+65e5 is encoded as e6,97,a5 in UTF-8
840        self.assertEqual('\x00\xe6\x97\xa5\xff',
841                         request.connection.written_data())
842
843    def test_receive_message(self):
844        request = _create_request_hixie75('\x00Hello\xff\x00World!\xff')
845        self.assertEqual('Hello', msgutil.receive_message(request))
846        self.assertEqual('World!', msgutil.receive_message(request))
847
848    def test_receive_message_unicode(self):
849        request = _create_request_hixie75('\x00\xe6\x9c\xac\xff')
850        # U+672c is encoded as e6,9c,ac in UTF-8
851        self.assertEqual(u'\u672c', msgutil.receive_message(request))
852
853    def test_receive_message_erroneous_unicode(self):
854        # \x80 and \x81 are invalid as UTF-8.
855        request = _create_request_hixie75('\x00\x80\x81\xff')
856        # Invalid characters should be replaced with
857        # U+fffd REPLACEMENT CHARACTER
858        self.assertEqual(u'\ufffd\ufffd', msgutil.receive_message(request))
859
860    def test_receive_message_discard(self):
861        request = _create_request_hixie75('\x80\x06IGNORE\x00Hello\xff'
862                                          '\x01DISREGARD\xff\x00World!\xff')
863        self.assertEqual('Hello', msgutil.receive_message(request))
864        self.assertEqual('World!', msgutil.receive_message(request))
865
866
867class MessageReceiverTest(unittest.TestCase):
868    """Tests the Stream class using MessageReceiver."""
869
870    def test_queue(self):
871        request = _create_blocking_request()
872        receiver = msgutil.MessageReceiver(request)
873
874        self.assertEqual(None, receiver.receive_nowait())
875
876        request.connection.put_bytes('\x81\x86' + _mask_hybi('Hello!'))
877        self.assertEqual('Hello!', receiver.receive())
878
879    def test_onmessage(self):
880        onmessage_queue = Queue.Queue()
881
882        def onmessage_handler(message):
883            onmessage_queue.put(message)
884
885        request = _create_blocking_request()
886        receiver = msgutil.MessageReceiver(request, onmessage_handler)
887
888        request.connection.put_bytes('\x81\x86' + _mask_hybi('Hello!'))
889        self.assertEqual('Hello!', onmessage_queue.get())
890
891
892class MessageReceiverHixie75Test(unittest.TestCase):
893    """Tests the StreamHixie75 class using MessageReceiver."""
894
895    def test_queue(self):
896        request = _create_blocking_request_hixie75()
897        receiver = msgutil.MessageReceiver(request)
898
899        self.assertEqual(None, receiver.receive_nowait())
900
901        request.connection.put_bytes('\x00Hello!\xff')
902        self.assertEqual('Hello!', receiver.receive())
903
904    def test_onmessage(self):
905        onmessage_queue = Queue.Queue()
906
907        def onmessage_handler(message):
908            onmessage_queue.put(message)
909
910        request = _create_blocking_request_hixie75()
911        receiver = msgutil.MessageReceiver(request, onmessage_handler)
912
913        request.connection.put_bytes('\x00Hello!\xff')
914        self.assertEqual('Hello!', onmessage_queue.get())
915
916
917class MessageSenderTest(unittest.TestCase):
918    """Tests the Stream class using MessageSender."""
919
920    def test_send(self):
921        request = _create_blocking_request()
922        sender = msgutil.MessageSender(request)
923
924        sender.send('World')
925        self.assertEqual('\x81\x05World', request.connection.written_data())
926
927    def test_send_nowait(self):
928        # Use a queue to check the bytes written by MessageSender.
929        # request.connection.written_data() cannot be used here because
930        # MessageSender runs in a separate thread.
931        send_queue = Queue.Queue()
932
933        def write(bytes):
934            send_queue.put(bytes)
935
936        request = _create_blocking_request()
937        request.connection.write = write
938
939        sender = msgutil.MessageSender(request)
940
941        sender.send_nowait('Hello')
942        sender.send_nowait('World')
943        self.assertEqual('\x81\x05Hello', send_queue.get())
944        self.assertEqual('\x81\x05World', send_queue.get())
945
946
947class MessageSenderHixie75Test(unittest.TestCase):
948    """Tests the StreamHixie75 class using MessageSender."""
949
950    def test_send(self):
951        request = _create_blocking_request_hixie75()
952        sender = msgutil.MessageSender(request)
953
954        sender.send('World')
955        self.assertEqual('\x00World\xff', request.connection.written_data())
956
957    def test_send_nowait(self):
958        # Use a queue to check the bytes written by MessageSender.
959        # request.connection.written_data() cannot be used here because
960        # MessageSender runs in a separate thread.
961        send_queue = Queue.Queue()
962
963        def write(bytes):
964            send_queue.put(bytes)
965
966        request = _create_blocking_request_hixie75()
967        request.connection.write = write
968
969        sender = msgutil.MessageSender(request)
970
971        sender.send_nowait('Hello')
972        sender.send_nowait('World')
973        self.assertEqual('\x00Hello\xff', send_queue.get())
974        self.assertEqual('\x00World\xff', send_queue.get())
975
976
977if __name__ == '__main__':
978    unittest.main()
979
980
981# vi:sts=4 sw=4 et
982