sock_diag_test.py revision 63195c89b87f54f7c7a3ade1ed9b0c4fcab6557b
1#!/usr/bin/python
2#
3# Copyright 2015 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17from errno import *
18import os
19import random
20from socket import *
21import time
22import unittest
23
24import csocket
25import cstruct
26import multinetwork_base
27import net_test
28import packets
29import sock_diag
30import threading
31
32
33NUM_SOCKETS = 30
34NO_BYTECODE = ""
35
36# TODO: Backport SOCK_DESTROY and delete this.
37HAVE_SOCK_DESTROY = net_test.LINUX_VERSION >= (4, 4)
38
39
40class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest):
41
42  @staticmethod
43  def _CreateLotsOfSockets():
44    # Dict mapping (addr, sport, dport) tuples to socketpairs.
45    socketpairs = {}
46    for i in xrange(NUM_SOCKETS):
47      family, addr = random.choice([
48          (AF_INET, "127.0.0.1"),
49          (AF_INET6, "::1"),
50          (AF_INET6, "::ffff:127.0.0.1")])
51      socketpair = net_test.CreateSocketPair(family, SOCK_STREAM, addr)
52      sport, dport = (socketpair[0].getsockname()[1],
53                      socketpair[1].getsockname()[1])
54      socketpairs[(addr, sport, dport)] = socketpair
55    return socketpairs
56
57  def assertSocketClosed(self, sock):
58    self.assertRaisesErrno(ENOTCONN, sock.getpeername)
59
60  def assertSocketConnected(self, sock):
61    sock.getpeername()  # No errors? Socket is alive and connected.
62
63  def assertSocketsClosed(self, socketpair):
64    for sock in socketpair:
65      self.assertSocketClosed(sock)
66
67  def setUp(self):
68    super(SockDiagBaseTest, self).setUp()
69    self.sock_diag = sock_diag.SockDiag()
70    self.socketpairs = {}
71
72  def tearDown(self):
73    [s.close() for socketpair in self.socketpairs.values() for s in socketpair]
74    super(SockDiagBaseTest, self).tearDown()
75
76
77class SockDiagTest(SockDiagBaseTest):
78
79  def assertSockDiagMatchesSocket(self, s, diag_msg):
80    family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
81    self.assertEqual(diag_msg.family, family)
82
83    src, sport = s.getsockname()[0:2]
84    self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src))
85    self.assertEqual(diag_msg.id.sport, sport)
86
87    if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]:
88      dst, dport = s.getpeername()[0:2]
89      self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst))
90      self.assertEqual(diag_msg.id.dport, dport)
91    else:
92      assertRaisesErrno(ENOTCONN, s.getpeername)
93
94  def testFindsMappedSockets(self):
95    """Tests that inet_diag_find_one_icsk can find mapped sockets.
96
97    Relevant kernel commits:
98      android-3.10:
99        f77e059 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
100    """
101    socketpair = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
102                                           "::ffff:127.0.0.1")
103    for sock in socketpair:
104      diag_msg = self.sock_diag.FindSockDiagFromFd(sock)
105      diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
106      self.sock_diag.GetSockDiag(diag_req)
107      # No errors? Good.
108
109  def testFindsAllMySockets(self):
110    """Tests that basic socket dumping works.
111
112    Relevant commits:
113      android-3.4:
114        ab4a727 net: inet_diag: zero out uninitialized idiag_{src,dst} fields
115      android-3.10
116        3eb409b net: inet_diag: zero out uninitialized idiag_{src,dst} fields
117    """
118    self.socketpairs = self._CreateLotsOfSockets()
119    sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE)
120    self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
121
122    # Find the cookies for all of our sockets.
123    cookies = {}
124    for diag_msg, attrs in sockets:
125      addr = self.sock_diag.GetSourceAddress(diag_msg)
126      sport = diag_msg.id.sport
127      dport = diag_msg.id.dport
128      if (addr, sport, dport) in self.socketpairs:
129        cookies[(addr, sport, dport)] = diag_msg.id.cookie
130      elif (addr, dport, sport) in self.socketpairs:
131        cookies[(addr, sport, dport)] = diag_msg.id.cookie
132
133    # Did we find all the cookies?
134    self.assertEquals(2 * NUM_SOCKETS, len(cookies))
135
136    socketpairs = self.socketpairs.values()
137    random.shuffle(socketpairs)
138    for socketpair in socketpairs:
139      for sock in socketpair:
140        # Check that we can find a diag_msg by scanning a dump.
141        self.assertSockDiagMatchesSocket(
142            sock,
143            self.sock_diag.FindSockDiagFromFd(sock))
144        cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie
145
146        # Check that we can find a diag_msg once we know the cookie.
147        req = self.sock_diag.DiagReqFromSocket(sock)
148        req.id.cookie = cookie
149        req.states = 1 << diag_msg.state
150        diag_msg = self.sock_diag.GetSockDiag(req)
151        self.assertSockDiagMatchesSocket(sock, diag_msg)
152
153  def testBytecodeCompilation(self):
154    instructions = [
155        (sock_diag.INET_DIAG_BC_S_GE,   1, 8, 0),                      # 0
156        (sock_diag.INET_DIAG_BC_D_LE,   1, 7, 0xffff),                 # 8
157        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)),       # 16
158        (sock_diag.INET_DIAG_BC_JMP,    1, 3, None),                   # 44
159        (sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)),  # 48
160        (sock_diag.INET_DIAG_BC_D_LE,   1, 3, 0x6665),  # not used     # 64
161        (sock_diag.INET_DIAG_BC_NOP,    1, 1, None),                   # 72
162                                                                       # 76 acc
163                                                                       # 80 rej
164    ]
165    bytecode = self.sock_diag.PackBytecode(instructions)
166    expected = (
167        "0208500000000000"
168        "050848000000ffff"
169        "071c20000a800000ffffffff00000000000000000000000000000001"
170        "01041c00"
171        "0718200002200000ffffffff7f000001"
172        "0508100000006566"
173        "00040400"
174    )
175    self.assertMultiLineEqual(expected, bytecode.encode("hex"))
176    self.assertEquals(76, len(bytecode))
177    self.socketpairs = self._CreateLotsOfSockets()
178    filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
179    allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE)
180    self.assertItemsEqual(allsockets, filteredsockets)
181
182    # Pick a few sockets in hash table order, and check that the bytecode we
183    # compiled selects them properly.
184    for socketpair in self.socketpairs.values()[:20]:
185      for s in socketpair:
186        diag_msg = self.sock_diag.FindSockDiagFromFd(s)
187        instructions = [
188            (sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport),
189            (sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport),
190            (sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport),
191            (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport),
192        ]
193        bytecode = self.sock_diag.PackBytecode(instructions)
194        self.assertEquals(32, len(bytecode))
195        sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
196        self.assertEquals(1, len(sockets))
197
198        # TODO: why doesn't comparing the cstructs work?
199        self.assertEquals(diag_msg.Pack(), sockets[0][0].Pack())
200
201  def testCrossFamilyBytecode(self):
202    """Checks for a cross-family bug in inet_diag_hostcond matching.
203
204    Relevant kernel commits:
205      android-3.4:
206        f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run()
207    """
208    # TODO: this is only here because the test fails if there are any open
209    # sockets other than the ones it creates itself. Make the bytecode more
210    # specific and remove it.
211    self.assertFalse(self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, ""))
212
213    pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1")
214    pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1")
215
216    bytecode4 = self.sock_diag.PackBytecode([
217        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("0.0.0.0", 0, -1))])
218    bytecode6 = self.sock_diag.PackBytecode([
219        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))])
220
221    # IPv4/v6 filters must never match IPv6/IPv4 sockets...
222    v4sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4)
223    self.assertTrue(v4sockets)
224    self.assertTrue(all(d.family == AF_INET for d, _ in v4sockets))
225
226    v6sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6)
227    self.assertTrue(v6sockets)
228    self.assertTrue(all(d.family == AF_INET6 for d, _ in v6sockets))
229
230    # Except for mapped addresses, which match both IPv4 and IPv6.
231    pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
232                                      "::ffff:127.0.0.1")
233    diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5]
234    v4sockets = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
235                                                                 bytecode4)]
236    v6sockets = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
237                                                                 bytecode6)]
238    self.assertTrue(all(d in v4sockets for d in diag_msgs))
239    self.assertTrue(all(d in v6sockets for d in diag_msgs))
240
241  def testPortComparisonValidation(self):
242    """Checks for a bug in validating port comparison bytecode.
243
244    Relevant kernel commits:
245      android-3.4:
246        5e1f542 inet_diag: validate port comparison byte code to prevent unsafe reads
247    """
248    bytecode = sock_diag.InetDiagBcOp((sock_diag.INET_DIAG_BC_D_GE, 4, 8))
249    self.assertRaisesErrno(
250        EINVAL,
251        self.sock_diag.DumpAllInetSockets, IPPROTO_TCP, bytecode.Pack())
252
253  def testNonSockDiagCommand(self):
254    def DiagDump(code):
255      sock_id = self.sock_diag._EmptyInetDiagSockId()
256      req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff,
257                                     sock_id))
258      self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg, "")
259
260    op = sock_diag.SOCK_DIAG_BY_FAMILY
261    DiagDump(op)  # No errors? Good.
262    self.assertRaisesErrno(EINVAL, DiagDump, op + 17)
263
264
265@unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
266class SockDestroyTest(SockDiagBaseTest):
267
268  def testClosesSockets(self):
269    self.socketpairs = self._CreateLotsOfSockets()
270    for (addr, _, _), socketpair in self.socketpairs.iteritems():
271      # Close one of the sockets.
272      # This will send a RST that will close the other side as well.
273      s = random.choice(socketpair)
274      if random.randrange(0, 2) == 1:
275        self.sock_diag.CloseSocketFromFd(s)
276      else:
277        diag_msg = self.sock_diag.FindSockDiagFromFd(s)
278        family = AF_INET6 if ":" in addr else AF_INET
279
280        # Get the cookie wrong and ensure that we get an error and the socket
281        # is not closed.
282        real_cookie = diag_msg.id.cookie
283        diag_msg.id.cookie = os.urandom(len(real_cookie))
284        req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
285        self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req)
286        self.assertSocketConnected(s)
287
288        # Now close it with the correct cookie.
289        req.id.cookie = real_cookie
290        self.sock_diag.CloseSocket(req)
291
292      # Check that both sockets in the pair are closed.
293      self.assertSocketsClosed(socketpair)
294
295  def testNonTcpSockets(self):
296    s = socket(AF_INET6, SOCK_DGRAM, 0)
297    s.connect(("::1", 53))
298    diag_msg = self.sock_diag.FindSockDiagFromFd(s)
299    self.assertRaisesErrno(EOPNOTSUPP, self.sock_diag.CloseSocketFromFd, s)
300
301  # TODO:
302  # Test that killing unix sockets returns EOPNOTSUPP.
303
304
305class SocketExceptionThread(threading.Thread):
306
307  def __init__(self, sock, operation):
308    self.exception = None
309    super(SocketExceptionThread, self).__init__()
310    self.daemon = True
311    self.sock = sock
312    self.operation = operation
313
314  def run(self):
315    try:
316      self.operation(self.sock)
317    except Exception, e:
318      self.exception = e
319
320
321# TODO: Take a tun fd as input, make this a utility class, and reuse at least
322# in forwarding_test.
323class TcpTest(SockDiagBaseTest):
324
325  NOT_YET_ACCEPTED = -1
326
327  def setUp(self):
328    super(TcpTest, self).setUp()
329    self.sock_diag = sock_diag.SockDiag()
330    self.netid = random.choice(self.tuns.keys())
331
332  def OpenListenSocket(self, version):
333    self.port = packets.RandomPort()
334    family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
335    address = {4: "0.0.0.0", 5: "::", 6: "::"}[version]
336    s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
337    s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
338    s.bind((address, self.port))
339    # We haven't configured inbound iptables marking, so bind explicitly.
340    self.SelectInterface(s, self.netid, "mark")
341    s.listen(100)
342    return s
343
344  def _ReceiveAndExpectResponse(self, netid, packet, reply, msg):
345    pkt = super(TcpTest, self)._ReceiveAndExpectResponse(netid, packet,
346                                                         reply, msg)
347    self.last_packet = pkt
348    return pkt
349
350  def ReceivePacketOn(self, netid, packet):
351    super(TcpTest, self).ReceivePacketOn(netid, packet)
352    self.last_packet = packet
353
354  def RstPacket(self):
355    return packets.RST(self.version, self.myaddr, self.remoteaddr,
356                       self.last_packet)
357
358  def IncomingConnection(self, version, end_state, netid):
359    self.s = self.OpenListenSocket(version)
360    self.end_state = end_state
361
362    remoteaddr = self.remoteaddr = self.GetRemoteAddress(version)
363    myaddr = self.myaddr = self.MyAddress(version, netid)
364
365    if version == 5: version = 4
366    self.version = version
367
368    if end_state == sock_diag.TCP_LISTEN:
369      return
370
371    desc, syn = packets.SYN(self.port, version, remoteaddr, myaddr)
372    synack_desc, synack = packets.SYNACK(version, myaddr, remoteaddr, syn)
373    msg = "Received %s, expected to see reply %s" % (desc, synack_desc)
374    reply = self._ReceiveAndExpectResponse(netid, syn, synack, msg)
375    if end_state == sock_diag.TCP_SYN_RECV:
376      return
377
378    establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
379    self.ReceivePacketOn(netid, establishing_ack)
380
381    if end_state == self.NOT_YET_ACCEPTED:
382      return
383
384    self.accepted, _ = self.s.accept()
385    net_test.DisableLinger(self.accepted)
386
387    if end_state == sock_diag.TCP_ESTABLISHED:
388      return
389
390    desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack,
391                             payload=net_test.UDP_PAYLOAD)
392    self.accepted.send(net_test.UDP_PAYLOAD)
393    self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
394
395    desc, fin = packets.FIN(version, remoteaddr, myaddr, data)
396    fin = packets._GetIpLayer(version)(str(fin))
397    ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, fin)
398    msg = "Received %s, expected to see reply %s" % (desc, ack_desc)
399
400    # TODO: Why can't we use this?
401    #   self._ReceiveAndExpectResponse(netid, fin, ack, msg)
402    self.ReceivePacketOn(netid, fin)
403    time.sleep(0.1)
404    self.ExpectPacketOn(netid, msg + ": expecting %s" % ack_desc, ack)
405    if end_state == sock_diag.TCP_CLOSE_WAIT:
406      return
407
408    raise ValueError("Invalid TCP state %d specified" % end_state)
409
410  def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True):
411    """Closes the socket and checks whether a RST is sent or not."""
412    if sock is not None:
413      self.assertIsNone(req, "Must specify sock or req, not both")
414      self.sock_diag.CloseSocketFromFd(sock)
415      self.assertRaisesErrno(EINVAL, sock.accept)
416    else:
417      self.assertIsNone(sock, "Must specify sock or req, not both")
418      self.sock_diag.CloseSocket(req)
419
420    if expect_reset:
421      desc, rst = self.RstPacket()
422      msg = "%s: expecting %s: " % (msg, desc)
423      self.ExpectPacketOn(self.netid, msg, rst)
424    else:
425      msg = "%s: " % msg
426      self.ExpectNoPacketsOn(self.netid, msg)
427
428    if sock is not None and do_close:
429      sock.close()
430
431  def CheckTcpReset(self, state, statename):
432    for version in [4, 5, 6]:
433      msg = "Closing incoming IPv%d %s socket" % (version, statename)
434      self.IncomingConnection(version, state, self.netid)
435      self.CheckRstOnClose(self.s, None, False, msg)
436      if state != sock_diag.TCP_LISTEN:
437        msg = "Closing accepted IPv%d %s socket" % (version, statename)
438        self.CheckRstOnClose(self.accepted, None, True, msg)
439
440  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
441  def testTcpResets(self):
442    """Checks that closing sockets in appropriate states sends a RST."""
443    self.CheckTcpReset(sock_diag.TCP_LISTEN, "TCP_LISTEN")
444    self.CheckTcpReset(sock_diag.TCP_ESTABLISHED, "TCP_ESTABLISHED")
445    self.CheckTcpReset(sock_diag.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")
446
447  def FindChildSockets(self, s):
448    """Finds the SYN_RECV child sockets of a given listening socket."""
449    d = self.sock_diag.FindSockDiagFromFd(self.s)
450    req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
451    req.states = 1 << sock_diag.TCP_SYN_RECV | 1 << sock_diag.TCP_ESTABLISHED
452    req.id.cookie = "\x00" * 8
453    children = self.sock_diag.Dump(req, NO_BYTECODE)
454    return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
455            for d, _ in children]
456
457  def CheckChildSocket(self, state, statename, parent_first):
458    for version in [4, 5, 6]:
459      self.IncomingConnection(version, state, self.netid)
460
461      d = self.sock_diag.FindSockDiagFromFd(self.s)
462      parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
463      children = self.FindChildSockets(self.s)
464      self.assertEquals(1, len(children))
465
466      is_established = (state == self.NOT_YET_ACCEPTED)
467
468      # The new TCP listener code in 4.4 makes SYN_RECV sockets live in the
469      # regular TCP hash tables, and inet_diag_find_one_icsk can find them.
470      # Before 4.4, we can see those sockets in dumps, but we can't fetch
471      # or close them.
472      can_close_children = is_established or net_test.LINUX_VERSION >= (4, 4)
473
474      for child in children:
475        if can_close_children:
476          self.sock_diag.GetSockDiag(child)  # No errors? Good, child found.
477        else:
478          self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
479
480      def CloseParent(expect_reset):
481        msg = "Closing parent IPv%d %s socket %s child" % (
482            version, statename, "before" if parent_first else "after")
483        self.CheckRstOnClose(self.s, None, expect_reset, msg)
484        self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, parent)
485
486      def CheckChildrenClosed():
487        for child in children:
488          self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
489
490      def CloseChildren():
491        for child in children:
492          msg = "Closing child IPv%d %s socket %s parent" % (
493              version, statename, "after" if parent_first else "before")
494          self.sock_diag.GetSockDiag(child)
495          self.CheckRstOnClose(None, child, is_established, msg)
496          self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
497        CheckChildrenClosed()
498
499      if parent_first:
500        # Closing the parent will close child sockets, which will send a RST,
501        # iff they are already established.
502        CloseParent(is_established)
503        if is_established:
504          CheckChildrenClosed()
505        elif can_close_children:
506          CloseChildren()
507          CheckChildrenClosed()
508        self.s.close()
509      else:
510        if can_close_children:
511          CloseChildren()
512        CloseParent(False)
513        self.s.close()
514
515  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
516  def testChildSockets(self):
517    self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", False)
518    self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", True)
519    self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", False)
520    self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", True)
521
522  def CloseDuringBlockingCall(self, sock, call, expected_errno):
523    thread = SocketExceptionThread(sock, call)
524    thread.start()
525    time.sleep(0.1)
526    self.sock_diag.CloseSocketFromFd(sock)
527    thread.join(1)
528    self.assertFalse(thread.is_alive())
529    self.assertIsNotNone(thread.exception)
530    self.assertTrue(isinstance(thread.exception, IOError),
531                    "Expected IOError, got %s" % thread.exception)
532    self.assertEqual(expected_errno, thread.exception.errno)
533    self.assertSocketClosed(sock)
534
535  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
536  def testAcceptInterrupted(self):
537    """Tests that accept() is interrupted by SOCK_DESTROY."""
538    for version in [4, 5, 6]:
539      self.IncomingConnection(version, sock_diag.TCP_LISTEN, self.netid)
540      self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL)
541      self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo")
542      self.assertRaisesErrno(EINVAL, self.s.accept)
543
544  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
545  def testReadInterrupted(self):
546    """Tests that read() is interrupted by SOCK_DESTROY."""
547    for version in [4, 5, 6]:
548      self.IncomingConnection(version, sock_diag.TCP_ESTABLISHED, self.netid)
549      self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096),
550                                   ECONNABORTED)
551      self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
552
553  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
554  def testConnectInterrupted(self):
555    """Tests that connect() is interrupted by SOCK_DESTROY."""
556    for version in [4, 5, 6]:
557      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
558      s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
559      self.SelectInterface(s, self.netid, "mark")
560      if version == 5:
561        remoteaddr = "::ffff:" + self.GetRemoteAddress(4)
562        version = 4
563      else:
564        remoteaddr = self.GetRemoteAddress(version)
565      s.bind(("", 0))
566      _, sport = s.getsockname()[:2]
567      self.CloseDuringBlockingCall(
568          s, lambda sock: sock.connect((remoteaddr, 53)), ECONNABORTED)
569      desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid),
570                              remoteaddr, sport=sport, seq=None)
571      self.ExpectPacketOn(self.netid, desc, syn)
572      msg = "SOCK_DESTROY of socket in connect, expected no RST"
573      self.ExpectNoPacketsOn(self.netid, msg)
574
575  def testIpv4MappedSynRecvSocket(self):
576    """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets.
577
578    Relevant kernel commits:
579         android-3.4:
580           457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state
581    """
582    self.IncomingConnection(5, sock_diag.TCP_SYN_RECV, self.netid)
583    sock_id = self.sock_diag._EmptyInetDiagSockId()
584    sock_id.sport = self.port
585    states = 1 << sock_diag.TCP_SYN_RECV
586    req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
587    children = self.sock_diag.Dump(req, NO_BYTECODE)
588
589    self.assertTrue(children)
590    for child, unused_args in children:
591      self.assertEqual(sock_diag.TCP_SYN_RECV, child.state)
592      self.assertEqual(self.sock_diag.PaddedAddress(self.remoteaddr),
593                       child.id.dst)
594      self.assertEqual(self.sock_diag.PaddedAddress(self.myaddr),
595                       child.id.src)
596
597
598if __name__ == "__main__":
599  unittest.main()
600