sock_diag_test.py revision eda52a427837c5840e991f41e0fcfb9b5dfc38a9
1#!/usr/bin/python
2#
3# Copyright 2015 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17from errno import *
18import os
19import random
20from socket import *
21import time
22import unittest
23
24import csocket
25import cstruct
26import multinetwork_base
27import net_test
28import packets
29import sock_diag
30import threading
31
32
33NUM_SOCKETS = 100
34
35ALL_NON_TIME_WAIT = 0xffffffff & ~(1 << sock_diag.TCP_TIME_WAIT)
36
37# TODO: Backport SOCK_DESTROY and delete this.
38HAVE_SOCK_DESTROY = net_test.LINUX_VERSION >= (4, 4)
39
40
41class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest):
42
43  @staticmethod
44  def _CreateLotsOfSockets():
45    # Dict mapping (addr, sport, dport) tuples to socketpairs.
46    socketpairs = {}
47    for i in xrange(NUM_SOCKETS):
48      family, addr = random.choice([(AF_INET, "127.0.0.1"), (AF_INET6, "::1")])
49      socketpair = net_test.CreateSocketPair(family, SOCK_STREAM, addr)
50      sport, dport = (socketpair[0].getsockname()[1],
51                      socketpair[1].getsockname()[1])
52      socketpairs[(addr, sport, dport)] = socketpair
53    return socketpairs
54
55
56class SockDiagTest(SockDiagBaseTest):
57
58  def setUp(self):
59    super(SockDiagTest, self).setUp()
60    self.sock_diag = sock_diag.SockDiag()
61    self.socketpairs = {}
62
63  def tearDown(self):
64    [s.close() for socketpair in self.socketpairs.values() for s in socketpair]
65    super(SockDiagTest, self).tearDown()
66
67  def testFixupDiagMsg(self):
68    src = "0a00fa02303030312030312038302031"
69    dst = "0808080841414141414141416f0a3230"
70    cookie = "4078678100000000"
71    sockid = sock_diag.InetDiagSockId((47436, 32069,
72                                       src.decode("hex"), dst.decode("hex"), 0,
73                                       cookie.decode("hex")))
74    msg4 = sock_diag.InetDiagMsg((AF_INET, IPPROTO_TCP, 0,
75                                  sock_diag.TCP_SYN_RECV, sockid,
76                                  980, 123, 456, 789, 5555))
77    # Make a copy, cstructs are mutable.
78    msg6 = sock_diag.InetDiagMsg(msg4.Pack())
79    msg6.family = AF_INET6
80
81    fixed6 = sock_diag.InetDiagMsg(msg6.Pack())
82    self.sock_diag.FixupDiagMsg(fixed6)
83    self.assertEquals(msg6.Pack(), fixed6.Pack())
84
85    fixed4 = sock_diag.InetDiagMsg(msg4.Pack())
86    self.sock_diag.FixupDiagMsg(fixed4)
87    msg4.id.src = src.decode("hex")[:4] + 12 * "\x00"
88    msg4.id.dst = dst.decode("hex")[:4] + 12 * "\x00"
89    self.assertEquals(msg4.Pack(), fixed4.Pack())
90
91  def assertSocketClosed(self, sock):
92    self.assertRaisesErrno(ENOTCONN, sock.getpeername)
93
94  def assertSocketConnected(self, sock):
95    sock.getpeername()  # No errors? Socket is alive and connected.
96
97  def assertSocketsClosed(self, socketpair):
98    for sock in socketpair:
99      self.assertSocketClosed(sock)
100
101  def assertSockDiagMatchesSocket(self, s, diag_msg):
102    family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
103    self.assertEqual(diag_msg.family, family)
104
105    self.sock_diag.FixupDiagMsg(diag_msg)
106
107    src, sport = s.getsockname()[0:2]
108    self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src))
109    self.assertEqual(diag_msg.id.sport, sport)
110
111    if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]:
112      dst, dport = s.getpeername()[0:2]
113      self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst))
114      self.assertEqual(diag_msg.id.dport, dport)
115    else:
116      assertRaisesErrno(ENOTCONN, s.getpeername)
117
118  def testFindsAllMySockets(self):
119    self.socketpairs = self._CreateLotsOfSockets()
120    sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
121                                                states=ALL_NON_TIME_WAIT)
122    self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
123
124    # Find the cookies for all of our sockets.
125    cookies = {}
126    for diag_msg, attrs in sockets:
127      addr = self.sock_diag.GetSourceAddress(diag_msg)
128      sport = diag_msg.id.sport
129      dport = diag_msg.id.dport
130      if (addr, sport, dport) in self.socketpairs:
131        cookies[(addr, sport, dport)] = diag_msg.id.cookie
132      elif (addr, dport, sport) in self.socketpairs:
133        cookies[(addr, sport, dport)] = diag_msg.id.cookie
134
135    # Did we find all the cookies?
136    self.assertEquals(2 * NUM_SOCKETS, len(cookies))
137
138    socketpairs = self.socketpairs.values()
139    random.shuffle(socketpairs)
140    for socketpair in socketpairs:
141      for sock in socketpair:
142        # Check that we can find a diag_msg by scanning a dump.
143        self.assertSockDiagMatchesSocket(
144            sock,
145            self.sock_diag.FindSockDiagFromFd(sock))
146        cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie
147
148        # Check that we can find a diag_msg once we know the cookie.
149        req = self.sock_diag.DiagReqFromSocket(sock)
150        req.id.cookie = cookie
151        req.states = 1 << diag_msg.state
152        diag_msg = self.sock_diag.GetSockDiag(req)
153        self.assertSockDiagMatchesSocket(sock, diag_msg)
154
155  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
156  def testClosesSockets(self):
157    self.socketpairs = self._CreateLotsOfSockets()
158    for (addr, _, _), socketpair in self.socketpairs.iteritems():
159      # Close one of the sockets.
160      # This will send a RST that will close the other side as well.
161      s = random.choice(socketpair)
162      if random.randrange(0, 2) == 1:
163        self.sock_diag.CloseSocketFromFd(s)
164      else:
165        diag_msg = self.sock_diag.FindSockDiagFromFd(s)
166        family = AF_INET6 if ":" in addr else AF_INET
167
168        # Get the cookie wrong and ensure that we get an error and the socket
169        # is not closed.
170        real_cookie = diag_msg.id.cookie
171        diag_msg.id.cookie = os.urandom(len(real_cookie))
172        req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
173        self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req)
174        self.assertSocketConnected(s)
175
176        # Now close it with the correct cookie.
177        req.id.cookie = real_cookie
178        self.sock_diag.CloseSocket(req)
179
180      # Check that both sockets in the pair are closed.
181      self.assertSocketsClosed(socketpair)
182
183  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
184  def testNonTcpSockets(self):
185    s = socket(AF_INET6, SOCK_DGRAM, 0)
186    s.connect(("::1", 53))
187    diag_msg = self.sock_diag.FindSockDiagFromFd(s)
188    self.assertRaisesErrno(EOPNOTSUPP, self.sock_diag.CloseSocketFromFd, s)
189
190  def testNonSockDiagCommand(self):
191    def DiagDump(code):
192      sock_id = self.sock_diag._EmptyInetDiagSockId()
193      req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff,
194                                     sock_id))
195      self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg)
196
197    op = sock_diag.SOCK_DIAG_BY_FAMILY
198    DiagDump(op)  # No errors? Good.
199    self.assertRaisesErrno(EINVAL, DiagDump, op + 17)
200
201  # TODO:
202  # Test that killing unix sockets returns EOPNOTSUPP.
203
204
205class SocketExceptionThread(threading.Thread):
206
207  def __init__(self, sock, operation):
208    self.exception = None
209    super(SocketExceptionThread, self).__init__()
210    self.daemon = True
211    self.sock = sock
212    self.operation = operation
213
214  def run(self):
215    try:
216      self.operation(self.sock)
217    except Exception, e:
218      self.exception = e
219
220
221# TODO: Take a tun fd as input, make this a utility class, and reuse at least
222# in forwarding_test.
223class TcpTest(SockDiagBaseTest):
224
225  NOT_YET_ACCEPTED = -1
226
227  def setUp(self):
228    super(TcpTest, self).setUp()
229    self.sock_diag = sock_diag.SockDiag()
230    self.netid = random.choice(self.tuns.keys())
231
232  def OpenListenSocket(self, version):
233    self.port = packets.RandomPort()
234    family = {4: AF_INET, 6: AF_INET6}[version]
235    address = {4: "0.0.0.0", 6: "::"}[version]
236    s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
237    s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
238    s.bind((address, self.port))
239    # We haven't configured inbound iptables marking, so bind explicitly.
240    self.SelectInterface(s, self.netid, "mark")
241    s.listen(100)
242    return s
243
244  def _ReceiveAndExpectResponse(self, netid, packet, reply, msg):
245    pkt = super(TcpTest, self)._ReceiveAndExpectResponse(netid, packet,
246                                                         reply, msg)
247    self.last_packet = pkt
248    return pkt
249
250  def ReceivePacketOn(self, netid, packet):
251    super(TcpTest, self).ReceivePacketOn(netid, packet)
252    self.last_packet = packet
253
254  def RstPacket(self):
255    return packets.RST(self.version, self.myaddr, self.remoteaddr,
256                       self.last_packet)
257
258  def IncomingConnection(self, version, end_state, netid):
259    if version == 5:
260      mapped = True
261      socket_version = 6
262      version = 4
263    else:
264      socket_version = version
265      mapped = False
266
267    self.version = version
268    self.s = self.OpenListenSocket(socket_version)
269    self.end_state = end_state
270
271    def MaybeMappedAddress(addr):
272      return "::ffff:%s" % addr if mapped else addr
273
274    remoteaddr = self.remoteaddr = MaybeMappedAddress(
275        self.GetRemoteAddress(version))
276    myaddr = self.myaddr = MaybeMappedAddress(
277        self.MyAddress(version, netid))
278
279    if end_state == sock_diag.TCP_LISTEN:
280      return
281
282    desc, syn = packets.SYN(self.port, version, remoteaddr, myaddr)
283    synack_desc, synack = packets.SYNACK(version, myaddr, remoteaddr, syn)
284    msg = "Received %s, expected to see reply %s" % (desc, synack_desc)
285    reply = self._ReceiveAndExpectResponse(netid, syn, synack, msg)
286    if end_state == sock_diag.TCP_SYN_RECV:
287      return
288
289    establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
290    self.ReceivePacketOn(netid, establishing_ack)
291
292    if end_state == self.NOT_YET_ACCEPTED:
293      return
294
295    self.accepted, _ = self.s.accept()
296    if end_state == sock_diag.TCP_ESTABLISHED:
297      return
298
299    desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack,
300                             payload=net_test.UDP_PAYLOAD)
301    self.accepted.send(net_test.UDP_PAYLOAD)
302    self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
303
304    desc, fin = packets.FIN(version, remoteaddr, myaddr, data)
305    fin = packets._GetIpLayer(version)(str(fin))
306    ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, fin)
307    msg = "Received %s, expected to see reply %s" % (desc, ack_desc)
308
309    # TODO: Why can't we use this?
310    #   self._ReceiveAndExpectResponse(netid, fin, ack, msg)
311    self.ReceivePacketOn(netid, fin)
312    time.sleep(0.1)
313    self.ExpectPacketOn(netid, msg + ": expecting %s" % ack_desc, ack)
314    if end_state == sock_diag.TCP_CLOSE_WAIT:
315      return
316
317    raise ValueError("Invalid TCP state %d specified" % end_state)
318
319  def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True):
320    """Closes the socket and checks whether a RST is sent or not."""
321    if sock is not None:
322      self.assertIsNone(req, "Must specify sock or req, not both")
323      self.sock_diag.CloseSocketFromFd(sock)
324      self.assertRaisesErrno(EINVAL, sock.accept)
325    else:
326      self.assertIsNone(sock, "Must specify sock or req, not both")
327      self.sock_diag.CloseSocket(req)
328
329    if expect_reset:
330      desc, rst = self.RstPacket()
331      msg = "%s: expecting %s: " % (msg, desc)
332      self.ExpectPacketOn(self.netid, msg, rst)
333    else:
334      msg = "%s: " % msg
335      self.ExpectNoPacketsOn(self.netid, msg)
336
337    if sock is not None and do_close:
338      sock.close()
339
340  def CheckTcpReset(self, state, statename):
341    for version in [4, 6]:
342      msg = "Closing incoming IPv%d %s socket" % (version, statename)
343      self.IncomingConnection(version, state, self.netid)
344      self.CheckRstOnClose(self.s, None, False, msg)
345      if state != sock_diag.TCP_LISTEN:
346        msg = "Closing accepted IPv%d %s socket" % (version, statename)
347        self.CheckRstOnClose(self.accepted, None, True, msg)
348
349  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
350  def testTcpResets(self):
351    """Checks that closing sockets in appropriate states sends a RST."""
352    self.CheckTcpReset(sock_diag.TCP_LISTEN, "TCP_LISTEN")
353    self.CheckTcpReset(sock_diag.TCP_ESTABLISHED, "TCP_ESTABLISHED")
354    self.CheckTcpReset(sock_diag.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")
355
356  def FindChildSockets(self, s):
357    """Finds the SYN_RECV child sockets of a given listening socket."""
358    d = self.sock_diag.FindSockDiagFromFd(self.s)
359    req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
360    req.states = 1 << sock_diag.TCP_SYN_RECV | 1 << sock_diag.TCP_ESTABLISHED
361    req.id.cookie = "\x00" * 8
362    children = self.sock_diag.Dump(req)
363    return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
364            for d, _ in children]
365
366  def CheckChildSocket(self, state, statename, parent_first):
367    for version in [4, 6]:
368      self.IncomingConnection(version, state, self.netid)
369
370      d = self.sock_diag.FindSockDiagFromFd(self.s)
371      parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
372      children = self.FindChildSockets(self.s)
373      self.assertEquals(1, len(children))
374
375      is_established = (state == self.NOT_YET_ACCEPTED)
376
377      # The new TCP listener code in 4.4 makes SYN_RECV sockets live in the
378      # regular TCP hash tables, and inet_diag_find_one_icsk can find them.
379      # Before 4.4, we can see those sockets in dumps, but we can't fetch
380      # or close them.
381      can_close_children = is_established or net_test.LINUX_VERSION >= (4, 4)
382
383      for child in children:
384        if can_close_children:
385          self.sock_diag.GetSockDiag(child)  # No errors? Good, child found.
386        else:
387          self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
388
389      def CloseParent(expect_reset):
390        msg = "Closing parent IPv%d %s socket %s child" % (
391            version, statename, "before" if parent_first else "after")
392        self.CheckRstOnClose(self.s, None, expect_reset, msg)
393        self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, parent)
394
395      def CheckChildrenClosed():
396        for child in children:
397          self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
398
399      def CloseChildren():
400        for child in children:
401          msg = "Closing child IPv%d %s socket %s parent" % (
402              version, statename, "after" if parent_first else "before")
403          self.sock_diag.GetSockDiag(child)
404          self.CheckRstOnClose(None, child, is_established, msg)
405          self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
406        CheckChildrenClosed()
407
408      if parent_first:
409        # Closing the parent will close child sockets, which will send a RST,
410        # iff they are already established.
411        CloseParent(is_established)
412        if is_established:
413          CheckChildrenClosed()
414        elif can_close_children:
415          CloseChildren()
416          CheckChildrenClosed()
417        self.s.close()
418      else:
419        if can_close_children:
420          CloseChildren()
421        CloseParent(False)
422        self.s.close()
423
424  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
425  def testChildSockets(self):
426    self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", False)
427    self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", True)
428    self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", False)
429    self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", True)
430
431  def CloseDuringBlockingCall(self, sock, call, expected_errno):
432    thread = SocketExceptionThread(sock, call)
433    thread.start()
434    time.sleep(0.1)
435    self.sock_diag.CloseSocketFromFd(sock)
436    thread.join(1)
437    self.assertFalse(thread.is_alive())
438    self.assertIsNotNone(thread.exception)
439    self.assertTrue(isinstance(thread.exception, IOError),
440                    "Expected IOError, got %s" % thread.exception)
441    self.assertEqual(expected_errno, thread.exception.errno)
442    self.assertSocketClosed(sock)
443
444  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
445  def testAcceptInterrupted(self):
446    """Tests that accept() is interrupted by SOCK_DESTROY."""
447    for version in [4, 5, 6]:
448      self.IncomingConnection(version, sock_diag.TCP_LISTEN, self.netid)
449      self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL)
450      self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo")
451      self.assertRaisesErrno(EINVAL, self.s.accept)
452
453  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
454  def testReadInterrupted(self):
455    """Tests that read() is interrupted by SOCK_DESTROY."""
456    for version in [4, 5, 6]:
457      self.IncomingConnection(version, sock_diag.TCP_ESTABLISHED, self.netid)
458      self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096),
459                                   ECONNABORTED)
460      self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
461
462  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
463  def testConnectInterrupted(self):
464    """Tests that connect() is interrupted by SOCK_DESTROY."""
465    for version in [4, 5, 6]:
466      family = {4: AF_INET, 6: AF_INET6}[version]
467      s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
468      self.SelectInterface(s, self.netid, "mark")
469      remoteaddr = self.GetRemoteAddress(version)
470      s.bind(("", 0))
471      _, sport = s.getsockname()[:2]
472      self.CloseDuringBlockingCall(
473          s, lambda sock: sock.connect((remoteaddr, 53)), ECONNABORTED)
474      desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid),
475                              remoteaddr, sport=sport, seq=None)
476      self.ExpectPacketOn(self.netid, desc, syn)
477      msg = "SOCK_DESTROY of socket in connect, expected no RST"
478      self.ExpectNoPacketsOn(self.netid, msg)
479
480  def testIpv4MappedSynRecvSocket(self):
481    """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets.
482
483    Relevant kernel commits:
484         android-3.4:
485           457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state
486    """
487    self.IncomingConnection(5, sock_diag.TCP_SYN_RECV, self.netid)
488    sock_id = self.sock_diag._EmptyInetDiagSockId()
489    sock_id.sport = self.port
490    states = 1 << sock_diag.TCP_SYN_RECV
491    req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
492    children = self.sock_diag.Dump(req)
493
494    self.assertTrue(children)
495    for child, unused_args in children:
496      self.assertEqual(sock_diag.TCP_SYN_RECV, child.state)
497      self.assertEqual(self.sock_diag.PaddedAddress(self.remoteaddr),
498                       child.id.dst)
499      self.assertEqual(self.sock_diag.PaddedAddress(self.myaddr),
500                       child.id.src)
501
502
503if __name__ == "__main__":
504  unittest.main()
505