1from __future__ import nested_scopes # Backward compat for 2.1 2from unittest import TestCase 3from wsgiref.util import setup_testing_defaults 4from wsgiref.headers import Headers 5from wsgiref.handlers import BaseHandler, BaseCGIHandler 6from wsgiref import util 7from wsgiref.validate import validator 8from wsgiref.simple_server import WSGIServer, WSGIRequestHandler, demo_app 9from wsgiref.simple_server import make_server 10from StringIO import StringIO 11from SocketServer import BaseServer 12import os 13import re 14import sys 15 16from test import test_support 17 18class MockServer(WSGIServer): 19 """Non-socket HTTP server""" 20 21 def __init__(self, server_address, RequestHandlerClass): 22 BaseServer.__init__(self, server_address, RequestHandlerClass) 23 self.server_bind() 24 25 def server_bind(self): 26 host, port = self.server_address 27 self.server_name = host 28 self.server_port = port 29 self.setup_environ() 30 31 32class MockHandler(WSGIRequestHandler): 33 """Non-socket HTTP handler""" 34 def setup(self): 35 self.connection = self.request 36 self.rfile, self.wfile = self.connection 37 38 def finish(self): 39 pass 40 41 42def hello_app(environ,start_response): 43 start_response("200 OK", [ 44 ('Content-Type','text/plain'), 45 ('Date','Mon, 05 Jun 2006 18:49:54 GMT') 46 ]) 47 return ["Hello, world!"] 48 49def run_amock(app=hello_app, data="GET / HTTP/1.0\n\n"): 50 server = make_server("", 80, app, MockServer, MockHandler) 51 inp, out, err, olderr = StringIO(data), StringIO(), StringIO(), sys.stderr 52 sys.stderr = err 53 54 try: 55 server.finish_request((inp,out), ("127.0.0.1",8888)) 56 finally: 57 sys.stderr = olderr 58 59 return out.getvalue(), err.getvalue() 60 61 62def compare_generic_iter(make_it,match): 63 """Utility to compare a generic 2.1/2.2+ iterator with an iterable 64 65 If running under Python 2.2+, this tests the iterator using iter()/next(), 66 as well as __getitem__. 'make_it' must be a function returning a fresh 67 iterator to be tested (since this may test the iterator twice).""" 68 69 it = make_it() 70 n = 0 71 for item in match: 72 if not it[n]==item: raise AssertionError 73 n+=1 74 try: 75 it[n] 76 except IndexError: 77 pass 78 else: 79 raise AssertionError("Too many items from __getitem__",it) 80 81 try: 82 iter, StopIteration 83 except NameError: 84 pass 85 else: 86 # Only test iter mode under 2.2+ 87 it = make_it() 88 if not iter(it) is it: raise AssertionError 89 for item in match: 90 if not it.next()==item: raise AssertionError 91 try: 92 it.next() 93 except StopIteration: 94 pass 95 else: 96 raise AssertionError("Too many items from .next()",it) 97 98 99class IntegrationTests(TestCase): 100 101 def check_hello(self, out, has_length=True): 102 self.assertEqual(out, 103 "HTTP/1.0 200 OK\r\n" 104 "Server: WSGIServer/0.1 Python/"+sys.version.split()[0]+"\r\n" 105 "Content-Type: text/plain\r\n" 106 "Date: Mon, 05 Jun 2006 18:49:54 GMT\r\n" + 107 (has_length and "Content-Length: 13\r\n" or "") + 108 "\r\n" 109 "Hello, world!" 110 ) 111 112 def test_plain_hello(self): 113 out, err = run_amock() 114 self.check_hello(out) 115 116 def test_validated_hello(self): 117 out, err = run_amock(validator(hello_app)) 118 # the middleware doesn't support len(), so content-length isn't there 119 self.check_hello(out, has_length=False) 120 121 def test_simple_validation_error(self): 122 def bad_app(environ,start_response): 123 start_response("200 OK", ('Content-Type','text/plain')) 124 return ["Hello, world!"] 125 out, err = run_amock(validator(bad_app)) 126 self.assertTrue(out.endswith( 127 "A server error occurred. Please contact the administrator." 128 )) 129 self.assertEqual( 130 err.splitlines()[-2], 131 "AssertionError: Headers (('Content-Type', 'text/plain')) must" 132 " be of type list: <type 'tuple'>" 133 ) 134 135 136class UtilityTests(TestCase): 137 138 def checkShift(self,sn_in,pi_in,part,sn_out,pi_out): 139 env = {'SCRIPT_NAME':sn_in,'PATH_INFO':pi_in} 140 util.setup_testing_defaults(env) 141 self.assertEqual(util.shift_path_info(env),part) 142 self.assertEqual(env['PATH_INFO'],pi_out) 143 self.assertEqual(env['SCRIPT_NAME'],sn_out) 144 return env 145 146 def checkDefault(self, key, value, alt=None): 147 # Check defaulting when empty 148 env = {} 149 util.setup_testing_defaults(env) 150 if isinstance(value, StringIO): 151 self.assertIsInstance(env[key], StringIO) 152 else: 153 self.assertEqual(env[key], value) 154 155 # Check existing value 156 env = {key:alt} 157 util.setup_testing_defaults(env) 158 self.assertTrue(env[key] is alt) 159 160 def checkCrossDefault(self,key,value,**kw): 161 util.setup_testing_defaults(kw) 162 self.assertEqual(kw[key],value) 163 164 def checkAppURI(self,uri,**kw): 165 util.setup_testing_defaults(kw) 166 self.assertEqual(util.application_uri(kw),uri) 167 168 def checkReqURI(self,uri,query=1,**kw): 169 util.setup_testing_defaults(kw) 170 self.assertEqual(util.request_uri(kw,query),uri) 171 172 def checkFW(self,text,size,match): 173 174 def make_it(text=text,size=size): 175 return util.FileWrapper(StringIO(text),size) 176 177 compare_generic_iter(make_it,match) 178 179 it = make_it() 180 self.assertFalse(it.filelike.closed) 181 182 for item in it: 183 pass 184 185 self.assertFalse(it.filelike.closed) 186 187 it.close() 188 self.assertTrue(it.filelike.closed) 189 190 def testSimpleShifts(self): 191 self.checkShift('','/', '', '/', '') 192 self.checkShift('','/x', 'x', '/x', '') 193 self.checkShift('/','', None, '/', '') 194 self.checkShift('/a','/x/y', 'x', '/a/x', '/y') 195 self.checkShift('/a','/x/', 'x', '/a/x', '/') 196 197 def testNormalizedShifts(self): 198 self.checkShift('/a/b', '/../y', '..', '/a', '/y') 199 self.checkShift('', '/../y', '..', '', '/y') 200 self.checkShift('/a/b', '//y', 'y', '/a/b/y', '') 201 self.checkShift('/a/b', '//y/', 'y', '/a/b/y', '/') 202 self.checkShift('/a/b', '/./y', 'y', '/a/b/y', '') 203 self.checkShift('/a/b', '/./y/', 'y', '/a/b/y', '/') 204 self.checkShift('/a/b', '///./..//y/.//', '..', '/a', '/y/') 205 self.checkShift('/a/b', '///', '', '/a/b/', '') 206 self.checkShift('/a/b', '/.//', '', '/a/b/', '') 207 self.checkShift('/a/b', '/x//', 'x', '/a/b/x', '/') 208 self.checkShift('/a/b', '/.', None, '/a/b', '') 209 210 def testDefaults(self): 211 for key, value in [ 212 ('SERVER_NAME','127.0.0.1'), 213 ('SERVER_PORT', '80'), 214 ('SERVER_PROTOCOL','HTTP/1.0'), 215 ('HTTP_HOST','127.0.0.1'), 216 ('REQUEST_METHOD','GET'), 217 ('SCRIPT_NAME',''), 218 ('PATH_INFO','/'), 219 ('wsgi.version', (1,0)), 220 ('wsgi.run_once', 0), 221 ('wsgi.multithread', 0), 222 ('wsgi.multiprocess', 0), 223 ('wsgi.input', StringIO("")), 224 ('wsgi.errors', StringIO()), 225 ('wsgi.url_scheme','http'), 226 ]: 227 self.checkDefault(key,value) 228 229 def testCrossDefaults(self): 230 self.checkCrossDefault('HTTP_HOST',"foo.bar",SERVER_NAME="foo.bar") 231 self.checkCrossDefault('wsgi.url_scheme',"https",HTTPS="on") 232 self.checkCrossDefault('wsgi.url_scheme',"https",HTTPS="1") 233 self.checkCrossDefault('wsgi.url_scheme',"https",HTTPS="yes") 234 self.checkCrossDefault('wsgi.url_scheme',"http",HTTPS="foo") 235 self.checkCrossDefault('SERVER_PORT',"80",HTTPS="foo") 236 self.checkCrossDefault('SERVER_PORT',"443",HTTPS="on") 237 238 def testGuessScheme(self): 239 self.assertEqual(util.guess_scheme({}), "http") 240 self.assertEqual(util.guess_scheme({'HTTPS':"foo"}), "http") 241 self.assertEqual(util.guess_scheme({'HTTPS':"on"}), "https") 242 self.assertEqual(util.guess_scheme({'HTTPS':"yes"}), "https") 243 self.assertEqual(util.guess_scheme({'HTTPS':"1"}), "https") 244 245 def testAppURIs(self): 246 self.checkAppURI("http://127.0.0.1/") 247 self.checkAppURI("http://127.0.0.1/spam", SCRIPT_NAME="/spam") 248 self.checkAppURI("http://spam.example.com:2071/", 249 HTTP_HOST="spam.example.com:2071", SERVER_PORT="2071") 250 self.checkAppURI("http://spam.example.com/", 251 SERVER_NAME="spam.example.com") 252 self.checkAppURI("http://127.0.0.1/", 253 HTTP_HOST="127.0.0.1", SERVER_NAME="spam.example.com") 254 self.checkAppURI("https://127.0.0.1/", HTTPS="on") 255 self.checkAppURI("http://127.0.0.1:8000/", SERVER_PORT="8000", 256 HTTP_HOST=None) 257 258 def testReqURIs(self): 259 self.checkReqURI("http://127.0.0.1/") 260 self.checkReqURI("http://127.0.0.1/spam", SCRIPT_NAME="/spam") 261 self.checkReqURI("http://127.0.0.1/spammity/spam", 262 SCRIPT_NAME="/spammity", PATH_INFO="/spam") 263 self.checkReqURI("http://127.0.0.1/spammity/spam;ham", 264 SCRIPT_NAME="/spammity", PATH_INFO="/spam;ham") 265 self.checkReqURI("http://127.0.0.1/spammity/spam;cookie=1234,5678", 266 SCRIPT_NAME="/spammity", PATH_INFO="/spam;cookie=1234,5678") 267 self.checkReqURI("http://127.0.0.1/spammity/spam?say=ni", 268 SCRIPT_NAME="/spammity", PATH_INFO="/spam",QUERY_STRING="say=ni") 269 self.checkReqURI("http://127.0.0.1/spammity/spam", 0, 270 SCRIPT_NAME="/spammity", PATH_INFO="/spam",QUERY_STRING="say=ni") 271 272 def testFileWrapper(self): 273 self.checkFW("xyz"*50, 120, ["xyz"*40,"xyz"*10]) 274 275 def testHopByHop(self): 276 for hop in ( 277 "Connection Keep-Alive Proxy-Authenticate Proxy-Authorization " 278 "TE Trailers Transfer-Encoding Upgrade" 279 ).split(): 280 for alt in hop, hop.title(), hop.upper(), hop.lower(): 281 self.assertTrue(util.is_hop_by_hop(alt)) 282 283 # Not comprehensive, just a few random header names 284 for hop in ( 285 "Accept Cache-Control Date Pragma Trailer Via Warning" 286 ).split(): 287 for alt in hop, hop.title(), hop.upper(), hop.lower(): 288 self.assertFalse(util.is_hop_by_hop(alt)) 289 290class HeaderTests(TestCase): 291 292 def testMappingInterface(self): 293 test = [('x','y')] 294 self.assertEqual(len(Headers([])),0) 295 self.assertEqual(len(Headers(test[:])),1) 296 self.assertEqual(Headers(test[:]).keys(), ['x']) 297 self.assertEqual(Headers(test[:]).values(), ['y']) 298 self.assertEqual(Headers(test[:]).items(), test) 299 self.assertFalse(Headers(test).items() is test) # must be copy! 300 301 h=Headers([]) 302 del h['foo'] # should not raise an error 303 304 h['Foo'] = 'bar' 305 for m in h.has_key, h.__contains__, h.get, h.get_all, h.__getitem__: 306 self.assertTrue(m('foo')) 307 self.assertTrue(m('Foo')) 308 self.assertTrue(m('FOO')) 309 self.assertFalse(m('bar')) 310 311 self.assertEqual(h['foo'],'bar') 312 h['foo'] = 'baz' 313 self.assertEqual(h['FOO'],'baz') 314 self.assertEqual(h.get_all('foo'),['baz']) 315 316 self.assertEqual(h.get("foo","whee"), "baz") 317 self.assertEqual(h.get("zoo","whee"), "whee") 318 self.assertEqual(h.setdefault("foo","whee"), "baz") 319 self.assertEqual(h.setdefault("zoo","whee"), "whee") 320 self.assertEqual(h["foo"],"baz") 321 self.assertEqual(h["zoo"],"whee") 322 323 def testRequireList(self): 324 self.assertRaises(TypeError, Headers, "foo") 325 326 327 def testExtras(self): 328 h = Headers([]) 329 self.assertEqual(str(h),'\r\n') 330 331 h.add_header('foo','bar',baz="spam") 332 self.assertEqual(h['foo'], 'bar; baz="spam"') 333 self.assertEqual(str(h),'foo: bar; baz="spam"\r\n\r\n') 334 335 h.add_header('Foo','bar',cheese=None) 336 self.assertEqual(h.get_all('foo'), 337 ['bar; baz="spam"', 'bar; cheese']) 338 339 self.assertEqual(str(h), 340 'foo: bar; baz="spam"\r\n' 341 'Foo: bar; cheese\r\n' 342 '\r\n' 343 ) 344 345 346class ErrorHandler(BaseCGIHandler): 347 """Simple handler subclass for testing BaseHandler""" 348 349 # BaseHandler records the OS environment at import time, but envvars 350 # might have been changed later by other tests, which trips up 351 # HandlerTests.testEnviron(). 352 os_environ = dict(os.environ.items()) 353 354 def __init__(self,**kw): 355 setup_testing_defaults(kw) 356 BaseCGIHandler.__init__( 357 self, StringIO(''), StringIO(), StringIO(), kw, 358 multithread=True, multiprocess=True 359 ) 360 361class TestHandler(ErrorHandler): 362 """Simple handler subclass for testing BaseHandler, w/error passthru""" 363 364 def handle_error(self): 365 raise # for testing, we want to see what's happening 366 367 368class HandlerTests(TestCase): 369 370 def checkEnvironAttrs(self, handler): 371 env = handler.environ 372 for attr in [ 373 'version','multithread','multiprocess','run_once','file_wrapper' 374 ]: 375 if attr=='file_wrapper' and handler.wsgi_file_wrapper is None: 376 continue 377 self.assertEqual(getattr(handler,'wsgi_'+attr),env['wsgi.'+attr]) 378 379 def checkOSEnviron(self,handler): 380 empty = {}; setup_testing_defaults(empty) 381 env = handler.environ 382 from os import environ 383 for k,v in environ.items(): 384 if k not in empty: 385 self.assertEqual(env[k],v) 386 for k,v in empty.items(): 387 self.assertIn(k, env) 388 389 def testEnviron(self): 390 h = TestHandler(X="Y") 391 h.setup_environ() 392 self.checkEnvironAttrs(h) 393 self.checkOSEnviron(h) 394 self.assertEqual(h.environ["X"],"Y") 395 396 def testCGIEnviron(self): 397 h = BaseCGIHandler(None,None,None,{}) 398 h.setup_environ() 399 for key in 'wsgi.url_scheme', 'wsgi.input', 'wsgi.errors': 400 self.assertIn(key, h.environ) 401 402 def testScheme(self): 403 h=TestHandler(HTTPS="on"); h.setup_environ() 404 self.assertEqual(h.environ['wsgi.url_scheme'],'https') 405 h=TestHandler(); h.setup_environ() 406 self.assertEqual(h.environ['wsgi.url_scheme'],'http') 407 408 def testAbstractMethods(self): 409 h = BaseHandler() 410 for name in [ 411 '_flush','get_stdin','get_stderr','add_cgi_vars' 412 ]: 413 self.assertRaises(NotImplementedError, getattr(h,name)) 414 self.assertRaises(NotImplementedError, h._write, "test") 415 416 def testContentLength(self): 417 # Demo one reason iteration is better than write()... ;) 418 419 def trivial_app1(e,s): 420 s('200 OK',[]) 421 return [e['wsgi.url_scheme']] 422 423 def trivial_app2(e,s): 424 s('200 OK',[])(e['wsgi.url_scheme']) 425 return [] 426 427 def trivial_app4(e,s): 428 # Simulate a response to a HEAD request 429 s('200 OK',[('Content-Length', '12345')]) 430 return [] 431 432 h = TestHandler() 433 h.run(trivial_app1) 434 self.assertEqual(h.stdout.getvalue(), 435 "Status: 200 OK\r\n" 436 "Content-Length: 4\r\n" 437 "\r\n" 438 "http") 439 440 h = TestHandler() 441 h.run(trivial_app2) 442 self.assertEqual(h.stdout.getvalue(), 443 "Status: 200 OK\r\n" 444 "\r\n" 445 "http") 446 447 448 h = TestHandler() 449 h.run(trivial_app4) 450 self.assertEqual(h.stdout.getvalue(), 451 b'Status: 200 OK\r\n' 452 b'Content-Length: 12345\r\n' 453 b'\r\n') 454 455 def testBasicErrorOutput(self): 456 457 def non_error_app(e,s): 458 s('200 OK',[]) 459 return [] 460 461 def error_app(e,s): 462 raise AssertionError("This should be caught by handler") 463 464 h = ErrorHandler() 465 h.run(non_error_app) 466 self.assertEqual(h.stdout.getvalue(), 467 "Status: 200 OK\r\n" 468 "Content-Length: 0\r\n" 469 "\r\n") 470 self.assertEqual(h.stderr.getvalue(),"") 471 472 h = ErrorHandler() 473 h.run(error_app) 474 self.assertEqual(h.stdout.getvalue(), 475 "Status: %s\r\n" 476 "Content-Type: text/plain\r\n" 477 "Content-Length: %d\r\n" 478 "\r\n%s" % (h.error_status,len(h.error_body),h.error_body)) 479 480 self.assertNotEqual(h.stderr.getvalue().find("AssertionError"), -1) 481 482 def testErrorAfterOutput(self): 483 MSG = "Some output has been sent" 484 def error_app(e,s): 485 s("200 OK",[])(MSG) 486 raise AssertionError("This should be caught by handler") 487 488 h = ErrorHandler() 489 h.run(error_app) 490 self.assertEqual(h.stdout.getvalue(), 491 "Status: 200 OK\r\n" 492 "\r\n"+MSG) 493 self.assertNotEqual(h.stderr.getvalue().find("AssertionError"), -1) 494 495 def testHeaderFormats(self): 496 497 def non_error_app(e,s): 498 s('200 OK',[]) 499 return [] 500 501 stdpat = ( 502 r"HTTP/%s 200 OK\r\n" 503 r"Date: \w{3}, [ 0123]\d \w{3} \d{4} \d\d:\d\d:\d\d GMT\r\n" 504 r"%s" r"Content-Length: 0\r\n" r"\r\n" 505 ) 506 shortpat = ( 507 "Status: 200 OK\r\n" "Content-Length: 0\r\n" "\r\n" 508 ) 509 510 for ssw in "FooBar/1.0", None: 511 sw = ssw and "Server: %s\r\n" % ssw or "" 512 513 for version in "1.0", "1.1": 514 for proto in "HTTP/0.9", "HTTP/1.0", "HTTP/1.1": 515 516 h = TestHandler(SERVER_PROTOCOL=proto) 517 h.origin_server = False 518 h.http_version = version 519 h.server_software = ssw 520 h.run(non_error_app) 521 self.assertEqual(shortpat,h.stdout.getvalue()) 522 523 h = TestHandler(SERVER_PROTOCOL=proto) 524 h.origin_server = True 525 h.http_version = version 526 h.server_software = ssw 527 h.run(non_error_app) 528 if proto=="HTTP/0.9": 529 self.assertEqual(h.stdout.getvalue(),"") 530 else: 531 self.assertTrue( 532 re.match(stdpat%(version,sw), h.stdout.getvalue()), 533 (stdpat%(version,sw), h.stdout.getvalue()) 534 ) 535 536 def testCloseOnError(self): 537 side_effects = {'close_called': False} 538 MSG = b"Some output has been sent" 539 def error_app(e,s): 540 s("200 OK",[])(MSG) 541 class CrashyIterable(object): 542 def __iter__(self): 543 while True: 544 yield b'blah' 545 raise AssertionError("This should be caught by handler") 546 547 def close(self): 548 side_effects['close_called'] = True 549 return CrashyIterable() 550 551 h = ErrorHandler() 552 h.run(error_app) 553 self.assertEqual(side_effects['close_called'], True) 554 555 556def test_main(): 557 test_support.run_unittest(__name__) 558 559if __name__ == "__main__": 560 test_main() 561