sock_diag_test.py revision 2c96358e57d689833af77ec0c13831a9ae1b27b8
1#!/usr/bin/python
2#
3# Copyright 2015 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17from errno import *
18import os
19import random
20from socket import *
21import time
22import unittest
23
24import csocket
25import cstruct
26import multinetwork_base
27import net_test
28import packets
29import sock_diag
30import threading
31
32
33NUM_SOCKETS = 100
34NO_BYTECODE = ""
35
36# TODO: Backport SOCK_DESTROY and delete this.
37HAVE_SOCK_DESTROY = net_test.LINUX_VERSION >= (4, 4)
38
39
40class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest):
41
42  @staticmethod
43  def _CreateLotsOfSockets():
44    # Dict mapping (addr, sport, dport) tuples to socketpairs.
45    socketpairs = {}
46    for i in xrange(NUM_SOCKETS):
47      family, addr = random.choice([(AF_INET, "127.0.0.1"), (AF_INET6, "::1")])
48      socketpair = net_test.CreateSocketPair(family, SOCK_STREAM, addr)
49      sport, dport = (socketpair[0].getsockname()[1],
50                      socketpair[1].getsockname()[1])
51      socketpairs[(addr, sport, dport)] = socketpair
52    return socketpairs
53
54  def assertSocketClosed(self, sock):
55    self.assertRaisesErrno(ENOTCONN, sock.getpeername)
56
57  def assertSocketConnected(self, sock):
58    sock.getpeername()  # No errors? Socket is alive and connected.
59
60  def assertSocketsClosed(self, socketpair):
61    for sock in socketpair:
62      self.assertSocketClosed(sock)
63
64
65class SockDiagTest(SockDiagBaseTest):
66
67  def setUp(self):
68    super(SockDiagTest, self).setUp()
69    self.sock_diag = sock_diag.SockDiag()
70    self.socketpairs = {}
71
72  def tearDown(self):
73    [s.close() for socketpair in self.socketpairs.values() for s in socketpair]
74    super(SockDiagTest, self).tearDown()
75
76  def testFixupDiagMsg(self):
77    src = "0a00fa02303030312030312038302031"
78    dst = "0808080841414141414141416f0a3230"
79    cookie = "4078678100000000"
80    sockid = sock_diag.InetDiagSockId((47436, 32069,
81                                       src.decode("hex"), dst.decode("hex"), 0,
82                                       cookie.decode("hex")))
83    msg4 = sock_diag.InetDiagMsg((AF_INET, IPPROTO_TCP, 0,
84                                  sock_diag.TCP_SYN_RECV, sockid,
85                                  980, 123, 456, 789, 5555))
86    # Make a copy, cstructs are mutable.
87    msg6 = sock_diag.InetDiagMsg(msg4.Pack())
88    msg6.family = AF_INET6
89
90    fixed6 = sock_diag.InetDiagMsg(msg6.Pack())
91    self.sock_diag.FixupDiagMsg(fixed6)
92    self.assertEquals(msg6.Pack(), fixed6.Pack())
93
94    fixed4 = sock_diag.InetDiagMsg(msg4.Pack())
95    self.sock_diag.FixupDiagMsg(fixed4)
96    msg4.id.src = src.decode("hex")[:4] + 12 * "\x00"
97    msg4.id.dst = dst.decode("hex")[:4] + 12 * "\x00"
98    self.assertEquals(msg4.Pack(), fixed4.Pack())
99
100  def assertSockDiagMatchesSocket(self, s, diag_msg):
101    family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
102    self.assertEqual(diag_msg.family, family)
103
104    self.sock_diag.FixupDiagMsg(diag_msg)
105
106    src, sport = s.getsockname()[0:2]
107    self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src))
108    self.assertEqual(diag_msg.id.sport, sport)
109
110    if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]:
111      dst, dport = s.getpeername()[0:2]
112      self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst))
113      self.assertEqual(diag_msg.id.dport, dport)
114    else:
115      assertRaisesErrno(ENOTCONN, s.getpeername)
116
117  def testFindsAllMySockets(self):
118    self.socketpairs = self._CreateLotsOfSockets()
119    sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE)
120    self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
121
122    # Find the cookies for all of our sockets.
123    cookies = {}
124    for diag_msg, attrs in sockets:
125      addr = self.sock_diag.GetSourceAddress(diag_msg)
126      sport = diag_msg.id.sport
127      dport = diag_msg.id.dport
128      if (addr, sport, dport) in self.socketpairs:
129        cookies[(addr, sport, dport)] = diag_msg.id.cookie
130      elif (addr, dport, sport) in self.socketpairs:
131        cookies[(addr, sport, dport)] = diag_msg.id.cookie
132
133    # Did we find all the cookies?
134    self.assertEquals(2 * NUM_SOCKETS, len(cookies))
135
136    socketpairs = self.socketpairs.values()
137    random.shuffle(socketpairs)
138    for socketpair in socketpairs:
139      for sock in socketpair:
140        # Check that we can find a diag_msg by scanning a dump.
141        self.assertSockDiagMatchesSocket(
142            sock,
143            self.sock_diag.FindSockDiagFromFd(sock))
144        cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie
145
146        # Check that we can find a diag_msg once we know the cookie.
147        req = self.sock_diag.DiagReqFromSocket(sock)
148        req.id.cookie = cookie
149        req.states = 1 << diag_msg.state
150        diag_msg = self.sock_diag.GetSockDiag(req)
151        self.assertSockDiagMatchesSocket(sock, diag_msg)
152
153  def testBytecodeCompilation(self):
154    instructions = [
155        (sock_diag.INET_DIAG_BC_S_GE,   1, 8, 0),                      # 0
156        (sock_diag.INET_DIAG_BC_D_LE,   1, 7, 0xffff),                 # 8
157        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)),       # 16
158        (sock_diag.INET_DIAG_BC_JMP,    1, 3, None),                   # 44
159        (sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)),  # 48
160        (sock_diag.INET_DIAG_BC_D_LE,   1, 3, 0x6665),  # not used     # 64
161        (sock_diag.INET_DIAG_BC_NOP,    1, 1, None),                   # 72
162                                                                       # 76 acc
163                                                                       # 80 rej
164    ]
165    bytecode = self.sock_diag.PackBytecode(instructions)
166    expected = (
167        "0208500000000000"
168        "050848000000ffff"
169        "071c20000a800000ffffffff00000000000000000000000000000001"
170        "01041c00"
171        "0718200002200000ffffffff7f000001"
172        "0508100000006566"
173        "00040400"
174    )
175    self.assertMultiLineEqual(expected, bytecode.encode("hex"))
176    self.assertEquals(76, len(bytecode))
177    self.socketpairs = self._CreateLotsOfSockets()
178    filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
179    allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE)
180    self.assertItemsEqual(allsockets, filteredsockets)
181
182    # Pick a few sockets in hash table order, and check that the bytecode we
183    # compiled selects them properly.
184    for socketpair in self.socketpairs.values()[:20]:
185      for s in socketpair:
186        diag_msg = self.sock_diag.FindSockDiagFromFd(s)
187        instructions = [
188            (sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport),
189            (sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport),
190            (sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport),
191            (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport),
192        ]
193        bytecode = self.sock_diag.PackBytecode(instructions)
194        self.assertEquals(32, len(bytecode))
195        sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
196        self.assertEquals(1, len(sockets))
197
198        # TODO: why doesn't comparing the cstructs work?
199        self.assertEquals(diag_msg.Pack(), sockets[0][0].Pack())
200
201  def testCrossFamilyBytecode(self):
202    """Checks for a cross-family bug in inet_diag_hostcond matching.
203
204    Relevant kernel commits:
205      android-3.4:
206        f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run()
207    """
208    # TODO: this is only here because the test fails if there are any open
209    # sockets other than the ones it creates itself. Make the bytecode more
210    # specific and remove it.
211    self.assertFalse(self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, ""))
212
213    pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1")
214    pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1")
215
216    bytecode4 = self.sock_diag.PackBytecode([
217        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("0.0.0.0", 0, -1))])
218    bytecode6 = self.sock_diag.PackBytecode([
219        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))])
220
221    # IPv4/v6 filters must never match IPv6/IPv4 sockets...
222    v4sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4)
223    self.assertTrue(v4sockets)
224    self.assertTrue(all(d.family == AF_INET for d, _ in v4sockets))
225
226    v6sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6)
227    self.assertTrue(v6sockets)
228    self.assertTrue(all(d.family == AF_INET6 for d, _ in v6sockets))
229
230    # Except for mapped addresses, which match both IPv4 and IPv6.
231    pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
232                                      "::ffff:127.0.0.1")
233    diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5]
234    v4sockets = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
235                                                                 bytecode4)]
236    v6sockets = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
237                                                                 bytecode6)]
238    self.assertTrue(all(d in v4sockets for d in diag_msgs))
239    self.assertTrue(all(d in v6sockets for d in diag_msgs))
240
241  def testPortComparisonValidation(self):
242    """Checks for a bug in validating port comparison bytecode.
243
244    Relevant kernel commits:
245      android-3.4:
246        5e1f542 inet_diag: validate port comparison byte code to prevent unsafe reads
247    """
248    bytecode = sock_diag.InetDiagBcOp((sock_diag.INET_DIAG_BC_D_GE, 4, 8))
249    self.assertRaisesErrno(
250        EINVAL,
251        self.sock_diag.DumpAllInetSockets, IPPROTO_TCP, bytecode.Pack())
252
253  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
254  def testClosesSockets(self):
255    self.socketpairs = self._CreateLotsOfSockets()
256    for (addr, _, _), socketpair in self.socketpairs.iteritems():
257      # Close one of the sockets.
258      # This will send a RST that will close the other side as well.
259      s = random.choice(socketpair)
260      if random.randrange(0, 2) == 1:
261        self.sock_diag.CloseSocketFromFd(s)
262      else:
263        diag_msg = self.sock_diag.FindSockDiagFromFd(s)
264        family = AF_INET6 if ":" in addr else AF_INET
265
266        # Get the cookie wrong and ensure that we get an error and the socket
267        # is not closed.
268        real_cookie = diag_msg.id.cookie
269        diag_msg.id.cookie = os.urandom(len(real_cookie))
270        req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
271        self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req)
272        self.assertSocketConnected(s)
273
274        # Now close it with the correct cookie.
275        req.id.cookie = real_cookie
276        self.sock_diag.CloseSocket(req)
277
278      # Check that both sockets in the pair are closed.
279      self.assertSocketsClosed(socketpair)
280
281  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
282  def testNonTcpSockets(self):
283    s = socket(AF_INET6, SOCK_DGRAM, 0)
284    s.connect(("::1", 53))
285    diag_msg = self.sock_diag.FindSockDiagFromFd(s)
286    self.assertRaisesErrno(EOPNOTSUPP, self.sock_diag.CloseSocketFromFd, s)
287
288  def testNonSockDiagCommand(self):
289    def DiagDump(code):
290      sock_id = self.sock_diag._EmptyInetDiagSockId()
291      req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff,
292                                     sock_id))
293      self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg, "")
294
295    op = sock_diag.SOCK_DIAG_BY_FAMILY
296    DiagDump(op)  # No errors? Good.
297    self.assertRaisesErrno(EINVAL, DiagDump, op + 17)
298
299  # TODO:
300  # Test that killing unix sockets returns EOPNOTSUPP.
301
302
303class SocketExceptionThread(threading.Thread):
304
305  def __init__(self, sock, operation):
306    self.exception = None
307    super(SocketExceptionThread, self).__init__()
308    self.daemon = True
309    self.sock = sock
310    self.operation = operation
311
312  def run(self):
313    try:
314      self.operation(self.sock)
315    except Exception, e:
316      self.exception = e
317
318
319# TODO: Take a tun fd as input, make this a utility class, and reuse at least
320# in forwarding_test.
321class TcpTest(SockDiagBaseTest):
322
323  NOT_YET_ACCEPTED = -1
324
325  def setUp(self):
326    super(TcpTest, self).setUp()
327    self.sock_diag = sock_diag.SockDiag()
328    self.netid = random.choice(self.tuns.keys())
329
330  def OpenListenSocket(self, version):
331    self.port = packets.RandomPort()
332    family = {4: AF_INET, 6: AF_INET6}[version]
333    address = {4: "0.0.0.0", 6: "::"}[version]
334    s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
335    s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
336    s.bind((address, self.port))
337    # We haven't configured inbound iptables marking, so bind explicitly.
338    self.SelectInterface(s, self.netid, "mark")
339    s.listen(100)
340    return s
341
342  def _ReceiveAndExpectResponse(self, netid, packet, reply, msg):
343    pkt = super(TcpTest, self)._ReceiveAndExpectResponse(netid, packet,
344                                                         reply, msg)
345    self.last_packet = pkt
346    return pkt
347
348  def ReceivePacketOn(self, netid, packet):
349    super(TcpTest, self).ReceivePacketOn(netid, packet)
350    self.last_packet = packet
351
352  def RstPacket(self):
353    return packets.RST(self.version, self.myaddr, self.remoteaddr,
354                       self.last_packet)
355
356  def IncomingConnection(self, version, end_state, netid):
357    if version == 5:
358      mapped = True
359      socket_version = 6
360      version = 4
361    else:
362      socket_version = version
363      mapped = False
364
365    self.version = version
366    self.s = self.OpenListenSocket(socket_version)
367    self.end_state = end_state
368
369    def MaybeMappedAddress(addr):
370      return "::ffff:%s" % addr if mapped else addr
371
372    remoteaddr = self.remoteaddr = MaybeMappedAddress(
373        self.GetRemoteAddress(version))
374    myaddr = self.myaddr = MaybeMappedAddress(
375        self.MyAddress(version, netid))
376
377    if end_state == sock_diag.TCP_LISTEN:
378      return
379
380    desc, syn = packets.SYN(self.port, version, remoteaddr, myaddr)
381    synack_desc, synack = packets.SYNACK(version, myaddr, remoteaddr, syn)
382    msg = "Received %s, expected to see reply %s" % (desc, synack_desc)
383    reply = self._ReceiveAndExpectResponse(netid, syn, synack, msg)
384    if end_state == sock_diag.TCP_SYN_RECV:
385      return
386
387    establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
388    self.ReceivePacketOn(netid, establishing_ack)
389
390    if end_state == self.NOT_YET_ACCEPTED:
391      return
392
393    self.accepted, _ = self.s.accept()
394    net_test.DisableLinger(self.accepted)
395
396    if end_state == sock_diag.TCP_ESTABLISHED:
397      return
398
399    desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack,
400                             payload=net_test.UDP_PAYLOAD)
401    self.accepted.send(net_test.UDP_PAYLOAD)
402    self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
403
404    desc, fin = packets.FIN(version, remoteaddr, myaddr, data)
405    fin = packets._GetIpLayer(version)(str(fin))
406    ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, fin)
407    msg = "Received %s, expected to see reply %s" % (desc, ack_desc)
408
409    # TODO: Why can't we use this?
410    #   self._ReceiveAndExpectResponse(netid, fin, ack, msg)
411    self.ReceivePacketOn(netid, fin)
412    time.sleep(0.1)
413    self.ExpectPacketOn(netid, msg + ": expecting %s" % ack_desc, ack)
414    if end_state == sock_diag.TCP_CLOSE_WAIT:
415      return
416
417    raise ValueError("Invalid TCP state %d specified" % end_state)
418
419  def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True):
420    """Closes the socket and checks whether a RST is sent or not."""
421    if sock is not None:
422      self.assertIsNone(req, "Must specify sock or req, not both")
423      self.sock_diag.CloseSocketFromFd(sock)
424      self.assertRaisesErrno(EINVAL, sock.accept)
425    else:
426      self.assertIsNone(sock, "Must specify sock or req, not both")
427      self.sock_diag.CloseSocket(req)
428
429    if expect_reset:
430      desc, rst = self.RstPacket()
431      msg = "%s: expecting %s: " % (msg, desc)
432      self.ExpectPacketOn(self.netid, msg, rst)
433    else:
434      msg = "%s: " % msg
435      self.ExpectNoPacketsOn(self.netid, msg)
436
437    if sock is not None and do_close:
438      sock.close()
439
440  def CheckTcpReset(self, state, statename):
441    for version in [4, 6]:
442      msg = "Closing incoming IPv%d %s socket" % (version, statename)
443      self.IncomingConnection(version, state, self.netid)
444      self.CheckRstOnClose(self.s, None, False, msg)
445      if state != sock_diag.TCP_LISTEN:
446        msg = "Closing accepted IPv%d %s socket" % (version, statename)
447        self.CheckRstOnClose(self.accepted, None, True, msg)
448
449  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
450  def testTcpResets(self):
451    """Checks that closing sockets in appropriate states sends a RST."""
452    self.CheckTcpReset(sock_diag.TCP_LISTEN, "TCP_LISTEN")
453    self.CheckTcpReset(sock_diag.TCP_ESTABLISHED, "TCP_ESTABLISHED")
454    self.CheckTcpReset(sock_diag.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")
455
456  def FindChildSockets(self, s):
457    """Finds the SYN_RECV child sockets of a given listening socket."""
458    d = self.sock_diag.FindSockDiagFromFd(self.s)
459    req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
460    req.states = 1 << sock_diag.TCP_SYN_RECV | 1 << sock_diag.TCP_ESTABLISHED
461    req.id.cookie = "\x00" * 8
462    children = self.sock_diag.Dump(req, NO_BYTECODE)
463    return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
464            for d, _ in children]
465
466  def CheckChildSocket(self, state, statename, parent_first):
467    for version in [4, 6]:
468      self.IncomingConnection(version, state, self.netid)
469
470      d = self.sock_diag.FindSockDiagFromFd(self.s)
471      parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
472      children = self.FindChildSockets(self.s)
473      self.assertEquals(1, len(children))
474
475      is_established = (state == self.NOT_YET_ACCEPTED)
476
477      # The new TCP listener code in 4.4 makes SYN_RECV sockets live in the
478      # regular TCP hash tables, and inet_diag_find_one_icsk can find them.
479      # Before 4.4, we can see those sockets in dumps, but we can't fetch
480      # or close them.
481      can_close_children = is_established or net_test.LINUX_VERSION >= (4, 4)
482
483      for child in children:
484        if can_close_children:
485          self.sock_diag.GetSockDiag(child)  # No errors? Good, child found.
486        else:
487          self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
488
489      def CloseParent(expect_reset):
490        msg = "Closing parent IPv%d %s socket %s child" % (
491            version, statename, "before" if parent_first else "after")
492        self.CheckRstOnClose(self.s, None, expect_reset, msg)
493        self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, parent)
494
495      def CheckChildrenClosed():
496        for child in children:
497          self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
498
499      def CloseChildren():
500        for child in children:
501          msg = "Closing child IPv%d %s socket %s parent" % (
502              version, statename, "after" if parent_first else "before")
503          self.sock_diag.GetSockDiag(child)
504          self.CheckRstOnClose(None, child, is_established, msg)
505          self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
506        CheckChildrenClosed()
507
508      if parent_first:
509        # Closing the parent will close child sockets, which will send a RST,
510        # iff they are already established.
511        CloseParent(is_established)
512        if is_established:
513          CheckChildrenClosed()
514        elif can_close_children:
515          CloseChildren()
516          CheckChildrenClosed()
517        self.s.close()
518      else:
519        if can_close_children:
520          CloseChildren()
521        CloseParent(False)
522        self.s.close()
523
524  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
525  def testChildSockets(self):
526    self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", False)
527    self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", True)
528    self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", False)
529    self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", True)
530
531  def CloseDuringBlockingCall(self, sock, call, expected_errno):
532    thread = SocketExceptionThread(sock, call)
533    thread.start()
534    time.sleep(0.1)
535    self.sock_diag.CloseSocketFromFd(sock)
536    thread.join(1)
537    self.assertFalse(thread.is_alive())
538    self.assertIsNotNone(thread.exception)
539    self.assertTrue(isinstance(thread.exception, IOError),
540                    "Expected IOError, got %s" % thread.exception)
541    self.assertEqual(expected_errno, thread.exception.errno)
542    self.assertSocketClosed(sock)
543
544  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
545  def testAcceptInterrupted(self):
546    """Tests that accept() is interrupted by SOCK_DESTROY."""
547    for version in [4, 5, 6]:
548      self.IncomingConnection(version, sock_diag.TCP_LISTEN, self.netid)
549      self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL)
550      self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo")
551      self.assertRaisesErrno(EINVAL, self.s.accept)
552
553  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
554  def testReadInterrupted(self):
555    """Tests that read() is interrupted by SOCK_DESTROY."""
556    for version in [4, 5, 6]:
557      self.IncomingConnection(version, sock_diag.TCP_ESTABLISHED, self.netid)
558      self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096),
559                                   ECONNABORTED)
560      self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
561
562  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
563  def testConnectInterrupted(self):
564    """Tests that connect() is interrupted by SOCK_DESTROY."""
565    for version in [4, 5, 6]:
566      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
567      s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
568      self.SelectInterface(s, self.netid, "mark")
569      if version == 5:
570        remoteaddr = "::ffff:" + self.GetRemoteAddress(4)
571        version = 4
572      else:
573        remoteaddr = self.GetRemoteAddress(version)
574      s.bind(("", 0))
575      _, sport = s.getsockname()[:2]
576      self.CloseDuringBlockingCall(
577          s, lambda sock: sock.connect((remoteaddr, 53)), ECONNABORTED)
578      desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid),
579                              remoteaddr, sport=sport, seq=None)
580      self.ExpectPacketOn(self.netid, desc, syn)
581      msg = "SOCK_DESTROY of socket in connect, expected no RST"
582      self.ExpectNoPacketsOn(self.netid, msg)
583
584  def testIpv4MappedSynRecvSocket(self):
585    """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets.
586
587    Relevant kernel commits:
588         android-3.4:
589           457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state
590    """
591    self.IncomingConnection(5, sock_diag.TCP_SYN_RECV, self.netid)
592    sock_id = self.sock_diag._EmptyInetDiagSockId()
593    sock_id.sport = self.port
594    states = 1 << sock_diag.TCP_SYN_RECV
595    req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
596    children = self.sock_diag.Dump(req, NO_BYTECODE)
597
598    self.assertTrue(children)
599    for child, unused_args in children:
600      self.assertEqual(sock_diag.TCP_SYN_RECV, child.state)
601      self.assertEqual(self.sock_diag.PaddedAddress(self.remoteaddr),
602                       child.id.dst)
603      self.assertEqual(self.sock_diag.PaddedAddress(self.myaddr),
604                       child.id.src)
605
606
607if __name__ == "__main__":
608  unittest.main()
609