sock_diag_test.py revision eda52a427837c5840e991f41e0fcfb9b5dfc38a9
1#!/usr/bin/python 2# 3# Copyright 2015 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17from errno import * 18import os 19import random 20from socket import * 21import time 22import unittest 23 24import csocket 25import cstruct 26import multinetwork_base 27import net_test 28import packets 29import sock_diag 30import threading 31 32 33NUM_SOCKETS = 100 34 35ALL_NON_TIME_WAIT = 0xffffffff & ~(1 << sock_diag.TCP_TIME_WAIT) 36 37# TODO: Backport SOCK_DESTROY and delete this. 38HAVE_SOCK_DESTROY = net_test.LINUX_VERSION >= (4, 4) 39 40 41class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest): 42 43 @staticmethod 44 def _CreateLotsOfSockets(): 45 # Dict mapping (addr, sport, dport) tuples to socketpairs. 46 socketpairs = {} 47 for i in xrange(NUM_SOCKETS): 48 family, addr = random.choice([(AF_INET, "127.0.0.1"), (AF_INET6, "::1")]) 49 socketpair = net_test.CreateSocketPair(family, SOCK_STREAM, addr) 50 sport, dport = (socketpair[0].getsockname()[1], 51 socketpair[1].getsockname()[1]) 52 socketpairs[(addr, sport, dport)] = socketpair 53 return socketpairs 54 55 56class SockDiagTest(SockDiagBaseTest): 57 58 def setUp(self): 59 super(SockDiagTest, self).setUp() 60 self.sock_diag = sock_diag.SockDiag() 61 self.socketpairs = {} 62 63 def tearDown(self): 64 [s.close() for socketpair in self.socketpairs.values() for s in socketpair] 65 super(SockDiagTest, self).tearDown() 66 67 def testFixupDiagMsg(self): 68 src = "0a00fa02303030312030312038302031" 69 dst = "0808080841414141414141416f0a3230" 70 cookie = "4078678100000000" 71 sockid = sock_diag.InetDiagSockId((47436, 32069, 72 src.decode("hex"), dst.decode("hex"), 0, 73 cookie.decode("hex"))) 74 msg4 = sock_diag.InetDiagMsg((AF_INET, IPPROTO_TCP, 0, 75 sock_diag.TCP_SYN_RECV, sockid, 76 980, 123, 456, 789, 5555)) 77 # Make a copy, cstructs are mutable. 78 msg6 = sock_diag.InetDiagMsg(msg4.Pack()) 79 msg6.family = AF_INET6 80 81 fixed6 = sock_diag.InetDiagMsg(msg6.Pack()) 82 self.sock_diag.FixupDiagMsg(fixed6) 83 self.assertEquals(msg6.Pack(), fixed6.Pack()) 84 85 fixed4 = sock_diag.InetDiagMsg(msg4.Pack()) 86 self.sock_diag.FixupDiagMsg(fixed4) 87 msg4.id.src = src.decode("hex")[:4] + 12 * "\x00" 88 msg4.id.dst = dst.decode("hex")[:4] + 12 * "\x00" 89 self.assertEquals(msg4.Pack(), fixed4.Pack()) 90 91 def assertSocketClosed(self, sock): 92 self.assertRaisesErrno(ENOTCONN, sock.getpeername) 93 94 def assertSocketConnected(self, sock): 95 sock.getpeername() # No errors? Socket is alive and connected. 96 97 def assertSocketsClosed(self, socketpair): 98 for sock in socketpair: 99 self.assertSocketClosed(sock) 100 101 def assertSockDiagMatchesSocket(self, s, diag_msg): 102 family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN) 103 self.assertEqual(diag_msg.family, family) 104 105 self.sock_diag.FixupDiagMsg(diag_msg) 106 107 src, sport = s.getsockname()[0:2] 108 self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src)) 109 self.assertEqual(diag_msg.id.sport, sport) 110 111 if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]: 112 dst, dport = s.getpeername()[0:2] 113 self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst)) 114 self.assertEqual(diag_msg.id.dport, dport) 115 else: 116 assertRaisesErrno(ENOTCONN, s.getpeername) 117 118 def testFindsAllMySockets(self): 119 self.socketpairs = self._CreateLotsOfSockets() 120 sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, 121 states=ALL_NON_TIME_WAIT) 122 self.assertGreaterEqual(len(sockets), NUM_SOCKETS) 123 124 # Find the cookies for all of our sockets. 125 cookies = {} 126 for diag_msg, attrs in sockets: 127 addr = self.sock_diag.GetSourceAddress(diag_msg) 128 sport = diag_msg.id.sport 129 dport = diag_msg.id.dport 130 if (addr, sport, dport) in self.socketpairs: 131 cookies[(addr, sport, dport)] = diag_msg.id.cookie 132 elif (addr, dport, sport) in self.socketpairs: 133 cookies[(addr, sport, dport)] = diag_msg.id.cookie 134 135 # Did we find all the cookies? 136 self.assertEquals(2 * NUM_SOCKETS, len(cookies)) 137 138 socketpairs = self.socketpairs.values() 139 random.shuffle(socketpairs) 140 for socketpair in socketpairs: 141 for sock in socketpair: 142 # Check that we can find a diag_msg by scanning a dump. 143 self.assertSockDiagMatchesSocket( 144 sock, 145 self.sock_diag.FindSockDiagFromFd(sock)) 146 cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie 147 148 # Check that we can find a diag_msg once we know the cookie. 149 req = self.sock_diag.DiagReqFromSocket(sock) 150 req.id.cookie = cookie 151 req.states = 1 << diag_msg.state 152 diag_msg = self.sock_diag.GetSockDiag(req) 153 self.assertSockDiagMatchesSocket(sock, diag_msg) 154 155 @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") 156 def testClosesSockets(self): 157 self.socketpairs = self._CreateLotsOfSockets() 158 for (addr, _, _), socketpair in self.socketpairs.iteritems(): 159 # Close one of the sockets. 160 # This will send a RST that will close the other side as well. 161 s = random.choice(socketpair) 162 if random.randrange(0, 2) == 1: 163 self.sock_diag.CloseSocketFromFd(s) 164 else: 165 diag_msg = self.sock_diag.FindSockDiagFromFd(s) 166 family = AF_INET6 if ":" in addr else AF_INET 167 168 # Get the cookie wrong and ensure that we get an error and the socket 169 # is not closed. 170 real_cookie = diag_msg.id.cookie 171 diag_msg.id.cookie = os.urandom(len(real_cookie)) 172 req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) 173 self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req) 174 self.assertSocketConnected(s) 175 176 # Now close it with the correct cookie. 177 req.id.cookie = real_cookie 178 self.sock_diag.CloseSocket(req) 179 180 # Check that both sockets in the pair are closed. 181 self.assertSocketsClosed(socketpair) 182 183 @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") 184 def testNonTcpSockets(self): 185 s = socket(AF_INET6, SOCK_DGRAM, 0) 186 s.connect(("::1", 53)) 187 diag_msg = self.sock_diag.FindSockDiagFromFd(s) 188 self.assertRaisesErrno(EOPNOTSUPP, self.sock_diag.CloseSocketFromFd, s) 189 190 def testNonSockDiagCommand(self): 191 def DiagDump(code): 192 sock_id = self.sock_diag._EmptyInetDiagSockId() 193 req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff, 194 sock_id)) 195 self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg) 196 197 op = sock_diag.SOCK_DIAG_BY_FAMILY 198 DiagDump(op) # No errors? Good. 199 self.assertRaisesErrno(EINVAL, DiagDump, op + 17) 200 201 # TODO: 202 # Test that killing unix sockets returns EOPNOTSUPP. 203 204 205class SocketExceptionThread(threading.Thread): 206 207 def __init__(self, sock, operation): 208 self.exception = None 209 super(SocketExceptionThread, self).__init__() 210 self.daemon = True 211 self.sock = sock 212 self.operation = operation 213 214 def run(self): 215 try: 216 self.operation(self.sock) 217 except Exception, e: 218 self.exception = e 219 220 221# TODO: Take a tun fd as input, make this a utility class, and reuse at least 222# in forwarding_test. 223class TcpTest(SockDiagBaseTest): 224 225 NOT_YET_ACCEPTED = -1 226 227 def setUp(self): 228 super(TcpTest, self).setUp() 229 self.sock_diag = sock_diag.SockDiag() 230 self.netid = random.choice(self.tuns.keys()) 231 232 def OpenListenSocket(self, version): 233 self.port = packets.RandomPort() 234 family = {4: AF_INET, 6: AF_INET6}[version] 235 address = {4: "0.0.0.0", 6: "::"}[version] 236 s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP) 237 s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) 238 s.bind((address, self.port)) 239 # We haven't configured inbound iptables marking, so bind explicitly. 240 self.SelectInterface(s, self.netid, "mark") 241 s.listen(100) 242 return s 243 244 def _ReceiveAndExpectResponse(self, netid, packet, reply, msg): 245 pkt = super(TcpTest, self)._ReceiveAndExpectResponse(netid, packet, 246 reply, msg) 247 self.last_packet = pkt 248 return pkt 249 250 def ReceivePacketOn(self, netid, packet): 251 super(TcpTest, self).ReceivePacketOn(netid, packet) 252 self.last_packet = packet 253 254 def RstPacket(self): 255 return packets.RST(self.version, self.myaddr, self.remoteaddr, 256 self.last_packet) 257 258 def IncomingConnection(self, version, end_state, netid): 259 if version == 5: 260 mapped = True 261 socket_version = 6 262 version = 4 263 else: 264 socket_version = version 265 mapped = False 266 267 self.version = version 268 self.s = self.OpenListenSocket(socket_version) 269 self.end_state = end_state 270 271 def MaybeMappedAddress(addr): 272 return "::ffff:%s" % addr if mapped else addr 273 274 remoteaddr = self.remoteaddr = MaybeMappedAddress( 275 self.GetRemoteAddress(version)) 276 myaddr = self.myaddr = MaybeMappedAddress( 277 self.MyAddress(version, netid)) 278 279 if end_state == sock_diag.TCP_LISTEN: 280 return 281 282 desc, syn = packets.SYN(self.port, version, remoteaddr, myaddr) 283 synack_desc, synack = packets.SYNACK(version, myaddr, remoteaddr, syn) 284 msg = "Received %s, expected to see reply %s" % (desc, synack_desc) 285 reply = self._ReceiveAndExpectResponse(netid, syn, synack, msg) 286 if end_state == sock_diag.TCP_SYN_RECV: 287 return 288 289 establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1] 290 self.ReceivePacketOn(netid, establishing_ack) 291 292 if end_state == self.NOT_YET_ACCEPTED: 293 return 294 295 self.accepted, _ = self.s.accept() 296 if end_state == sock_diag.TCP_ESTABLISHED: 297 return 298 299 desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack, 300 payload=net_test.UDP_PAYLOAD) 301 self.accepted.send(net_test.UDP_PAYLOAD) 302 self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data) 303 304 desc, fin = packets.FIN(version, remoteaddr, myaddr, data) 305 fin = packets._GetIpLayer(version)(str(fin)) 306 ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, fin) 307 msg = "Received %s, expected to see reply %s" % (desc, ack_desc) 308 309 # TODO: Why can't we use this? 310 # self._ReceiveAndExpectResponse(netid, fin, ack, msg) 311 self.ReceivePacketOn(netid, fin) 312 time.sleep(0.1) 313 self.ExpectPacketOn(netid, msg + ": expecting %s" % ack_desc, ack) 314 if end_state == sock_diag.TCP_CLOSE_WAIT: 315 return 316 317 raise ValueError("Invalid TCP state %d specified" % end_state) 318 319 def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True): 320 """Closes the socket and checks whether a RST is sent or not.""" 321 if sock is not None: 322 self.assertIsNone(req, "Must specify sock or req, not both") 323 self.sock_diag.CloseSocketFromFd(sock) 324 self.assertRaisesErrno(EINVAL, sock.accept) 325 else: 326 self.assertIsNone(sock, "Must specify sock or req, not both") 327 self.sock_diag.CloseSocket(req) 328 329 if expect_reset: 330 desc, rst = self.RstPacket() 331 msg = "%s: expecting %s: " % (msg, desc) 332 self.ExpectPacketOn(self.netid, msg, rst) 333 else: 334 msg = "%s: " % msg 335 self.ExpectNoPacketsOn(self.netid, msg) 336 337 if sock is not None and do_close: 338 sock.close() 339 340 def CheckTcpReset(self, state, statename): 341 for version in [4, 6]: 342 msg = "Closing incoming IPv%d %s socket" % (version, statename) 343 self.IncomingConnection(version, state, self.netid) 344 self.CheckRstOnClose(self.s, None, False, msg) 345 if state != sock_diag.TCP_LISTEN: 346 msg = "Closing accepted IPv%d %s socket" % (version, statename) 347 self.CheckRstOnClose(self.accepted, None, True, msg) 348 349 @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") 350 def testTcpResets(self): 351 """Checks that closing sockets in appropriate states sends a RST.""" 352 self.CheckTcpReset(sock_diag.TCP_LISTEN, "TCP_LISTEN") 353 self.CheckTcpReset(sock_diag.TCP_ESTABLISHED, "TCP_ESTABLISHED") 354 self.CheckTcpReset(sock_diag.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT") 355 356 def FindChildSockets(self, s): 357 """Finds the SYN_RECV child sockets of a given listening socket.""" 358 d = self.sock_diag.FindSockDiagFromFd(self.s) 359 req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) 360 req.states = 1 << sock_diag.TCP_SYN_RECV | 1 << sock_diag.TCP_ESTABLISHED 361 req.id.cookie = "\x00" * 8 362 children = self.sock_diag.Dump(req) 363 return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) 364 for d, _ in children] 365 366 def CheckChildSocket(self, state, statename, parent_first): 367 for version in [4, 6]: 368 self.IncomingConnection(version, state, self.netid) 369 370 d = self.sock_diag.FindSockDiagFromFd(self.s) 371 parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) 372 children = self.FindChildSockets(self.s) 373 self.assertEquals(1, len(children)) 374 375 is_established = (state == self.NOT_YET_ACCEPTED) 376 377 # The new TCP listener code in 4.4 makes SYN_RECV sockets live in the 378 # regular TCP hash tables, and inet_diag_find_one_icsk can find them. 379 # Before 4.4, we can see those sockets in dumps, but we can't fetch 380 # or close them. 381 can_close_children = is_established or net_test.LINUX_VERSION >= (4, 4) 382 383 for child in children: 384 if can_close_children: 385 self.sock_diag.GetSockDiag(child) # No errors? Good, child found. 386 else: 387 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child) 388 389 def CloseParent(expect_reset): 390 msg = "Closing parent IPv%d %s socket %s child" % ( 391 version, statename, "before" if parent_first else "after") 392 self.CheckRstOnClose(self.s, None, expect_reset, msg) 393 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, parent) 394 395 def CheckChildrenClosed(): 396 for child in children: 397 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child) 398 399 def CloseChildren(): 400 for child in children: 401 msg = "Closing child IPv%d %s socket %s parent" % ( 402 version, statename, "after" if parent_first else "before") 403 self.sock_diag.GetSockDiag(child) 404 self.CheckRstOnClose(None, child, is_established, msg) 405 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child) 406 CheckChildrenClosed() 407 408 if parent_first: 409 # Closing the parent will close child sockets, which will send a RST, 410 # iff they are already established. 411 CloseParent(is_established) 412 if is_established: 413 CheckChildrenClosed() 414 elif can_close_children: 415 CloseChildren() 416 CheckChildrenClosed() 417 self.s.close() 418 else: 419 if can_close_children: 420 CloseChildren() 421 CloseParent(False) 422 self.s.close() 423 424 @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") 425 def testChildSockets(self): 426 self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", False) 427 self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", True) 428 self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", False) 429 self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", True) 430 431 def CloseDuringBlockingCall(self, sock, call, expected_errno): 432 thread = SocketExceptionThread(sock, call) 433 thread.start() 434 time.sleep(0.1) 435 self.sock_diag.CloseSocketFromFd(sock) 436 thread.join(1) 437 self.assertFalse(thread.is_alive()) 438 self.assertIsNotNone(thread.exception) 439 self.assertTrue(isinstance(thread.exception, IOError), 440 "Expected IOError, got %s" % thread.exception) 441 self.assertEqual(expected_errno, thread.exception.errno) 442 self.assertSocketClosed(sock) 443 444 @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") 445 def testAcceptInterrupted(self): 446 """Tests that accept() is interrupted by SOCK_DESTROY.""" 447 for version in [4, 5, 6]: 448 self.IncomingConnection(version, sock_diag.TCP_LISTEN, self.netid) 449 self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL) 450 self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo") 451 self.assertRaisesErrno(EINVAL, self.s.accept) 452 453 @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") 454 def testReadInterrupted(self): 455 """Tests that read() is interrupted by SOCK_DESTROY.""" 456 for version in [4, 5, 6]: 457 self.IncomingConnection(version, sock_diag.TCP_ESTABLISHED, self.netid) 458 self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096), 459 ECONNABORTED) 460 self.assertRaisesErrno(EPIPE, self.accepted.send, "foo") 461 462 @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") 463 def testConnectInterrupted(self): 464 """Tests that connect() is interrupted by SOCK_DESTROY.""" 465 for version in [4, 5, 6]: 466 family = {4: AF_INET, 6: AF_INET6}[version] 467 s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP) 468 self.SelectInterface(s, self.netid, "mark") 469 remoteaddr = self.GetRemoteAddress(version) 470 s.bind(("", 0)) 471 _, sport = s.getsockname()[:2] 472 self.CloseDuringBlockingCall( 473 s, lambda sock: sock.connect((remoteaddr, 53)), ECONNABORTED) 474 desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid), 475 remoteaddr, sport=sport, seq=None) 476 self.ExpectPacketOn(self.netid, desc, syn) 477 msg = "SOCK_DESTROY of socket in connect, expected no RST" 478 self.ExpectNoPacketsOn(self.netid, msg) 479 480 def testIpv4MappedSynRecvSocket(self): 481 """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets. 482 483 Relevant kernel commits: 484 android-3.4: 485 457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state 486 """ 487 self.IncomingConnection(5, sock_diag.TCP_SYN_RECV, self.netid) 488 sock_id = self.sock_diag._EmptyInetDiagSockId() 489 sock_id.sport = self.port 490 states = 1 << sock_diag.TCP_SYN_RECV 491 req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id)) 492 children = self.sock_diag.Dump(req) 493 494 self.assertTrue(children) 495 for child, unused_args in children: 496 self.assertEqual(sock_diag.TCP_SYN_RECV, child.state) 497 self.assertEqual(self.sock_diag.PaddedAddress(self.remoteaddr), 498 child.id.dst) 499 self.assertEqual(self.sock_diag.PaddedAddress(self.myaddr), 500 child.id.src) 501 502 503if __name__ == "__main__": 504 unittest.main() 505