1#!/usr/bin/env python
2#
3# Copyright 2012, Google Inc.
4# All rights reserved.
5#
6# Redistribution and use in source and binary forms, with or without
7# modification, are permitted provided that the following conditions are
8# met:
9#
10#     * Redistributions of source code must retain the above copyright
11# notice, this list of conditions and the following disclaimer.
12#     * Redistributions in binary form must reproduce the above
13# copyright notice, this list of conditions and the following disclaimer
14# in the documentation and/or other materials provided with the
15# distribution.
16#     * Neither the name of Google Inc. nor the names of its
17# contributors may be used to endorse or promote products derived from
18# this software without specific prior written permission.
19#
20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
32
33"""Tests for mux module."""
34
35import Queue
36import logging
37import optparse
38import unittest
39import struct
40import sys
41
42import set_sys_path  # Update sys.path to locate mod_pywebsocket module.
43
44from mod_pywebsocket import common
45from mod_pywebsocket import mux
46from mod_pywebsocket._stream_base import ConnectionTerminatedException
47from mod_pywebsocket._stream_hybi import Stream
48from mod_pywebsocket._stream_hybi import StreamOptions
49from mod_pywebsocket._stream_hybi import create_binary_frame
50from mod_pywebsocket._stream_hybi import parse_frame
51
52import mock
53
54
55class _OutgoingChannelData(object):
56    def __init__(self):
57        self.messages = []
58        self.control_messages = []
59
60        self.current_opcode = None
61        self.pending_fragments = []
62
63
64class _MockMuxConnection(mock.MockBlockingConn):
65    """Mock class of mod_python connection for mux."""
66
67    def __init__(self):
68        mock.MockBlockingConn.__init__(self)
69        self._control_blocks = []
70        self._channel_data = {}
71
72        self._current_opcode = None
73        self._pending_fragments = []
74
75    def write(self, data):
76        """Override MockBlockingConn.write."""
77
78        self._current_data = data
79        self._position = 0
80
81        def _receive_bytes(length):
82            if self._position + length > len(self._current_data):
83                raise ConnectionTerminatedException(
84                    'Failed to receive %d bytes from encapsulated '
85                    'frame' % length)
86            data = self._current_data[self._position:self._position+length]
87            self._position += length
88            return data
89
90        opcode, payload, fin, rsv1, rsv2, rsv3 = (
91            parse_frame(_receive_bytes, unmask_receive=False))
92
93        self._pending_fragments.append(payload)
94
95        if self._current_opcode is None:
96            if opcode == common.OPCODE_CONTINUATION:
97                raise Exception('Sending invalid continuation opcode')
98            self._current_opcode = opcode
99        else:
100            if opcode != common.OPCODE_CONTINUATION:
101                raise Exception('Sending invalid opcode %d' % opcode)
102        if not fin:
103            return
104
105        inner_frame_data = ''.join(self._pending_fragments)
106        self._pending_fragments = []
107        self._current_opcode = None
108
109        parser = mux._MuxFramePayloadParser(inner_frame_data)
110        channel_id = parser.read_channel_id()
111        if channel_id == mux._CONTROL_CHANNEL_ID:
112            self._control_blocks.append(parser.remaining_data())
113            return
114
115        if not channel_id in self._channel_data:
116            self._channel_data[channel_id] = _OutgoingChannelData()
117        channel_data = self._channel_data[channel_id]
118
119        (inner_fin, inner_rsv1, inner_rsv2, inner_rsv3, inner_opcode,
120         inner_payload) = parser.read_inner_frame()
121        channel_data.pending_fragments.append(inner_payload)
122
123        if channel_data.current_opcode is None:
124            if inner_opcode == common.OPCODE_CONTINUATION:
125                raise Exception('Sending invalid continuation opcode')
126            channel_data.current_opcode = inner_opcode
127        else:
128            if inner_opcode != common.OPCODE_CONTINUATION:
129                raise Exception('Sending invalid opcode %d' % inner_opcode)
130        if not inner_fin:
131            return
132
133        message = ''.join(channel_data.pending_fragments)
134        channel_data.pending_fragments = []
135
136        if (channel_data.current_opcode == common.OPCODE_TEXT or
137            channel_data.current_opcode == common.OPCODE_BINARY):
138            channel_data.messages.append(message)
139        else:
140            channel_data.control_messages.append(
141                {'opcode': channel_data.current_opcode,
142                 'message': message})
143        channel_data.current_opcode = None
144
145    def get_written_control_blocks(self):
146        return self._control_blocks
147
148    def get_written_messages(self, channel_id):
149        return self._channel_data[channel_id].messages
150
151    def get_written_control_messages(self, channel_id):
152        return self._channel_data[channel_id].control_messages
153
154
155class _ChannelEvent(object):
156    """A structure that records channel events."""
157
158    def __init__(self):
159        self.messages = []
160        self.exception = None
161        self.client_initiated_closing = False
162
163
164class _MuxMockDispatcher(object):
165    """Mock class of dispatch.Dispatcher for mux."""
166
167    def __init__(self):
168        self.channel_events = {}
169
170    def do_extra_handshake(self, request):
171        pass
172
173    def _do_echo(self, request, channel_events):
174        while True:
175            message = request.ws_stream.receive_message()
176            if message == None:
177                channel_events.client_initiated_closing = True
178                return
179            if message == 'Goodbye':
180                return
181            channel_events.messages.append(message)
182            # echo back
183            request.ws_stream.send_message(message)
184
185    def _do_ping(self, request, channel_events):
186        request.ws_stream.send_ping('Ping!')
187
188    def transfer_data(self, request):
189        self.channel_events[request.channel_id] = _ChannelEvent()
190
191        try:
192            # Note: more handler will be added.
193            if request.uri.endswith('echo'):
194                self._do_echo(request,
195                              self.channel_events[request.channel_id])
196            elif request.uri.endswith('ping'):
197                self._do_ping(request,
198                              self.channel_events[request.channel_id])
199            else:
200                raise ValueError('Cannot handle path %r' % request.path)
201            if not request.server_terminated:
202                request.ws_stream.close_connection()
203        except ConnectionTerminatedException, e:
204            self.channel_events[request.channel_id].exception = e
205        except Exception, e:
206            self.channel_events[request.channel_id].exception = e
207            raise
208
209
210def _create_mock_request():
211    headers = {'Host': 'server.example.com',
212               'Upgrade': 'websocket',
213               'Connection': 'Upgrade',
214               'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==',
215               'Sec-WebSocket-Version': '13',
216               'Origin': 'http://example.com'}
217    request = mock.MockRequest(uri='/echo',
218                               headers_in=headers,
219                               connection=_MockMuxConnection())
220    request.ws_stream = Stream(request, options=StreamOptions())
221    request.mux = True
222    request.mux_extensions = []
223    request.mux_quota = 8 * 1024
224    return request
225
226
227def _create_add_channel_request_frame(channel_id, encoding, encoded_handshake):
228    if encoding != 0 and encoding != 1:
229        raise ValueError('Invalid encoding')
230    block = mux._create_control_block_length_value(
231               channel_id, mux._MUX_OPCODE_ADD_CHANNEL_REQUEST, encoding,
232               encoded_handshake)
233    payload = mux._encode_channel_id(mux._CONTROL_CHANNEL_ID) + block
234    return create_binary_frame(payload, mask=True)
235
236
237def _create_logical_frame(channel_id, message, opcode=common.OPCODE_BINARY,
238                          mask=True):
239    bits = chr(0x80 | opcode)
240    payload = mux._encode_channel_id(channel_id) + bits + message
241    return create_binary_frame(payload, mask=mask)
242
243
244def _create_request_header(path='/echo'):
245    return (
246        'GET %s HTTP/1.1\r\n'
247        'Host: server.example.com\r\n'
248        'Upgrade: websocket\r\n'
249        'Connection: Upgrade\r\n'
250        'Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n'
251        'Sec-WebSocket-Version: 13\r\n'
252        'Origin: http://example.com\r\n'
253        '\r\n') % path
254
255
256class MuxTest(unittest.TestCase):
257    """A unittest for mux module."""
258
259    def test_channel_id_decode(self):
260        data = '\x00\x01\xbf\xff\xdf\xff\xff\xff\xff\xff\xff'
261        parser = mux._MuxFramePayloadParser(data)
262        channel_id = parser.read_channel_id()
263        self.assertEqual(0, channel_id)
264        channel_id = parser.read_channel_id()
265        self.assertEqual(1, channel_id)
266        channel_id = parser.read_channel_id()
267        self.assertEqual(2 ** 14 - 1, channel_id)
268        channel_id = parser.read_channel_id()
269        self.assertEqual(2 ** 21 - 1, channel_id)
270        channel_id = parser.read_channel_id()
271        self.assertEqual(2 ** 29 - 1, channel_id)
272        self.assertEqual(len(data), parser._read_position)
273
274    def test_channel_id_encode(self):
275        encoded = mux._encode_channel_id(0)
276        self.assertEqual('\x00', encoded)
277        encoded = mux._encode_channel_id(2 ** 14 - 1)
278        self.assertEqual('\xbf\xff', encoded)
279        encoded = mux._encode_channel_id(2 ** 14)
280        self.assertEqual('\xc0@\x00', encoded)
281        encoded = mux._encode_channel_id(2 ** 21 - 1)
282        self.assertEqual('\xdf\xff\xff', encoded)
283        encoded = mux._encode_channel_id(2 ** 21)
284        self.assertEqual('\xe0 \x00\x00', encoded)
285        encoded = mux._encode_channel_id(2 ** 29 - 1)
286        self.assertEqual('\xff\xff\xff\xff', encoded)
287        # channel_id is too large
288        self.assertRaises(ValueError,
289                          mux._encode_channel_id,
290                          2 ** 29)
291
292    def test_create_control_block_length_value(self):
293        data = 'Hello, world!'
294        block = mux._create_control_block_length_value(
295            channel_id=1, opcode=mux._MUX_OPCODE_ADD_CHANNEL_REQUEST,
296            flags=0x7, value=data)
297        expected = '\x1c\x01\x0dHello, world!'
298        self.assertEqual(expected, block)
299
300        data = 'a' * (2 ** 8)
301        block = mux._create_control_block_length_value(
302            channel_id=2, opcode=mux._MUX_OPCODE_ADD_CHANNEL_RESPONSE,
303            flags=0x0, value=data)
304        expected = '\x21\x02\x01\x00' + data
305        self.assertEqual(expected, block)
306
307        data = 'b' * (2 ** 16)
308        block = mux._create_control_block_length_value(
309            channel_id=3, opcode=mux._MUX_OPCODE_DROP_CHANNEL,
310            flags=0x0, value=data)
311        expected = '\x62\x03\x01\x00\x00' + data
312        self.assertEqual(expected, block)
313
314    def test_read_control_blocks(self):
315        data = ('\x00\x01\00'
316                '\x61\x02\x01\x00%s'
317                '\x0a\x03\x01\x00\x00%s'
318                '\x63\x04\x01\x00\x00\x00%s') % (
319            'a' * 0x0100, 'b' * 0x010000, 'c' * 0x01000000)
320        parser = mux._MuxFramePayloadParser(data)
321        blocks = list(parser.read_control_blocks())
322        self.assertEqual(4, len(blocks))
323
324        self.assertEqual(mux._MUX_OPCODE_ADD_CHANNEL_REQUEST, blocks[0].opcode)
325        self.assertEqual(0, blocks[0].encoding)
326        self.assertEqual(0, len(blocks[0].encoded_handshake))
327
328        self.assertEqual(mux._MUX_OPCODE_DROP_CHANNEL, blocks[1].opcode)
329        self.assertEqual(0, blocks[1].mux_error)
330        self.assertEqual(0x0100, len(blocks[1].reason))
331
332        self.assertEqual(mux._MUX_OPCODE_ADD_CHANNEL_REQUEST, blocks[2].opcode)
333        self.assertEqual(2, blocks[2].encoding)
334        self.assertEqual(0x010000, len(blocks[2].encoded_handshake))
335
336        self.assertEqual(mux._MUX_OPCODE_DROP_CHANNEL, blocks[3].opcode)
337        self.assertEqual(0, blocks[3].mux_error)
338        self.assertEqual(0x01000000, len(blocks[3].reason))
339
340        self.assertEqual(len(data), parser._read_position)
341
342    def test_create_add_channel_response(self):
343        data = mux._create_add_channel_response(channel_id=1,
344                                                encoded_handshake='FooBar',
345                                                encoding=0,
346                                                rejected=False)
347        self.assertEqual('\x82\x0a\x00\x20\x01\x06FooBar', data)
348
349        data = mux._create_add_channel_response(channel_id=2,
350                                                encoded_handshake='Hello',
351                                                encoding=1,
352                                                rejected=True)
353        self.assertEqual('\x82\x09\x00\x34\x02\x05Hello', data)
354
355    def test_drop_channel(self):
356        data = mux._create_drop_channel(channel_id=1,
357                                        reason='',
358                                        mux_error=False)
359        self.assertEqual('\x82\x04\x00\x60\x01\x00', data)
360
361        data = mux._create_drop_channel(channel_id=1,
362                                        reason='error',
363                                        mux_error=True)
364        self.assertEqual('\x82\x09\x00\x70\x01\x05error', data)
365
366        # reason must be empty if mux_error is False.
367        self.assertRaises(ValueError,
368                          mux._create_drop_channel,
369                          1, 'FooBar', False)
370
371    def test_parse_request_text(self):
372        request_text = _create_request_header()
373        command, path, version, headers = mux._parse_request_text(request_text)
374        self.assertEqual('GET', command)
375        self.assertEqual('/echo', path)
376        self.assertEqual('HTTP/1.1', version)
377        self.assertEqual(6, len(headers))
378        self.assertEqual('server.example.com', headers['Host'])
379        self.assertEqual('websocket', headers['Upgrade'])
380        self.assertEqual('Upgrade', headers['Connection'])
381        self.assertEqual('dGhlIHNhbXBsZSBub25jZQ==',
382                         headers['Sec-WebSocket-Key'])
383        self.assertEqual('13', headers['Sec-WebSocket-Version'])
384        self.assertEqual('http://example.com', headers['Origin'])
385
386
387class MuxHandlerTest(unittest.TestCase):
388
389    def test_add_channel(self):
390        request = _create_mock_request()
391        dispatcher = _MuxMockDispatcher()
392        mux_handler = mux._MuxHandler(request, dispatcher)
393        mux_handler.start()
394        mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
395                                      mux._INITIAL_QUOTA_FOR_CLIENT)
396
397        encoded_handshake = _create_request_header(path='/echo')
398        add_channel_request = _create_add_channel_request_frame(
399            channel_id=2, encoding=0,
400            encoded_handshake=encoded_handshake)
401        request.connection.put_bytes(add_channel_request)
402
403        flow_control = mux._create_flow_control(channel_id=2,
404                                                replenished_quota=5,
405                                                outer_frame_mask=True)
406        request.connection.put_bytes(flow_control)
407
408        encoded_handshake = _create_request_header(path='/echo')
409        add_channel_request = _create_add_channel_request_frame(
410            channel_id=3, encoding=0,
411            encoded_handshake=encoded_handshake)
412        request.connection.put_bytes(add_channel_request)
413
414        flow_control = mux._create_flow_control(channel_id=3,
415                                                replenished_quota=5,
416                                                outer_frame_mask=True)
417        request.connection.put_bytes(flow_control)
418
419        request.connection.put_bytes(
420            _create_logical_frame(channel_id=2, message='Hello'))
421        request.connection.put_bytes(
422            _create_logical_frame(channel_id=3, message='World'))
423        request.connection.put_bytes(
424            _create_logical_frame(channel_id=1, message='Goodbye'))
425        request.connection.put_bytes(
426            _create_logical_frame(channel_id=2, message='Goodbye'))
427        request.connection.put_bytes(
428            _create_logical_frame(channel_id=3, message='Goodbye'))
429
430        mux_handler.wait_until_done(timeout=2)
431
432        self.assertEqual([], dispatcher.channel_events[1].messages)
433        self.assertEqual(['Hello'], dispatcher.channel_events[2].messages)
434        self.assertEqual(['World'], dispatcher.channel_events[3].messages)
435        # Channel 2
436        messages = request.connection.get_written_messages(2)
437        self.assertEqual(1, len(messages))
438        self.assertEqual('Hello', messages[0])
439        # Channel 3
440        messages = request.connection.get_written_messages(3)
441        self.assertEqual(1, len(messages))
442        self.assertEqual('World', messages[0])
443        control_blocks = request.connection.get_written_control_blocks()
444        # There should be 8 control blocks:
445        #   - 1 NewChannelSlot
446        #   - 2 AddChannelResponses for channel id 2 and 3
447        #   - 6 FlowControls for channel id 1 (initialize), 'Hello', 'World',
448        #     and 3 'Goodbye's
449        self.assertEqual(9, len(control_blocks))
450
451    def test_add_channel_incomplete_handshake(self):
452        request = _create_mock_request()
453        dispatcher = _MuxMockDispatcher()
454        mux_handler = mux._MuxHandler(request, dispatcher)
455        mux_handler.start()
456        mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
457                                      mux._INITIAL_QUOTA_FOR_CLIENT)
458
459        incomplete_encoded_handshake = 'GET /echo HTTP/1.1'
460        add_channel_request = _create_add_channel_request_frame(
461            channel_id=2, encoding=0,
462            encoded_handshake=incomplete_encoded_handshake)
463        request.connection.put_bytes(add_channel_request)
464
465        request.connection.put_bytes(
466            _create_logical_frame(channel_id=1, message='Goodbye'))
467
468        mux_handler.wait_until_done(timeout=2)
469
470        self.assertTrue(1 in dispatcher.channel_events)
471        self.assertTrue(not 2 in dispatcher.channel_events)
472
473    def test_add_channel_invalid_version_handshake(self):
474        request = _create_mock_request()
475        dispatcher = _MuxMockDispatcher()
476        mux_handler = mux._MuxHandler(request, dispatcher)
477        mux_handler.start()
478        mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
479                                      mux._INITIAL_QUOTA_FOR_CLIENT)
480
481        encoded_handshake = (
482            'GET /echo HTTP/1.1\r\n'
483            'Host: example.com\r\n'
484            'Connection: Upgrade\r\n'
485            'Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n'
486            'Sec-WebSocket-Protocol: sample\r\n'
487            'Upgrade: WebSocket\r\n'
488            'Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n'
489            'Origin: http://example.com\r\n'
490            '\r\n'
491            '^n:ds[4U')
492
493        add_channel_request = _create_add_channel_request_frame(
494            channel_id=2, encoding=0,
495            encoded_handshake=encoded_handshake)
496        request.connection.put_bytes(add_channel_request)
497
498        request.connection.put_bytes(
499            _create_logical_frame(channel_id=1, message='Goodbye'))
500
501        mux_handler.wait_until_done(timeout=2)
502
503        self.assertTrue(1 in dispatcher.channel_events)
504        self.assertTrue(not 2 in dispatcher.channel_events)
505
506    def test_receive_drop_channel(self):
507        request = _create_mock_request()
508        dispatcher = _MuxMockDispatcher()
509        mux_handler = mux._MuxHandler(request, dispatcher)
510        mux_handler.start()
511        mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
512                                      mux._INITIAL_QUOTA_FOR_CLIENT)
513
514        encoded_handshake = _create_request_header(path='/echo')
515        add_channel_request = _create_add_channel_request_frame(
516            channel_id=2, encoding=0,
517            encoded_handshake=encoded_handshake)
518        request.connection.put_bytes(add_channel_request)
519
520        drop_channel = mux._create_drop_channel(channel_id=2,
521                                                outer_frame_mask=True)
522        request.connection.put_bytes(drop_channel)
523
524        # Terminate implicitly opened channel.
525        request.connection.put_bytes(
526            _create_logical_frame(channel_id=1, message='Goodbye'))
527
528        mux_handler.wait_until_done(timeout=2)
529
530        exception = dispatcher.channel_events[2].exception
531        self.assertTrue(exception.__class__ == ConnectionTerminatedException)
532
533    def test_receive_ping_frame(self):
534        request = _create_mock_request()
535        dispatcher = _MuxMockDispatcher()
536        mux_handler = mux._MuxHandler(request, dispatcher)
537        mux_handler.start()
538        mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
539                                      mux._INITIAL_QUOTA_FOR_CLIENT)
540
541        encoded_handshake = _create_request_header(path='/echo')
542        add_channel_request = _create_add_channel_request_frame(
543            channel_id=2, encoding=0,
544            encoded_handshake=encoded_handshake)
545        request.connection.put_bytes(add_channel_request)
546
547        flow_control = mux._create_flow_control(channel_id=2,
548                                                replenished_quota=12,
549                                                outer_frame_mask=True)
550        request.connection.put_bytes(flow_control)
551
552        ping_frame = _create_logical_frame(channel_id=2,
553                                           message='Hello World!',
554                                           opcode=common.OPCODE_PING)
555        request.connection.put_bytes(ping_frame)
556
557        request.connection.put_bytes(
558            _create_logical_frame(channel_id=1, message='Goodbye'))
559        request.connection.put_bytes(
560            _create_logical_frame(channel_id=2, message='Goodbye'))
561
562        mux_handler.wait_until_done(timeout=2)
563
564        messages = request.connection.get_written_control_messages(2)
565        self.assertEqual(common.OPCODE_PONG, messages[0]['opcode'])
566        self.assertEqual('Hello World!', messages[0]['message'])
567
568    def test_send_ping(self):
569        request = _create_mock_request()
570        dispatcher = _MuxMockDispatcher()
571        mux_handler = mux._MuxHandler(request, dispatcher)
572        mux_handler.start()
573        mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
574                                      mux._INITIAL_QUOTA_FOR_CLIENT)
575
576        encoded_handshake = _create_request_header(path='/ping')
577        add_channel_request = _create_add_channel_request_frame(
578            channel_id=2, encoding=0,
579            encoded_handshake=encoded_handshake)
580        request.connection.put_bytes(add_channel_request)
581
582        flow_control = mux._create_flow_control(channel_id=2,
583                                                replenished_quota=5,
584                                                outer_frame_mask=True)
585        request.connection.put_bytes(flow_control)
586
587        request.connection.put_bytes(
588            _create_logical_frame(channel_id=1, message='Goodbye'))
589
590        mux_handler.wait_until_done(timeout=2)
591
592        messages = request.connection.get_written_control_messages(2)
593        self.assertEqual(common.OPCODE_PING, messages[0]['opcode'])
594        self.assertEqual('Ping!', messages[0]['message'])
595
596    def test_two_flow_control(self):
597        request = _create_mock_request()
598        dispatcher = _MuxMockDispatcher()
599        mux_handler = mux._MuxHandler(request, dispatcher)
600        mux_handler.start()
601        mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
602                                      mux._INITIAL_QUOTA_FOR_CLIENT)
603
604        encoded_handshake = _create_request_header(path='/echo')
605        add_channel_request = _create_add_channel_request_frame(
606            channel_id=2, encoding=0,
607            encoded_handshake=encoded_handshake)
608        request.connection.put_bytes(add_channel_request)
609
610        # Replenish 5 bytes.
611        flow_control = mux._create_flow_control(channel_id=2,
612                                                replenished_quota=5,
613                                                outer_frame_mask=True)
614        request.connection.put_bytes(flow_control)
615
616        # Send 10 bytes. The server will try echo back 10 bytes.
617        request.connection.put_bytes(
618            _create_logical_frame(channel_id=2, message='HelloWorld'))
619
620        # Replenish 5 bytes again.
621        flow_control = mux._create_flow_control(channel_id=2,
622                                                replenished_quota=5,
623                                                outer_frame_mask=True)
624        request.connection.put_bytes(flow_control)
625
626        request.connection.put_bytes(
627            _create_logical_frame(channel_id=1, message='Goodbye'))
628        request.connection.put_bytes(
629            _create_logical_frame(channel_id=2, message='Goodbye'))
630
631        mux_handler.wait_until_done(timeout=2)
632
633        messages = request.connection.get_written_messages(2)
634        self.assertEqual(['HelloWorld'], messages)
635
636    def test_no_send_quota_on_server(self):
637        request = _create_mock_request()
638        dispatcher = _MuxMockDispatcher()
639        mux_handler = mux._MuxHandler(request, dispatcher)
640        mux_handler.start()
641        mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
642                                      mux._INITIAL_QUOTA_FOR_CLIENT)
643
644        encoded_handshake = _create_request_header(path='/echo')
645        add_channel_request = _create_add_channel_request_frame(
646            channel_id=2, encoding=0,
647            encoded_handshake=encoded_handshake)
648        request.connection.put_bytes(add_channel_request)
649
650        request.connection.put_bytes(
651            _create_logical_frame(channel_id=2, message='HelloWorld'))
652
653        request.connection.put_bytes(
654            _create_logical_frame(channel_id=1, message='Goodbye'))
655
656        mux_handler.wait_until_done(timeout=1)
657
658        # No message should be sent on channel 2.
659        self.assertRaises(KeyError,
660                          request.connection.get_written_messages,
661                          2)
662
663    def test_quota_violation_by_client(self):
664        request = _create_mock_request()
665        dispatcher = _MuxMockDispatcher()
666        mux_handler = mux._MuxHandler(request, dispatcher)
667        mux_handler.start()
668        mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS, 0)
669
670        encoded_handshake = _create_request_header(path='/echo')
671        add_channel_request = _create_add_channel_request_frame(
672            channel_id=2, encoding=0,
673            encoded_handshake=encoded_handshake)
674        request.connection.put_bytes(add_channel_request)
675
676        request.connection.put_bytes(
677            _create_logical_frame(channel_id=2, message='HelloWorld'))
678
679        request.connection.put_bytes(
680            _create_logical_frame(channel_id=1, message='Goodbye'))
681
682        mux_handler.wait_until_done(timeout=2)
683        control_blocks = request.connection.get_written_control_blocks()
684        # The first block is FlowControl for channel id 1.
685        # The next two blocks are NewChannelSlot and AddChannelResponse.
686        # The 4th block or the last block should be DropChannels for channel 2.
687        # (The order can be mixed up)
688        # The remaining block should be FlowControl for 'Goodbye'.
689        self.assertEqual(5, len(control_blocks))
690        expected_opcode_and_flag = ((mux._MUX_OPCODE_DROP_CHANNEL << 5) |
691                                    (1 << 4))
692        self.assertTrue((expected_opcode_and_flag ==
693                        (ord(control_blocks[3][0]) & 0xf0)) or
694                        (expected_opcode_and_flag ==
695                        (ord(control_blocks[4][0]) & 0xf0)))
696
697    def test_fragmented_control_message(self):
698        request = _create_mock_request()
699        dispatcher = _MuxMockDispatcher()
700        mux_handler = mux._MuxHandler(request, dispatcher)
701        mux_handler.start()
702        mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
703                                      mux._INITIAL_QUOTA_FOR_CLIENT)
704
705        encoded_handshake = _create_request_header(path='/ping')
706        add_channel_request = _create_add_channel_request_frame(
707            channel_id=2, encoding=0,
708            encoded_handshake=encoded_handshake)
709        request.connection.put_bytes(add_channel_request)
710
711        # Replenish total 5 bytes in 3 FlowControls.
712        flow_control = mux._create_flow_control(channel_id=2,
713                                                replenished_quota=1,
714                                                outer_frame_mask=True)
715        request.connection.put_bytes(flow_control)
716
717        flow_control = mux._create_flow_control(channel_id=2,
718                                                replenished_quota=2,
719                                                outer_frame_mask=True)
720        request.connection.put_bytes(flow_control)
721
722        flow_control = mux._create_flow_control(channel_id=2,
723                                                replenished_quota=2,
724                                                outer_frame_mask=True)
725        request.connection.put_bytes(flow_control)
726
727        request.connection.put_bytes(
728            _create_logical_frame(channel_id=1, message='Goodbye'))
729
730        mux_handler.wait_until_done(timeout=2)
731
732        messages = request.connection.get_written_control_messages(2)
733        self.assertEqual(common.OPCODE_PING, messages[0]['opcode'])
734        self.assertEqual('Ping!', messages[0]['message'])
735
736    def test_channel_slot_violation_by_client(self):
737        request = _create_mock_request()
738        dispatcher = _MuxMockDispatcher()
739        mux_handler = mux._MuxHandler(request, dispatcher)
740        mux_handler.start()
741        mux_handler.add_channel_slots(slots=1,
742                                      send_quota=mux._INITIAL_QUOTA_FOR_CLIENT)
743
744        encoded_handshake = _create_request_header(path='/echo')
745        add_channel_request = _create_add_channel_request_frame(
746            channel_id=2, encoding=0,
747            encoded_handshake=encoded_handshake)
748        request.connection.put_bytes(add_channel_request)
749        flow_control = mux._create_flow_control(channel_id=2,
750                                                replenished_quota=10,
751                                                outer_frame_mask=True)
752        request.connection.put_bytes(flow_control)
753
754        request.connection.put_bytes(
755            _create_logical_frame(channel_id=2, message='Hello'))
756
757        # This request should be rejected.
758        encoded_handshake = _create_request_header(path='/echo')
759        add_channel_request = _create_add_channel_request_frame(
760            channel_id=3, encoding=0,
761            encoded_handshake=encoded_handshake)
762        request.connection.put_bytes(add_channel_request)
763        flow_control = mux._create_flow_control(channel_id=3,
764                                                replenished_quota=5,
765                                                outer_frame_mask=True)
766        request.connection.put_bytes(flow_control)
767
768        request.connection.put_bytes(
769            _create_logical_frame(channel_id=3, message='Hello'))
770
771        request.connection.put_bytes(
772            _create_logical_frame(channel_id=1, message='Goodbye'))
773        request.connection.put_bytes(
774            _create_logical_frame(channel_id=2, message='Goodbye'))
775
776        mux_handler.wait_until_done(timeout=2)
777
778        self.assertEqual(['Hello'], dispatcher.channel_events[2].messages)
779        self.assertFalse(dispatcher.channel_events.has_key(3))
780
781
782if __name__ == '__main__':
783    unittest.main()
784
785
786# vi:sts=4 sw=4 et
787