sock_diag_test.py revision 2c96358e57d689833af77ec0c13831a9ae1b27b8
1#!/usr/bin/python 2# 3# Copyright 2015 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17from errno import * 18import os 19import random 20from socket import * 21import time 22import unittest 23 24import csocket 25import cstruct 26import multinetwork_base 27import net_test 28import packets 29import sock_diag 30import threading 31 32 33NUM_SOCKETS = 100 34NO_BYTECODE = "" 35 36# TODO: Backport SOCK_DESTROY and delete this. 37HAVE_SOCK_DESTROY = net_test.LINUX_VERSION >= (4, 4) 38 39 40class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest): 41 42 @staticmethod 43 def _CreateLotsOfSockets(): 44 # Dict mapping (addr, sport, dport) tuples to socketpairs. 45 socketpairs = {} 46 for i in xrange(NUM_SOCKETS): 47 family, addr = random.choice([(AF_INET, "127.0.0.1"), (AF_INET6, "::1")]) 48 socketpair = net_test.CreateSocketPair(family, SOCK_STREAM, addr) 49 sport, dport = (socketpair[0].getsockname()[1], 50 socketpair[1].getsockname()[1]) 51 socketpairs[(addr, sport, dport)] = socketpair 52 return socketpairs 53 54 def assertSocketClosed(self, sock): 55 self.assertRaisesErrno(ENOTCONN, sock.getpeername) 56 57 def assertSocketConnected(self, sock): 58 sock.getpeername() # No errors? Socket is alive and connected. 59 60 def assertSocketsClosed(self, socketpair): 61 for sock in socketpair: 62 self.assertSocketClosed(sock) 63 64 65class SockDiagTest(SockDiagBaseTest): 66 67 def setUp(self): 68 super(SockDiagTest, self).setUp() 69 self.sock_diag = sock_diag.SockDiag() 70 self.socketpairs = {} 71 72 def tearDown(self): 73 [s.close() for socketpair in self.socketpairs.values() for s in socketpair] 74 super(SockDiagTest, self).tearDown() 75 76 def testFixupDiagMsg(self): 77 src = "0a00fa02303030312030312038302031" 78 dst = "0808080841414141414141416f0a3230" 79 cookie = "4078678100000000" 80 sockid = sock_diag.InetDiagSockId((47436, 32069, 81 src.decode("hex"), dst.decode("hex"), 0, 82 cookie.decode("hex"))) 83 msg4 = sock_diag.InetDiagMsg((AF_INET, IPPROTO_TCP, 0, 84 sock_diag.TCP_SYN_RECV, sockid, 85 980, 123, 456, 789, 5555)) 86 # Make a copy, cstructs are mutable. 87 msg6 = sock_diag.InetDiagMsg(msg4.Pack()) 88 msg6.family = AF_INET6 89 90 fixed6 = sock_diag.InetDiagMsg(msg6.Pack()) 91 self.sock_diag.FixupDiagMsg(fixed6) 92 self.assertEquals(msg6.Pack(), fixed6.Pack()) 93 94 fixed4 = sock_diag.InetDiagMsg(msg4.Pack()) 95 self.sock_diag.FixupDiagMsg(fixed4) 96 msg4.id.src = src.decode("hex")[:4] + 12 * "\x00" 97 msg4.id.dst = dst.decode("hex")[:4] + 12 * "\x00" 98 self.assertEquals(msg4.Pack(), fixed4.Pack()) 99 100 def assertSockDiagMatchesSocket(self, s, diag_msg): 101 family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN) 102 self.assertEqual(diag_msg.family, family) 103 104 self.sock_diag.FixupDiagMsg(diag_msg) 105 106 src, sport = s.getsockname()[0:2] 107 self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src)) 108 self.assertEqual(diag_msg.id.sport, sport) 109 110 if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]: 111 dst, dport = s.getpeername()[0:2] 112 self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst)) 113 self.assertEqual(diag_msg.id.dport, dport) 114 else: 115 assertRaisesErrno(ENOTCONN, s.getpeername) 116 117 def testFindsAllMySockets(self): 118 self.socketpairs = self._CreateLotsOfSockets() 119 sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE) 120 self.assertGreaterEqual(len(sockets), NUM_SOCKETS) 121 122 # Find the cookies for all of our sockets. 123 cookies = {} 124 for diag_msg, attrs in sockets: 125 addr = self.sock_diag.GetSourceAddress(diag_msg) 126 sport = diag_msg.id.sport 127 dport = diag_msg.id.dport 128 if (addr, sport, dport) in self.socketpairs: 129 cookies[(addr, sport, dport)] = diag_msg.id.cookie 130 elif (addr, dport, sport) in self.socketpairs: 131 cookies[(addr, sport, dport)] = diag_msg.id.cookie 132 133 # Did we find all the cookies? 134 self.assertEquals(2 * NUM_SOCKETS, len(cookies)) 135 136 socketpairs = self.socketpairs.values() 137 random.shuffle(socketpairs) 138 for socketpair in socketpairs: 139 for sock in socketpair: 140 # Check that we can find a diag_msg by scanning a dump. 141 self.assertSockDiagMatchesSocket( 142 sock, 143 self.sock_diag.FindSockDiagFromFd(sock)) 144 cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie 145 146 # Check that we can find a diag_msg once we know the cookie. 147 req = self.sock_diag.DiagReqFromSocket(sock) 148 req.id.cookie = cookie 149 req.states = 1 << diag_msg.state 150 diag_msg = self.sock_diag.GetSockDiag(req) 151 self.assertSockDiagMatchesSocket(sock, diag_msg) 152 153 def testBytecodeCompilation(self): 154 instructions = [ 155 (sock_diag.INET_DIAG_BC_S_GE, 1, 8, 0), # 0 156 (sock_diag.INET_DIAG_BC_D_LE, 1, 7, 0xffff), # 8 157 (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)), # 16 158 (sock_diag.INET_DIAG_BC_JMP, 1, 3, None), # 44 159 (sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)), # 48 160 (sock_diag.INET_DIAG_BC_D_LE, 1, 3, 0x6665), # not used # 64 161 (sock_diag.INET_DIAG_BC_NOP, 1, 1, None), # 72 162 # 76 acc 163 # 80 rej 164 ] 165 bytecode = self.sock_diag.PackBytecode(instructions) 166 expected = ( 167 "0208500000000000" 168 "050848000000ffff" 169 "071c20000a800000ffffffff00000000000000000000000000000001" 170 "01041c00" 171 "0718200002200000ffffffff7f000001" 172 "0508100000006566" 173 "00040400" 174 ) 175 self.assertMultiLineEqual(expected, bytecode.encode("hex")) 176 self.assertEquals(76, len(bytecode)) 177 self.socketpairs = self._CreateLotsOfSockets() 178 filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode) 179 allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE) 180 self.assertItemsEqual(allsockets, filteredsockets) 181 182 # Pick a few sockets in hash table order, and check that the bytecode we 183 # compiled selects them properly. 184 for socketpair in self.socketpairs.values()[:20]: 185 for s in socketpair: 186 diag_msg = self.sock_diag.FindSockDiagFromFd(s) 187 instructions = [ 188 (sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport), 189 (sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport), 190 (sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport), 191 (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport), 192 ] 193 bytecode = self.sock_diag.PackBytecode(instructions) 194 self.assertEquals(32, len(bytecode)) 195 sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode) 196 self.assertEquals(1, len(sockets)) 197 198 # TODO: why doesn't comparing the cstructs work? 199 self.assertEquals(diag_msg.Pack(), sockets[0][0].Pack()) 200 201 def testCrossFamilyBytecode(self): 202 """Checks for a cross-family bug in inet_diag_hostcond matching. 203 204 Relevant kernel commits: 205 android-3.4: 206 f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run() 207 """ 208 # TODO: this is only here because the test fails if there are any open 209 # sockets other than the ones it creates itself. Make the bytecode more 210 # specific and remove it. 211 self.assertFalse(self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, "")) 212 213 pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1") 214 pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1") 215 216 bytecode4 = self.sock_diag.PackBytecode([ 217 (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("0.0.0.0", 0, -1))]) 218 bytecode6 = self.sock_diag.PackBytecode([ 219 (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))]) 220 221 # IPv4/v6 filters must never match IPv6/IPv4 sockets... 222 v4sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4) 223 self.assertTrue(v4sockets) 224 self.assertTrue(all(d.family == AF_INET for d, _ in v4sockets)) 225 226 v6sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6) 227 self.assertTrue(v6sockets) 228 self.assertTrue(all(d.family == AF_INET6 for d, _ in v6sockets)) 229 230 # Except for mapped addresses, which match both IPv4 and IPv6. 231 pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, 232 "::ffff:127.0.0.1") 233 diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5] 234 v4sockets = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, 235 bytecode4)] 236 v6sockets = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, 237 bytecode6)] 238 self.assertTrue(all(d in v4sockets for d in diag_msgs)) 239 self.assertTrue(all(d in v6sockets for d in diag_msgs)) 240 241 def testPortComparisonValidation(self): 242 """Checks for a bug in validating port comparison bytecode. 243 244 Relevant kernel commits: 245 android-3.4: 246 5e1f542 inet_diag: validate port comparison byte code to prevent unsafe reads 247 """ 248 bytecode = sock_diag.InetDiagBcOp((sock_diag.INET_DIAG_BC_D_GE, 4, 8)) 249 self.assertRaisesErrno( 250 EINVAL, 251 self.sock_diag.DumpAllInetSockets, IPPROTO_TCP, bytecode.Pack()) 252 253 @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") 254 def testClosesSockets(self): 255 self.socketpairs = self._CreateLotsOfSockets() 256 for (addr, _, _), socketpair in self.socketpairs.iteritems(): 257 # Close one of the sockets. 258 # This will send a RST that will close the other side as well. 259 s = random.choice(socketpair) 260 if random.randrange(0, 2) == 1: 261 self.sock_diag.CloseSocketFromFd(s) 262 else: 263 diag_msg = self.sock_diag.FindSockDiagFromFd(s) 264 family = AF_INET6 if ":" in addr else AF_INET 265 266 # Get the cookie wrong and ensure that we get an error and the socket 267 # is not closed. 268 real_cookie = diag_msg.id.cookie 269 diag_msg.id.cookie = os.urandom(len(real_cookie)) 270 req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) 271 self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req) 272 self.assertSocketConnected(s) 273 274 # Now close it with the correct cookie. 275 req.id.cookie = real_cookie 276 self.sock_diag.CloseSocket(req) 277 278 # Check that both sockets in the pair are closed. 279 self.assertSocketsClosed(socketpair) 280 281 @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") 282 def testNonTcpSockets(self): 283 s = socket(AF_INET6, SOCK_DGRAM, 0) 284 s.connect(("::1", 53)) 285 diag_msg = self.sock_diag.FindSockDiagFromFd(s) 286 self.assertRaisesErrno(EOPNOTSUPP, self.sock_diag.CloseSocketFromFd, s) 287 288 def testNonSockDiagCommand(self): 289 def DiagDump(code): 290 sock_id = self.sock_diag._EmptyInetDiagSockId() 291 req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff, 292 sock_id)) 293 self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg, "") 294 295 op = sock_diag.SOCK_DIAG_BY_FAMILY 296 DiagDump(op) # No errors? Good. 297 self.assertRaisesErrno(EINVAL, DiagDump, op + 17) 298 299 # TODO: 300 # Test that killing unix sockets returns EOPNOTSUPP. 301 302 303class SocketExceptionThread(threading.Thread): 304 305 def __init__(self, sock, operation): 306 self.exception = None 307 super(SocketExceptionThread, self).__init__() 308 self.daemon = True 309 self.sock = sock 310 self.operation = operation 311 312 def run(self): 313 try: 314 self.operation(self.sock) 315 except Exception, e: 316 self.exception = e 317 318 319# TODO: Take a tun fd as input, make this a utility class, and reuse at least 320# in forwarding_test. 321class TcpTest(SockDiagBaseTest): 322 323 NOT_YET_ACCEPTED = -1 324 325 def setUp(self): 326 super(TcpTest, self).setUp() 327 self.sock_diag = sock_diag.SockDiag() 328 self.netid = random.choice(self.tuns.keys()) 329 330 def OpenListenSocket(self, version): 331 self.port = packets.RandomPort() 332 family = {4: AF_INET, 6: AF_INET6}[version] 333 address = {4: "0.0.0.0", 6: "::"}[version] 334 s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP) 335 s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) 336 s.bind((address, self.port)) 337 # We haven't configured inbound iptables marking, so bind explicitly. 338 self.SelectInterface(s, self.netid, "mark") 339 s.listen(100) 340 return s 341 342 def _ReceiveAndExpectResponse(self, netid, packet, reply, msg): 343 pkt = super(TcpTest, self)._ReceiveAndExpectResponse(netid, packet, 344 reply, msg) 345 self.last_packet = pkt 346 return pkt 347 348 def ReceivePacketOn(self, netid, packet): 349 super(TcpTest, self).ReceivePacketOn(netid, packet) 350 self.last_packet = packet 351 352 def RstPacket(self): 353 return packets.RST(self.version, self.myaddr, self.remoteaddr, 354 self.last_packet) 355 356 def IncomingConnection(self, version, end_state, netid): 357 if version == 5: 358 mapped = True 359 socket_version = 6 360 version = 4 361 else: 362 socket_version = version 363 mapped = False 364 365 self.version = version 366 self.s = self.OpenListenSocket(socket_version) 367 self.end_state = end_state 368 369 def MaybeMappedAddress(addr): 370 return "::ffff:%s" % addr if mapped else addr 371 372 remoteaddr = self.remoteaddr = MaybeMappedAddress( 373 self.GetRemoteAddress(version)) 374 myaddr = self.myaddr = MaybeMappedAddress( 375 self.MyAddress(version, netid)) 376 377 if end_state == sock_diag.TCP_LISTEN: 378 return 379 380 desc, syn = packets.SYN(self.port, version, remoteaddr, myaddr) 381 synack_desc, synack = packets.SYNACK(version, myaddr, remoteaddr, syn) 382 msg = "Received %s, expected to see reply %s" % (desc, synack_desc) 383 reply = self._ReceiveAndExpectResponse(netid, syn, synack, msg) 384 if end_state == sock_diag.TCP_SYN_RECV: 385 return 386 387 establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1] 388 self.ReceivePacketOn(netid, establishing_ack) 389 390 if end_state == self.NOT_YET_ACCEPTED: 391 return 392 393 self.accepted, _ = self.s.accept() 394 net_test.DisableLinger(self.accepted) 395 396 if end_state == sock_diag.TCP_ESTABLISHED: 397 return 398 399 desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack, 400 payload=net_test.UDP_PAYLOAD) 401 self.accepted.send(net_test.UDP_PAYLOAD) 402 self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data) 403 404 desc, fin = packets.FIN(version, remoteaddr, myaddr, data) 405 fin = packets._GetIpLayer(version)(str(fin)) 406 ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, fin) 407 msg = "Received %s, expected to see reply %s" % (desc, ack_desc) 408 409 # TODO: Why can't we use this? 410 # self._ReceiveAndExpectResponse(netid, fin, ack, msg) 411 self.ReceivePacketOn(netid, fin) 412 time.sleep(0.1) 413 self.ExpectPacketOn(netid, msg + ": expecting %s" % ack_desc, ack) 414 if end_state == sock_diag.TCP_CLOSE_WAIT: 415 return 416 417 raise ValueError("Invalid TCP state %d specified" % end_state) 418 419 def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True): 420 """Closes the socket and checks whether a RST is sent or not.""" 421 if sock is not None: 422 self.assertIsNone(req, "Must specify sock or req, not both") 423 self.sock_diag.CloseSocketFromFd(sock) 424 self.assertRaisesErrno(EINVAL, sock.accept) 425 else: 426 self.assertIsNone(sock, "Must specify sock or req, not both") 427 self.sock_diag.CloseSocket(req) 428 429 if expect_reset: 430 desc, rst = self.RstPacket() 431 msg = "%s: expecting %s: " % (msg, desc) 432 self.ExpectPacketOn(self.netid, msg, rst) 433 else: 434 msg = "%s: " % msg 435 self.ExpectNoPacketsOn(self.netid, msg) 436 437 if sock is not None and do_close: 438 sock.close() 439 440 def CheckTcpReset(self, state, statename): 441 for version in [4, 6]: 442 msg = "Closing incoming IPv%d %s socket" % (version, statename) 443 self.IncomingConnection(version, state, self.netid) 444 self.CheckRstOnClose(self.s, None, False, msg) 445 if state != sock_diag.TCP_LISTEN: 446 msg = "Closing accepted IPv%d %s socket" % (version, statename) 447 self.CheckRstOnClose(self.accepted, None, True, msg) 448 449 @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") 450 def testTcpResets(self): 451 """Checks that closing sockets in appropriate states sends a RST.""" 452 self.CheckTcpReset(sock_diag.TCP_LISTEN, "TCP_LISTEN") 453 self.CheckTcpReset(sock_diag.TCP_ESTABLISHED, "TCP_ESTABLISHED") 454 self.CheckTcpReset(sock_diag.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT") 455 456 def FindChildSockets(self, s): 457 """Finds the SYN_RECV child sockets of a given listening socket.""" 458 d = self.sock_diag.FindSockDiagFromFd(self.s) 459 req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) 460 req.states = 1 << sock_diag.TCP_SYN_RECV | 1 << sock_diag.TCP_ESTABLISHED 461 req.id.cookie = "\x00" * 8 462 children = self.sock_diag.Dump(req, NO_BYTECODE) 463 return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) 464 for d, _ in children] 465 466 def CheckChildSocket(self, state, statename, parent_first): 467 for version in [4, 6]: 468 self.IncomingConnection(version, state, self.netid) 469 470 d = self.sock_diag.FindSockDiagFromFd(self.s) 471 parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) 472 children = self.FindChildSockets(self.s) 473 self.assertEquals(1, len(children)) 474 475 is_established = (state == self.NOT_YET_ACCEPTED) 476 477 # The new TCP listener code in 4.4 makes SYN_RECV sockets live in the 478 # regular TCP hash tables, and inet_diag_find_one_icsk can find them. 479 # Before 4.4, we can see those sockets in dumps, but we can't fetch 480 # or close them. 481 can_close_children = is_established or net_test.LINUX_VERSION >= (4, 4) 482 483 for child in children: 484 if can_close_children: 485 self.sock_diag.GetSockDiag(child) # No errors? Good, child found. 486 else: 487 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child) 488 489 def CloseParent(expect_reset): 490 msg = "Closing parent IPv%d %s socket %s child" % ( 491 version, statename, "before" if parent_first else "after") 492 self.CheckRstOnClose(self.s, None, expect_reset, msg) 493 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, parent) 494 495 def CheckChildrenClosed(): 496 for child in children: 497 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child) 498 499 def CloseChildren(): 500 for child in children: 501 msg = "Closing child IPv%d %s socket %s parent" % ( 502 version, statename, "after" if parent_first else "before") 503 self.sock_diag.GetSockDiag(child) 504 self.CheckRstOnClose(None, child, is_established, msg) 505 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child) 506 CheckChildrenClosed() 507 508 if parent_first: 509 # Closing the parent will close child sockets, which will send a RST, 510 # iff they are already established. 511 CloseParent(is_established) 512 if is_established: 513 CheckChildrenClosed() 514 elif can_close_children: 515 CloseChildren() 516 CheckChildrenClosed() 517 self.s.close() 518 else: 519 if can_close_children: 520 CloseChildren() 521 CloseParent(False) 522 self.s.close() 523 524 @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") 525 def testChildSockets(self): 526 self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", False) 527 self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", True) 528 self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", False) 529 self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", True) 530 531 def CloseDuringBlockingCall(self, sock, call, expected_errno): 532 thread = SocketExceptionThread(sock, call) 533 thread.start() 534 time.sleep(0.1) 535 self.sock_diag.CloseSocketFromFd(sock) 536 thread.join(1) 537 self.assertFalse(thread.is_alive()) 538 self.assertIsNotNone(thread.exception) 539 self.assertTrue(isinstance(thread.exception, IOError), 540 "Expected IOError, got %s" % thread.exception) 541 self.assertEqual(expected_errno, thread.exception.errno) 542 self.assertSocketClosed(sock) 543 544 @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") 545 def testAcceptInterrupted(self): 546 """Tests that accept() is interrupted by SOCK_DESTROY.""" 547 for version in [4, 5, 6]: 548 self.IncomingConnection(version, sock_diag.TCP_LISTEN, self.netid) 549 self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL) 550 self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo") 551 self.assertRaisesErrno(EINVAL, self.s.accept) 552 553 @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") 554 def testReadInterrupted(self): 555 """Tests that read() is interrupted by SOCK_DESTROY.""" 556 for version in [4, 5, 6]: 557 self.IncomingConnection(version, sock_diag.TCP_ESTABLISHED, self.netid) 558 self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096), 559 ECONNABORTED) 560 self.assertRaisesErrno(EPIPE, self.accepted.send, "foo") 561 562 @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") 563 def testConnectInterrupted(self): 564 """Tests that connect() is interrupted by SOCK_DESTROY.""" 565 for version in [4, 5, 6]: 566 family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version] 567 s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP) 568 self.SelectInterface(s, self.netid, "mark") 569 if version == 5: 570 remoteaddr = "::ffff:" + self.GetRemoteAddress(4) 571 version = 4 572 else: 573 remoteaddr = self.GetRemoteAddress(version) 574 s.bind(("", 0)) 575 _, sport = s.getsockname()[:2] 576 self.CloseDuringBlockingCall( 577 s, lambda sock: sock.connect((remoteaddr, 53)), ECONNABORTED) 578 desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid), 579 remoteaddr, sport=sport, seq=None) 580 self.ExpectPacketOn(self.netid, desc, syn) 581 msg = "SOCK_DESTROY of socket in connect, expected no RST" 582 self.ExpectNoPacketsOn(self.netid, msg) 583 584 def testIpv4MappedSynRecvSocket(self): 585 """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets. 586 587 Relevant kernel commits: 588 android-3.4: 589 457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state 590 """ 591 self.IncomingConnection(5, sock_diag.TCP_SYN_RECV, self.netid) 592 sock_id = self.sock_diag._EmptyInetDiagSockId() 593 sock_id.sport = self.port 594 states = 1 << sock_diag.TCP_SYN_RECV 595 req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id)) 596 children = self.sock_diag.Dump(req, NO_BYTECODE) 597 598 self.assertTrue(children) 599 for child, unused_args in children: 600 self.assertEqual(sock_diag.TCP_SYN_RECV, child.state) 601 self.assertEqual(self.sock_diag.PaddedAddress(self.remoteaddr), 602 child.id.dst) 603 self.assertEqual(self.sock_diag.PaddedAddress(self.myaddr), 604 child.id.src) 605 606 607if __name__ == "__main__": 608 unittest.main() 609