1#!/usr/bin/env python
2#
3# Copyright 2011, Google Inc.
4# All rights reserved.
5#
6# Redistribution and use in source and binary forms, with or without
7# modification, are permitted provided that the following conditions are
8# met:
9#
10#     * Redistributions of source code must retain the above copyright
11# notice, this list of conditions and the following disclaimer.
12#     * Redistributions in binary form must reproduce the above
13# copyright notice, this list of conditions and the following disclaimer
14# in the documentation and/or other materials provided with the
15# distribution.
16#     * Neither the name of Google Inc. nor the names of its
17# contributors may be used to endorse or promote products derived from
18# this software without specific prior written permission.
19#
20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
32
33"""Tests for handshake module."""
34
35
36import unittest
37
38import set_sys_path  # Update sys.path to locate mod_pywebsocket module.
39from mod_pywebsocket import common
40from mod_pywebsocket.handshake._base import AbortedByUserException
41from mod_pywebsocket.handshake._base import HandshakeException
42from mod_pywebsocket.handshake._base import VersionException
43from mod_pywebsocket.handshake.hybi import Handshaker
44
45import mock
46
47
48class RequestDefinition(object):
49    """A class for holding data for constructing opening handshake strings for
50    testing the opening handshake processor.
51    """
52
53    def __init__(self, method, uri, headers):
54        self.method = method
55        self.uri = uri
56        self.headers = headers
57
58
59def _create_good_request_def():
60    return RequestDefinition(
61        'GET', '/demo',
62        {'Host': 'server.example.com',
63         'Upgrade': 'websocket',
64         'Connection': 'Upgrade',
65         'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==',
66         'Sec-WebSocket-Version': '13',
67         'Origin': 'http://example.com'})
68
69
70def _create_request(request_def):
71    conn = mock.MockConn('')
72    return mock.MockRequest(
73        method=request_def.method,
74        uri=request_def.uri,
75        headers_in=request_def.headers,
76        connection=conn)
77
78
79def _create_handshaker(request):
80    handshaker = Handshaker(request, mock.MockDispatcher())
81    return handshaker
82
83
84class SubprotocolChoosingDispatcher(object):
85    """A dispatcher for testing. This dispatcher sets the i-th subprotocol
86    of requested ones to ws_protocol where i is given on construction as index
87    argument. If index is negative, default_value will be set to ws_protocol.
88    """
89
90    def __init__(self, index, default_value=None):
91        self.index = index
92        self.default_value = default_value
93
94    def do_extra_handshake(self, conn_context):
95        if self.index >= 0:
96            conn_context.ws_protocol = conn_context.ws_requested_protocols[
97                self.index]
98        else:
99            conn_context.ws_protocol = self.default_value
100
101    def transfer_data(self, conn_context):
102        pass
103
104
105class HandshakeAbortedException(Exception):
106    pass
107
108
109class AbortingDispatcher(object):
110    """A dispatcher for testing. This dispatcher raises an exception in
111    do_extra_handshake to reject the request.
112    """
113
114    def do_extra_handshake(self, conn_context):
115        raise HandshakeAbortedException('An exception to reject the request')
116
117    def transfer_data(self, conn_context):
118        pass
119
120
121class AbortedByUserDispatcher(object):
122    """A dispatcher for testing. This dispatcher raises an
123    AbortedByUserException in do_extra_handshake to reject the request.
124    """
125
126    def do_extra_handshake(self, conn_context):
127        raise AbortedByUserException('An AbortedByUserException to reject the '
128                                     'request')
129
130    def transfer_data(self, conn_context):
131        pass
132
133
134_EXPECTED_RESPONSE = (
135    'HTTP/1.1 101 Switching Protocols\r\n'
136    'Upgrade: websocket\r\n'
137    'Connection: Upgrade\r\n'
138    'Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n')
139
140
141class HandshakerTest(unittest.TestCase):
142    """A unittest for draft-ietf-hybi-thewebsocketprotocol-06 and later
143    handshake processor.
144    """
145
146    def test_do_handshake(self):
147        request = _create_request(_create_good_request_def())
148        dispatcher = mock.MockDispatcher()
149        handshaker = Handshaker(request, dispatcher)
150        handshaker.do_handshake()
151
152        self.assertTrue(dispatcher.do_extra_handshake_called)
153
154        self.assertEqual(
155            _EXPECTED_RESPONSE, request.connection.written_data())
156        self.assertEqual('/demo', request.ws_resource)
157        self.assertEqual('http://example.com', request.ws_origin)
158        self.assertEqual(None, request.ws_protocol)
159        self.assertEqual(None, request.ws_extensions)
160        self.assertEqual(common.VERSION_HYBI_LATEST, request.ws_version)
161
162    def test_do_handshake_with_capitalized_value(self):
163        request_def = _create_good_request_def()
164        request_def.headers['upgrade'] = 'WEBSOCKET'
165
166        request = _create_request(request_def)
167        handshaker = _create_handshaker(request)
168        handshaker.do_handshake()
169        self.assertEqual(
170            _EXPECTED_RESPONSE, request.connection.written_data())
171
172        request_def = _create_good_request_def()
173        request_def.headers['Connection'] = 'UPGRADE'
174
175        request = _create_request(request_def)
176        handshaker = _create_handshaker(request)
177        handshaker.do_handshake()
178        self.assertEqual(
179            _EXPECTED_RESPONSE, request.connection.written_data())
180
181    def test_do_handshake_with_multiple_connection_values(self):
182        request_def = _create_good_request_def()
183        request_def.headers['Connection'] = 'Upgrade, keep-alive, , '
184
185        request = _create_request(request_def)
186        handshaker = _create_handshaker(request)
187        handshaker.do_handshake()
188        self.assertEqual(
189            _EXPECTED_RESPONSE, request.connection.written_data())
190
191    def test_aborting_handshake(self):
192        handshaker = Handshaker(
193            _create_request(_create_good_request_def()),
194            AbortingDispatcher())
195        # do_extra_handshake raises an exception. Check that it's not caught by
196        # do_handshake.
197        self.assertRaises(HandshakeAbortedException, handshaker.do_handshake)
198
199    def test_do_handshake_with_protocol(self):
200        request_def = _create_good_request_def()
201        request_def.headers['Sec-WebSocket-Protocol'] = 'chat, superchat'
202
203        request = _create_request(request_def)
204        handshaker = Handshaker(request, SubprotocolChoosingDispatcher(0))
205        handshaker.do_handshake()
206
207        EXPECTED_RESPONSE = (
208            'HTTP/1.1 101 Switching Protocols\r\n'
209            'Upgrade: websocket\r\n'
210            'Connection: Upgrade\r\n'
211            'Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n'
212            'Sec-WebSocket-Protocol: chat\r\n\r\n')
213
214        self.assertEqual(EXPECTED_RESPONSE, request.connection.written_data())
215        self.assertEqual('chat', request.ws_protocol)
216
217    def test_do_handshake_protocol_not_in_request_but_in_response(self):
218        request_def = _create_good_request_def()
219        request = _create_request(request_def)
220        handshaker = Handshaker(
221            request, SubprotocolChoosingDispatcher(-1, 'foobar'))
222        # No request has been made but ws_protocol is set. HandshakeException
223        # must be raised.
224        self.assertRaises(HandshakeException, handshaker.do_handshake)
225
226    def test_do_handshake_with_protocol_no_protocol_selection(self):
227        request_def = _create_good_request_def()
228        request_def.headers['Sec-WebSocket-Protocol'] = 'chat, superchat'
229
230        request = _create_request(request_def)
231        handshaker = _create_handshaker(request)
232        # ws_protocol is not set. HandshakeException must be raised.
233        self.assertRaises(HandshakeException, handshaker.do_handshake)
234
235    def test_do_handshake_with_extensions(self):
236        request_def = _create_good_request_def()
237        request_def.headers['Sec-WebSocket-Extensions'] = (
238            'deflate-stream, unknown')
239
240        EXPECTED_RESPONSE = (
241            'HTTP/1.1 101 Switching Protocols\r\n'
242            'Upgrade: websocket\r\n'
243            'Connection: Upgrade\r\n'
244            'Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n'
245            'Sec-WebSocket-Extensions: deflate-stream\r\n\r\n')
246
247        request = _create_request(request_def)
248        handshaker = _create_handshaker(request)
249        handshaker.do_handshake()
250        self.assertEqual(EXPECTED_RESPONSE, request.connection.written_data())
251        self.assertEqual(1, len(request.ws_extensions))
252        extension = request.ws_extensions[0]
253        self.assertEqual('deflate-stream', extension.name())
254        self.assertEqual(0, len(extension.get_parameter_names()))
255
256    def test_do_handshake_with_quoted_extensions(self):
257        request_def = _create_good_request_def()
258        request_def.headers['Sec-WebSocket-Extensions'] = (
259            'deflate-stream, , '
260            'unknown; e   =    "mc^2"; ma="\r\n      \\\rf  "; pv=nrt')
261
262        request = _create_request(request_def)
263        handshaker = _create_handshaker(request)
264        handshaker.do_handshake()
265        self.assertEqual(2, len(request.ws_requested_extensions))
266        first_extension = request.ws_requested_extensions[0]
267        self.assertEqual('deflate-stream', first_extension.name())
268        self.assertEqual(0, len(first_extension.get_parameter_names()))
269        second_extension = request.ws_requested_extensions[1]
270        self.assertEqual('unknown', second_extension.name())
271        self.assertEqual(
272            ['e', 'ma', 'pv'], second_extension.get_parameter_names())
273        self.assertEqual('mc^2', second_extension.get_parameter_value('e'))
274        self.assertEqual(' \rf ', second_extension.get_parameter_value('ma'))
275        self.assertEqual('nrt', second_extension.get_parameter_value('pv'))
276
277    def test_do_handshake_with_optional_headers(self):
278        request_def = _create_good_request_def()
279        request_def.headers['EmptyValue'] = ''
280        request_def.headers['AKey'] = 'AValue'
281
282        request = _create_request(request_def)
283        handshaker = _create_handshaker(request)
284        handshaker.do_handshake()
285        self.assertEqual(
286            'AValue', request.headers_in['AKey'])
287        self.assertEqual(
288            '', request.headers_in['EmptyValue'])
289
290    def test_abort_extra_handshake(self):
291        handshaker = Handshaker(
292            _create_request(_create_good_request_def()),
293            AbortedByUserDispatcher())
294        # do_extra_handshake raises an AbortedByUserException. Check that it's
295        # not caught by do_handshake.
296        self.assertRaises(AbortedByUserException, handshaker.do_handshake)
297
298    def test_do_handshake_with_mux_and_deflateframe(self):
299        request_def = _create_good_request_def()
300        request_def.headers['Sec-WebSocket-Extensions'] = ('%s, %s' % (
301                common.MUX_EXTENSION,
302                common.DEFLATE_FRAME_EXTENSION))
303        request = _create_request(request_def)
304        handshaker = _create_handshaker(request)
305        handshaker.do_handshake()
306        self.assertEqual(2, len(request.ws_extensions))
307        self.assertEqual(common.MUX_EXTENSION,
308                         request.ws_extensions[0].name())
309        self.assertEqual(common.DEFLATE_FRAME_EXTENSION,
310                         request.ws_extensions[1].name())
311        self.assertTrue(request.mux)
312        self.assertEqual(0, len(request.mux_extensions))
313
314    def test_do_handshake_with_deflateframe_and_mux(self):
315        request_def = _create_good_request_def()
316        request_def.headers['Sec-WebSocket-Extensions'] = ('%s, %s' % (
317                common.DEFLATE_FRAME_EXTENSION,
318                common.MUX_EXTENSION))
319        request = _create_request(request_def)
320        handshaker = _create_handshaker(request)
321        handshaker.do_handshake()
322        # mux should be rejected.
323        self.assertEqual(1, len(request.ws_extensions))
324        first_extension = request.ws_extensions[0]
325        self.assertEqual(common.DEFLATE_FRAME_EXTENSION,
326                         first_extension.name())
327
328    def test_bad_requests(self):
329        bad_cases = [
330            ('HTTP request',
331             RequestDefinition(
332                 'GET', '/demo',
333                 {'Host': 'www.google.com',
334                  'User-Agent':
335                      'Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10.5;'
336                      ' en-US; rv:1.9.1.3) Gecko/20090824 Firefox/3.5.3'
337                      ' GTB6 GTBA',
338                  'Accept':
339                      'text/html,application/xhtml+xml,application/xml;q=0.9,'
340                      '*/*;q=0.8',
341                  'Accept-Language': 'en-us,en;q=0.5',
342                  'Accept-Encoding': 'gzip,deflate',
343                  'Accept-Charset': 'ISO-8859-1,utf-8;q=0.7,*;q=0.7',
344                  'Keep-Alive': '300',
345                  'Connection': 'keep-alive'}), None, True)]
346
347        request_def = _create_good_request_def()
348        request_def.method = 'POST'
349        bad_cases.append(('Wrong method', request_def, None, True))
350
351        request_def = _create_good_request_def()
352        del request_def.headers['Host']
353        bad_cases.append(('Missing Host', request_def, None, True))
354
355        request_def = _create_good_request_def()
356        del request_def.headers['Upgrade']
357        bad_cases.append(('Missing Upgrade', request_def, None, True))
358
359        request_def = _create_good_request_def()
360        request_def.headers['Upgrade'] = 'nonwebsocket'
361        bad_cases.append(('Wrong Upgrade', request_def, None, True))
362
363        request_def = _create_good_request_def()
364        del request_def.headers['Connection']
365        bad_cases.append(('Missing Connection', request_def, None, True))
366
367        request_def = _create_good_request_def()
368        request_def.headers['Connection'] = 'Downgrade'
369        bad_cases.append(('Wrong Connection', request_def, None, True))
370
371        request_def = _create_good_request_def()
372        del request_def.headers['Sec-WebSocket-Key']
373        bad_cases.append(('Missing Sec-WebSocket-Key', request_def, 400, True))
374
375        request_def = _create_good_request_def()
376        request_def.headers['Sec-WebSocket-Key'] = (
377            'dGhlIHNhbXBsZSBub25jZQ==garbage')
378        bad_cases.append(('Wrong Sec-WebSocket-Key (with garbage on the tail)',
379                          request_def, 400, True))
380
381        request_def = _create_good_request_def()
382        request_def.headers['Sec-WebSocket-Key'] = 'YQ=='  # BASE64 of 'a'
383        bad_cases.append(
384            ('Wrong Sec-WebSocket-Key (decoded value is not 16 octets long)',
385             request_def, 400, True))
386
387        request_def = _create_good_request_def()
388        # The last character right before == must be any of A, Q, w and g.
389        request_def.headers['Sec-WebSocket-Key'] = (
390            'AQIDBAUGBwgJCgsMDQ4PEC==')
391        bad_cases.append(
392            ('Wrong Sec-WebSocket-Key (padding bits are not zero)',
393             request_def, 400, True))
394
395        request_def = _create_good_request_def()
396        request_def.headers['Sec-WebSocket-Key'] = (
397            'dGhlIHNhbXBsZSBub25jZQ==,dGhlIHNhbXBsZSBub25jZQ==')
398        bad_cases.append(
399            ('Wrong Sec-WebSocket-Key (multiple values)',
400             request_def, 400, True))
401
402        request_def = _create_good_request_def()
403        del request_def.headers['Sec-WebSocket-Version']
404        bad_cases.append(('Missing Sec-WebSocket-Version', request_def, None,
405                          True))
406
407        request_def = _create_good_request_def()
408        request_def.headers['Sec-WebSocket-Version'] = '3'
409        bad_cases.append(('Wrong Sec-WebSocket-Version', request_def, None,
410                          False))
411
412        request_def = _create_good_request_def()
413        request_def.headers['Sec-WebSocket-Version'] = '13, 13'
414        bad_cases.append(('Wrong Sec-WebSocket-Version (multiple values)',
415                          request_def, 400, True))
416
417        for (case_name, request_def, expected_status,
418             expect_handshake_exception) in bad_cases:
419            request = _create_request(request_def)
420            handshaker = Handshaker(request, mock.MockDispatcher())
421            try:
422                handshaker.do_handshake()
423                self.fail('No exception thrown for \'%s\' case' % case_name)
424            except HandshakeException, e:
425                self.assertTrue(expect_handshake_exception)
426                self.assertEqual(expected_status, e.status)
427            except VersionException, e:
428                self.assertFalse(expect_handshake_exception)
429
430
431if __name__ == '__main__':
432    unittest.main()
433
434
435# vi:sts=4 sw=4 et
436