1#!/usr/bin/env python 2# 3# Copyright 2011, Google Inc. 4# All rights reserved. 5# 6# Redistribution and use in source and binary forms, with or without 7# modification, are permitted provided that the following conditions are 8# met: 9# 10# * Redistributions of source code must retain the above copyright 11# notice, this list of conditions and the following disclaimer. 12# * Redistributions in binary form must reproduce the above 13# copyright notice, this list of conditions and the following disclaimer 14# in the documentation and/or other materials provided with the 15# distribution. 16# * Neither the name of Google Inc. nor the names of its 17# contributors may be used to endorse or promote products derived from 18# this software without specific prior written permission. 19# 20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 32 33"""Tests for handshake module.""" 34 35 36import unittest 37 38import set_sys_path # Update sys.path to locate mod_pywebsocket module. 39from mod_pywebsocket import common 40from mod_pywebsocket.handshake._base import AbortedByUserException 41from mod_pywebsocket.handshake._base import HandshakeException 42from mod_pywebsocket.handshake._base import VersionException 43from mod_pywebsocket.handshake.hybi import Handshaker 44 45import mock 46 47 48class RequestDefinition(object): 49 """A class for holding data for constructing opening handshake strings for 50 testing the opening handshake processor. 51 """ 52 53 def __init__(self, method, uri, headers): 54 self.method = method 55 self.uri = uri 56 self.headers = headers 57 58 59def _create_good_request_def(): 60 return RequestDefinition( 61 'GET', '/demo', 62 {'Host': 'server.example.com', 63 'Upgrade': 'websocket', 64 'Connection': 'Upgrade', 65 'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==', 66 'Sec-WebSocket-Version': '13', 67 'Origin': 'http://example.com'}) 68 69 70def _create_request(request_def): 71 conn = mock.MockConn('') 72 return mock.MockRequest( 73 method=request_def.method, 74 uri=request_def.uri, 75 headers_in=request_def.headers, 76 connection=conn) 77 78 79def _create_handshaker(request): 80 handshaker = Handshaker(request, mock.MockDispatcher()) 81 return handshaker 82 83 84class SubprotocolChoosingDispatcher(object): 85 """A dispatcher for testing. This dispatcher sets the i-th subprotocol 86 of requested ones to ws_protocol where i is given on construction as index 87 argument. If index is negative, default_value will be set to ws_protocol. 88 """ 89 90 def __init__(self, index, default_value=None): 91 self.index = index 92 self.default_value = default_value 93 94 def do_extra_handshake(self, conn_context): 95 if self.index >= 0: 96 conn_context.ws_protocol = conn_context.ws_requested_protocols[ 97 self.index] 98 else: 99 conn_context.ws_protocol = self.default_value 100 101 def transfer_data(self, conn_context): 102 pass 103 104 105class HandshakeAbortedException(Exception): 106 pass 107 108 109class AbortingDispatcher(object): 110 """A dispatcher for testing. This dispatcher raises an exception in 111 do_extra_handshake to reject the request. 112 """ 113 114 def do_extra_handshake(self, conn_context): 115 raise HandshakeAbortedException('An exception to reject the request') 116 117 def transfer_data(self, conn_context): 118 pass 119 120 121class AbortedByUserDispatcher(object): 122 """A dispatcher for testing. This dispatcher raises an 123 AbortedByUserException in do_extra_handshake to reject the request. 124 """ 125 126 def do_extra_handshake(self, conn_context): 127 raise AbortedByUserException('An AbortedByUserException to reject the ' 128 'request') 129 130 def transfer_data(self, conn_context): 131 pass 132 133 134_EXPECTED_RESPONSE = ( 135 'HTTP/1.1 101 Switching Protocols\r\n' 136 'Upgrade: websocket\r\n' 137 'Connection: Upgrade\r\n' 138 'Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n') 139 140 141class HandshakerTest(unittest.TestCase): 142 """A unittest for draft-ietf-hybi-thewebsocketprotocol-06 and later 143 handshake processor. 144 """ 145 146 def test_do_handshake(self): 147 request = _create_request(_create_good_request_def()) 148 dispatcher = mock.MockDispatcher() 149 handshaker = Handshaker(request, dispatcher) 150 handshaker.do_handshake() 151 152 self.assertTrue(dispatcher.do_extra_handshake_called) 153 154 self.assertEqual( 155 _EXPECTED_RESPONSE, request.connection.written_data()) 156 self.assertEqual('/demo', request.ws_resource) 157 self.assertEqual('http://example.com', request.ws_origin) 158 self.assertEqual(None, request.ws_protocol) 159 self.assertEqual(None, request.ws_extensions) 160 self.assertEqual(common.VERSION_HYBI_LATEST, request.ws_version) 161 162 def test_do_handshake_with_capitalized_value(self): 163 request_def = _create_good_request_def() 164 request_def.headers['upgrade'] = 'WEBSOCKET' 165 166 request = _create_request(request_def) 167 handshaker = _create_handshaker(request) 168 handshaker.do_handshake() 169 self.assertEqual( 170 _EXPECTED_RESPONSE, request.connection.written_data()) 171 172 request_def = _create_good_request_def() 173 request_def.headers['Connection'] = 'UPGRADE' 174 175 request = _create_request(request_def) 176 handshaker = _create_handshaker(request) 177 handshaker.do_handshake() 178 self.assertEqual( 179 _EXPECTED_RESPONSE, request.connection.written_data()) 180 181 def test_do_handshake_with_multiple_connection_values(self): 182 request_def = _create_good_request_def() 183 request_def.headers['Connection'] = 'Upgrade, keep-alive, , ' 184 185 request = _create_request(request_def) 186 handshaker = _create_handshaker(request) 187 handshaker.do_handshake() 188 self.assertEqual( 189 _EXPECTED_RESPONSE, request.connection.written_data()) 190 191 def test_aborting_handshake(self): 192 handshaker = Handshaker( 193 _create_request(_create_good_request_def()), 194 AbortingDispatcher()) 195 # do_extra_handshake raises an exception. Check that it's not caught by 196 # do_handshake. 197 self.assertRaises(HandshakeAbortedException, handshaker.do_handshake) 198 199 def test_do_handshake_with_protocol(self): 200 request_def = _create_good_request_def() 201 request_def.headers['Sec-WebSocket-Protocol'] = 'chat, superchat' 202 203 request = _create_request(request_def) 204 handshaker = Handshaker(request, SubprotocolChoosingDispatcher(0)) 205 handshaker.do_handshake() 206 207 EXPECTED_RESPONSE = ( 208 'HTTP/1.1 101 Switching Protocols\r\n' 209 'Upgrade: websocket\r\n' 210 'Connection: Upgrade\r\n' 211 'Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n' 212 'Sec-WebSocket-Protocol: chat\r\n\r\n') 213 214 self.assertEqual(EXPECTED_RESPONSE, request.connection.written_data()) 215 self.assertEqual('chat', request.ws_protocol) 216 217 def test_do_handshake_protocol_not_in_request_but_in_response(self): 218 request_def = _create_good_request_def() 219 request = _create_request(request_def) 220 handshaker = Handshaker( 221 request, SubprotocolChoosingDispatcher(-1, 'foobar')) 222 # No request has been made but ws_protocol is set. HandshakeException 223 # must be raised. 224 self.assertRaises(HandshakeException, handshaker.do_handshake) 225 226 def test_do_handshake_with_protocol_no_protocol_selection(self): 227 request_def = _create_good_request_def() 228 request_def.headers['Sec-WebSocket-Protocol'] = 'chat, superchat' 229 230 request = _create_request(request_def) 231 handshaker = _create_handshaker(request) 232 # ws_protocol is not set. HandshakeException must be raised. 233 self.assertRaises(HandshakeException, handshaker.do_handshake) 234 235 def test_do_handshake_with_extensions(self): 236 request_def = _create_good_request_def() 237 request_def.headers['Sec-WebSocket-Extensions'] = ( 238 'deflate-stream, unknown') 239 240 EXPECTED_RESPONSE = ( 241 'HTTP/1.1 101 Switching Protocols\r\n' 242 'Upgrade: websocket\r\n' 243 'Connection: Upgrade\r\n' 244 'Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n' 245 'Sec-WebSocket-Extensions: deflate-stream\r\n\r\n') 246 247 request = _create_request(request_def) 248 handshaker = _create_handshaker(request) 249 handshaker.do_handshake() 250 self.assertEqual(EXPECTED_RESPONSE, request.connection.written_data()) 251 self.assertEqual(1, len(request.ws_extensions)) 252 extension = request.ws_extensions[0] 253 self.assertEqual('deflate-stream', extension.name()) 254 self.assertEqual(0, len(extension.get_parameter_names())) 255 256 def test_do_handshake_with_quoted_extensions(self): 257 request_def = _create_good_request_def() 258 request_def.headers['Sec-WebSocket-Extensions'] = ( 259 'deflate-stream, , ' 260 'unknown; e = "mc^2"; ma="\r\n \\\rf "; pv=nrt') 261 262 request = _create_request(request_def) 263 handshaker = _create_handshaker(request) 264 handshaker.do_handshake() 265 self.assertEqual(2, len(request.ws_requested_extensions)) 266 first_extension = request.ws_requested_extensions[0] 267 self.assertEqual('deflate-stream', first_extension.name()) 268 self.assertEqual(0, len(first_extension.get_parameter_names())) 269 second_extension = request.ws_requested_extensions[1] 270 self.assertEqual('unknown', second_extension.name()) 271 self.assertEqual( 272 ['e', 'ma', 'pv'], second_extension.get_parameter_names()) 273 self.assertEqual('mc^2', second_extension.get_parameter_value('e')) 274 self.assertEqual(' \rf ', second_extension.get_parameter_value('ma')) 275 self.assertEqual('nrt', second_extension.get_parameter_value('pv')) 276 277 def test_do_handshake_with_optional_headers(self): 278 request_def = _create_good_request_def() 279 request_def.headers['EmptyValue'] = '' 280 request_def.headers['AKey'] = 'AValue' 281 282 request = _create_request(request_def) 283 handshaker = _create_handshaker(request) 284 handshaker.do_handshake() 285 self.assertEqual( 286 'AValue', request.headers_in['AKey']) 287 self.assertEqual( 288 '', request.headers_in['EmptyValue']) 289 290 def test_abort_extra_handshake(self): 291 handshaker = Handshaker( 292 _create_request(_create_good_request_def()), 293 AbortedByUserDispatcher()) 294 # do_extra_handshake raises an AbortedByUserException. Check that it's 295 # not caught by do_handshake. 296 self.assertRaises(AbortedByUserException, handshaker.do_handshake) 297 298 def test_do_handshake_with_mux_and_deflateframe(self): 299 request_def = _create_good_request_def() 300 request_def.headers['Sec-WebSocket-Extensions'] = ('%s, %s' % ( 301 common.MUX_EXTENSION, 302 common.DEFLATE_FRAME_EXTENSION)) 303 request = _create_request(request_def) 304 handshaker = _create_handshaker(request) 305 handshaker.do_handshake() 306 self.assertEqual(2, len(request.ws_extensions)) 307 self.assertEqual(common.MUX_EXTENSION, 308 request.ws_extensions[0].name()) 309 self.assertEqual(common.DEFLATE_FRAME_EXTENSION, 310 request.ws_extensions[1].name()) 311 self.assertTrue(request.mux) 312 self.assertEqual(0, len(request.mux_extensions)) 313 314 def test_do_handshake_with_deflateframe_and_mux(self): 315 request_def = _create_good_request_def() 316 request_def.headers['Sec-WebSocket-Extensions'] = ('%s, %s' % ( 317 common.DEFLATE_FRAME_EXTENSION, 318 common.MUX_EXTENSION)) 319 request = _create_request(request_def) 320 handshaker = _create_handshaker(request) 321 handshaker.do_handshake() 322 # mux should be rejected. 323 self.assertEqual(1, len(request.ws_extensions)) 324 first_extension = request.ws_extensions[0] 325 self.assertEqual(common.DEFLATE_FRAME_EXTENSION, 326 first_extension.name()) 327 328 def test_bad_requests(self): 329 bad_cases = [ 330 ('HTTP request', 331 RequestDefinition( 332 'GET', '/demo', 333 {'Host': 'www.google.com', 334 'User-Agent': 335 'Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10.5;' 336 ' en-US; rv:1.9.1.3) Gecko/20090824 Firefox/3.5.3' 337 ' GTB6 GTBA', 338 'Accept': 339 'text/html,application/xhtml+xml,application/xml;q=0.9,' 340 '*/*;q=0.8', 341 'Accept-Language': 'en-us,en;q=0.5', 342 'Accept-Encoding': 'gzip,deflate', 343 'Accept-Charset': 'ISO-8859-1,utf-8;q=0.7,*;q=0.7', 344 'Keep-Alive': '300', 345 'Connection': 'keep-alive'}), None, True)] 346 347 request_def = _create_good_request_def() 348 request_def.method = 'POST' 349 bad_cases.append(('Wrong method', request_def, None, True)) 350 351 request_def = _create_good_request_def() 352 del request_def.headers['Host'] 353 bad_cases.append(('Missing Host', request_def, None, True)) 354 355 request_def = _create_good_request_def() 356 del request_def.headers['Upgrade'] 357 bad_cases.append(('Missing Upgrade', request_def, None, True)) 358 359 request_def = _create_good_request_def() 360 request_def.headers['Upgrade'] = 'nonwebsocket' 361 bad_cases.append(('Wrong Upgrade', request_def, None, True)) 362 363 request_def = _create_good_request_def() 364 del request_def.headers['Connection'] 365 bad_cases.append(('Missing Connection', request_def, None, True)) 366 367 request_def = _create_good_request_def() 368 request_def.headers['Connection'] = 'Downgrade' 369 bad_cases.append(('Wrong Connection', request_def, None, True)) 370 371 request_def = _create_good_request_def() 372 del request_def.headers['Sec-WebSocket-Key'] 373 bad_cases.append(('Missing Sec-WebSocket-Key', request_def, 400, True)) 374 375 request_def = _create_good_request_def() 376 request_def.headers['Sec-WebSocket-Key'] = ( 377 'dGhlIHNhbXBsZSBub25jZQ==garbage') 378 bad_cases.append(('Wrong Sec-WebSocket-Key (with garbage on the tail)', 379 request_def, 400, True)) 380 381 request_def = _create_good_request_def() 382 request_def.headers['Sec-WebSocket-Key'] = 'YQ==' # BASE64 of 'a' 383 bad_cases.append( 384 ('Wrong Sec-WebSocket-Key (decoded value is not 16 octets long)', 385 request_def, 400, True)) 386 387 request_def = _create_good_request_def() 388 # The last character right before == must be any of A, Q, w and g. 389 request_def.headers['Sec-WebSocket-Key'] = ( 390 'AQIDBAUGBwgJCgsMDQ4PEC==') 391 bad_cases.append( 392 ('Wrong Sec-WebSocket-Key (padding bits are not zero)', 393 request_def, 400, True)) 394 395 request_def = _create_good_request_def() 396 request_def.headers['Sec-WebSocket-Key'] = ( 397 'dGhlIHNhbXBsZSBub25jZQ==,dGhlIHNhbXBsZSBub25jZQ==') 398 bad_cases.append( 399 ('Wrong Sec-WebSocket-Key (multiple values)', 400 request_def, 400, True)) 401 402 request_def = _create_good_request_def() 403 del request_def.headers['Sec-WebSocket-Version'] 404 bad_cases.append(('Missing Sec-WebSocket-Version', request_def, None, 405 True)) 406 407 request_def = _create_good_request_def() 408 request_def.headers['Sec-WebSocket-Version'] = '3' 409 bad_cases.append(('Wrong Sec-WebSocket-Version', request_def, None, 410 False)) 411 412 request_def = _create_good_request_def() 413 request_def.headers['Sec-WebSocket-Version'] = '13, 13' 414 bad_cases.append(('Wrong Sec-WebSocket-Version (multiple values)', 415 request_def, 400, True)) 416 417 for (case_name, request_def, expected_status, 418 expect_handshake_exception) in bad_cases: 419 request = _create_request(request_def) 420 handshaker = Handshaker(request, mock.MockDispatcher()) 421 try: 422 handshaker.do_handshake() 423 self.fail('No exception thrown for \'%s\' case' % case_name) 424 except HandshakeException, e: 425 self.assertTrue(expect_handshake_exception) 426 self.assertEqual(expected_status, e.status) 427 except VersionException, e: 428 self.assertFalse(expect_handshake_exception) 429 430 431if __name__ == '__main__': 432 unittest.main() 433 434 435# vi:sts=4 sw=4 et 436