1#!/usr/bin/python
2#
3# Copyright 2014 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import errno
18import os
19import random
20from socket import *  # pylint: disable=wildcard-import
21import struct
22import time           # pylint: disable=unused-import
23import unittest
24
25from scapy import all as scapy
26
27import iproute
28import multinetwork_base
29import net_test
30
31PING_IDENT = 0xff19
32PING_PAYLOAD = "foobarbaz"
33PING_SEQ = 3
34PING_TOS = 0x83
35
36IPV6_FLOWINFO = 11
37
38
39UDP_PAYLOAD = str(scapy.DNS(rd=1,
40                            id=random.randint(0, 65535),
41                            qd=scapy.DNSQR(qname="wWW.GoOGle.CoM",
42                                           qtype="AAAA")))
43
44
45IPV4_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv4/fwmark_reflect"
46IPV6_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv6/fwmark_reflect"
47SYNCOOKIES_SYSCTL = "/proc/sys/net/ipv4/tcp_syncookies"
48TCP_MARK_ACCEPT_SYSCTL = "/proc/sys/net/ipv4/tcp_fwmark_accept"
49
50HAVE_MARK_REFLECT = os.path.isfile(IPV4_MARK_REFLECT_SYSCTL)
51HAVE_TCP_MARK_ACCEPT = os.path.isfile(TCP_MARK_ACCEPT_SYSCTL)
52
53# The IP[V6]UNICAST_IF socket option was added between 3.1 and 3.4.
54HAVE_UNICAST_IF = net_test.LINUX_VERSION >= (3, 4, 0)
55
56
57class ConfigurationError(AssertionError):
58  pass
59
60
61class Packets(object):
62
63  TCP_FIN = 1
64  TCP_SYN = 2
65  TCP_RST = 4
66  TCP_PSH = 8
67  TCP_ACK = 16
68
69  TCP_SEQ = 1692871236
70  TCP_WINDOW = 14400
71
72  @staticmethod
73  def RandomPort():
74    return random.randint(1025, 65535)
75
76  @staticmethod
77  def _GetIpLayer(version):
78    return {4: scapy.IP, 6: scapy.IPv6}[version]
79
80  @staticmethod
81  def _SetPacketTos(packet, tos):
82    if isinstance(packet, scapy.IPv6):
83      packet.tc = tos
84    elif isinstance(packet, scapy.IP):
85      packet.tos = tos
86    else:
87      raise ValueError("Can't find ToS Field")
88
89  @classmethod
90  def UDP(cls, version, srcaddr, dstaddr, sport=0):
91    ip = cls._GetIpLayer(version)
92    # Can't just use "if sport" because None has meaning (it means unspecified).
93    if sport == 0:
94      sport = cls.RandomPort()
95    return ("UDPv%d packet" % version,
96            ip(src=srcaddr, dst=dstaddr) /
97            scapy.UDP(sport=sport, dport=53) / UDP_PAYLOAD)
98
99  @classmethod
100  def UDPWithOptions(cls, version, srcaddr, dstaddr, sport=0):
101    if version == 4:
102      packet = (scapy.IP(src=srcaddr, dst=dstaddr, ttl=39, tos=0x83) /
103                scapy.UDP(sport=sport, dport=53) /
104                UDP_PAYLOAD)
105    else:
106      packet = (scapy.IPv6(src=srcaddr, dst=dstaddr,
107                           fl=0xbeef, hlim=39, tc=0x83) /
108                scapy.UDP(sport=sport, dport=53) /
109                UDP_PAYLOAD)
110    return ("UDPv%d packet with options" % version, packet)
111
112  @classmethod
113  def SYN(cls, dport, version, srcaddr, dstaddr, sport=0, seq=TCP_SEQ):
114    ip = cls._GetIpLayer(version)
115    if sport == 0:
116      sport = cls.RandomPort()
117    return ("TCP SYN",
118            ip(src=srcaddr, dst=dstaddr) /
119            scapy.TCP(sport=sport, dport=dport,
120                      seq=seq, ack=0,
121                      flags=cls.TCP_SYN, window=cls.TCP_WINDOW))
122
123  @classmethod
124  def RST(cls, version, srcaddr, dstaddr, packet):
125    ip = cls._GetIpLayer(version)
126    original = packet.getlayer("TCP")
127    return ("TCP RST",
128            ip(src=srcaddr, dst=dstaddr) /
129            scapy.TCP(sport=original.dport, dport=original.sport,
130                      ack=original.seq + 1, seq=None,
131                      flags=cls.TCP_RST | cls.TCP_ACK, window=cls.TCP_WINDOW))
132
133  @classmethod
134  def SYNACK(cls, version, srcaddr, dstaddr, packet):
135    ip = cls._GetIpLayer(version)
136    original = packet.getlayer("TCP")
137    return ("TCP SYN+ACK",
138            ip(src=srcaddr, dst=dstaddr) /
139            scapy.TCP(sport=original.dport, dport=original.sport,
140                      ack=original.seq + 1, seq=None,
141                      flags=cls.TCP_SYN | cls.TCP_ACK, window=None))
142
143  @classmethod
144  def ACK(cls, version, srcaddr, dstaddr, packet, payload=""):
145    ip = cls._GetIpLayer(version)
146    original = packet.getlayer("TCP")
147    was_syn_or_fin = (original.flags & (cls.TCP_SYN | cls.TCP_FIN)) != 0
148    ack_delta = was_syn_or_fin + len(original.payload)
149    desc = "TCP data" if payload else "TCP ACK"
150    flags = cls.TCP_ACK | cls.TCP_PSH if payload else cls.TCP_ACK
151    return (desc,
152            ip(src=srcaddr, dst=dstaddr) /
153            scapy.TCP(sport=original.dport, dport=original.sport,
154                      ack=original.seq + ack_delta, seq=original.ack,
155                      flags=flags, window=cls.TCP_WINDOW) /
156            payload)
157
158  @classmethod
159  def FIN(cls, version, srcaddr, dstaddr, packet):
160    ip = cls._GetIpLayer(version)
161    original = packet.getlayer("TCP")
162    was_fin = (original.flags & cls.TCP_FIN) != 0
163    return ("TCP FIN",
164            ip(src=srcaddr, dst=dstaddr) /
165            scapy.TCP(sport=original.dport, dport=original.sport,
166                      ack=original.seq + was_fin, seq=original.ack,
167                      flags=cls.TCP_ACK | cls.TCP_FIN, window=cls.TCP_WINDOW))
168
169  @classmethod
170  def GRE(cls, version, srcaddr, dstaddr, proto, packet):
171    if version == 4:
172      ip = scapy.IP(src=srcaddr, dst=dstaddr, proto=net_test.IPPROTO_GRE)
173    else:
174      ip = scapy.IPv6(src=srcaddr, dst=dstaddr, nh=net_test.IPPROTO_GRE)
175    packet = ip / scapy.GRE(proto=proto) / packet
176    return ("GRE packet", packet)
177
178  @classmethod
179  def ICMPPortUnreachable(cls, version, srcaddr, dstaddr, packet):
180    if version == 4:
181      # Linux hardcodes the ToS on ICMP errors to 0xc0 or greater because of
182      # RFC 1812 4.3.2.5 (!).
183      return ("ICMPv4 port unreachable",
184              scapy.IP(src=srcaddr, dst=dstaddr, proto=1, tos=0xc0) /
185              scapy.ICMPerror(type=3, code=3) / packet)
186    else:
187      return ("ICMPv6 port unreachable",
188              scapy.IPv6(src=srcaddr, dst=dstaddr) /
189              scapy.ICMPv6DestUnreach(code=4) / packet)
190
191  @classmethod
192  def ICMPPacketTooBig(cls, version, srcaddr, dstaddr, packet):
193    if version == 4:
194      return ("ICMPv4 fragmentation needed",
195              scapy.IP(src=srcaddr, dst=dstaddr, proto=1) /
196              scapy.ICMPerror(type=3, code=4, unused=1280) / str(packet)[:64])
197    else:
198      udp = packet.getlayer("UDP")
199      udp.payload = str(udp.payload)[:1280-40-8]
200      return ("ICMPv6 Packet Too Big",
201              scapy.IPv6(src=srcaddr, dst=dstaddr) /
202              scapy.ICMPv6PacketTooBig() / str(packet)[:1232])
203
204  @classmethod
205  def ICMPEcho(cls, version, srcaddr, dstaddr):
206    ip = cls._GetIpLayer(version)
207    icmp = {4: scapy.ICMP, 6: scapy.ICMPv6EchoRequest}[version]
208    packet = (ip(src=srcaddr, dst=dstaddr) /
209              icmp(id=PING_IDENT, seq=PING_SEQ) / PING_PAYLOAD)
210    cls._SetPacketTos(packet, PING_TOS)
211    return ("ICMPv%d echo" % version, packet)
212
213  @classmethod
214  def ICMPReply(cls, version, srcaddr, dstaddr, packet):
215    ip = cls._GetIpLayer(version)
216    # Scapy doesn't provide an ICMP echo reply constructor.
217    icmpv4_reply = lambda **kwargs: scapy.ICMP(type=0, **kwargs)
218    icmp = {4: icmpv4_reply, 6: scapy.ICMPv6EchoReply}[version]
219    packet = (ip(src=srcaddr, dst=dstaddr) /
220              icmp(id=PING_IDENT, seq=PING_SEQ) / PING_PAYLOAD)
221    # IPv6 only started copying the tclass to echo replies in 3.14.
222    if version == 4 or net_test.LINUX_VERSION >= (3, 14):
223      cls._SetPacketTos(packet, PING_TOS)
224    return ("ICMPv%d echo reply" % version, packet)
225
226  @classmethod
227  def NS(cls, srcaddr, tgtaddr, srcmac):
228    solicited = inet_pton(AF_INET6, tgtaddr)
229    last3bytes = tuple([ord(b) for b in solicited[-3:]])
230    solicited = "ff02::1:ff%02x:%02x%02x" % last3bytes
231    packet = (scapy.IPv6(src=srcaddr, dst=solicited) /
232              scapy.ICMPv6ND_NS(tgt=tgtaddr) /
233              scapy.ICMPv6NDOptSrcLLAddr(lladdr=srcmac))
234    return ("ICMPv6 NS", packet)
235
236  @classmethod
237  def NA(cls, srcaddr, dstaddr, srcmac):
238    packet = (scapy.IPv6(src=srcaddr, dst=dstaddr) /
239              scapy.ICMPv6ND_NA(tgt=srcaddr, R=0, S=1, O=1) /
240              scapy.ICMPv6NDOptDstLLAddr(lladdr=srcmac))
241    return ("ICMPv6 NA", packet)
242
243
244class InboundMarkingTest(multinetwork_base.MultiNetworkBaseTest):
245
246  @classmethod
247  def _SetInboundMarking(cls, netid, is_add):
248    for version in [4, 6]:
249      # Run iptables to set up incoming packet marking.
250      iface = cls.GetInterfaceName(netid)
251      add_del = "-A" if is_add else "-D"
252      iptables = {4: "iptables", 6: "ip6tables"}[version]
253      args = "%s %s INPUT -t mangle -i %s -j MARK --set-mark %d" % (
254          iptables, add_del, iface, netid)
255      iptables = "/sbin/" + iptables
256      ret = os.spawnvp(os.P_WAIT, iptables, args.split(" "))
257      if ret:
258        raise ConfigurationError("Setup command failed: %s" % args)
259
260  @classmethod
261  def setUpClass(cls):
262    super(InboundMarkingTest, cls).setUpClass()
263    for netid in cls.tuns:
264      cls._SetInboundMarking(netid, True)
265
266  @classmethod
267  def tearDownClass(cls):
268    for netid in cls.tuns:
269      cls._SetInboundMarking(netid, False)
270    super(InboundMarkingTest, cls).tearDownClass()
271
272  @classmethod
273  def SetMarkReflectSysctls(cls, value):
274    cls.SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value)
275    try:
276      cls.SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value)
277    except IOError:
278      # This does not exist if we use the version of the patch that uses a
279      # common sysctl for IPv4 and IPv6.
280      pass
281
282
283class OutgoingTest(multinetwork_base.MultiNetworkBaseTest):
284
285  # How many times to run outgoing packet tests.
286  ITERATIONS = 5
287
288  def CheckPingPacket(self, version, netid, routing_mode, dstaddr, packet):
289    s = self.BuildSocket(version, net_test.PingSocket, netid, routing_mode)
290
291    myaddr = self.MyAddress(version, netid)
292    s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
293    s.bind((myaddr, PING_IDENT))
294    net_test.SetSocketTos(s, PING_TOS)
295
296    desc, expected = Packets.ICMPEcho(version, myaddr, dstaddr)
297    msg = "IPv%d ping: expected %s on %s" % (
298        version, desc, self.GetInterfaceName(netid))
299
300    s.sendto(packet + PING_PAYLOAD, (dstaddr, 19321))
301
302    self.ExpectPacketOn(netid, msg, expected)
303
304  def CheckTCPSYNPacket(self, version, netid, routing_mode, dstaddr):
305    s = self.BuildSocket(version, net_test.TCPSocket, netid, routing_mode)
306
307    if version == 6 and dstaddr.startswith("::ffff"):
308      version = 4
309    myaddr = self.MyAddress(version, netid)
310    desc, expected = Packets.SYN(53, version, myaddr, dstaddr,
311                                 sport=None, seq=None)
312
313    # Non-blocking TCP connects always return EINPROGRESS.
314    self.assertRaisesErrno(errno.EINPROGRESS, s.connect, (dstaddr, 53))
315    msg = "IPv%s TCP connect: expected %s on %s" % (
316        version, desc, self.GetInterfaceName(netid))
317    self.ExpectPacketOn(netid, msg, expected)
318    s.close()
319
320  def CheckUDPPacket(self, version, netid, routing_mode, dstaddr):
321    s = self.BuildSocket(version, net_test.UDPSocket, netid, routing_mode)
322
323    if version == 6 and dstaddr.startswith("::ffff"):
324      version = 4
325    myaddr = self.MyAddress(version, netid)
326    desc, expected = Packets.UDP(version, myaddr, dstaddr, sport=None)
327    msg = "IPv%s UDP %%s: expected %s on %s" % (
328        version, desc, self.GetInterfaceName(netid))
329
330    s.sendto(UDP_PAYLOAD, (dstaddr, 53))
331    self.ExpectPacketOn(netid, msg % "sendto", expected)
332
333    # IP_UNICAST_IF doesn't seem to work on connected sockets, so no TCP.
334    if routing_mode != "ucast_oif":
335      s.connect((dstaddr, 53))
336      s.send(UDP_PAYLOAD)
337      self.ExpectPacketOn(netid, msg % "connect/send", expected)
338      s.close()
339
340  def CheckRawGrePacket(self, version, netid, routing_mode, dstaddr):
341    s = self.BuildSocket(version, net_test.RawGRESocket, netid, routing_mode)
342
343    inner_version = {4: 6, 6: 4}[version]
344    inner_src = self.MyAddress(inner_version, netid)
345    inner_dst = self.GetRemoteAddress(inner_version)
346    inner = str(Packets.UDP(inner_version, inner_src, inner_dst, sport=None)[1])
347
348    ethertype = {4: net_test.ETH_P_IP, 6: net_test.ETH_P_IPV6}[inner_version]
349    # A GRE header can be as simple as two zero bytes and the ethertype.
350    packet = struct.pack("!i", ethertype) + inner
351    myaddr = self.MyAddress(version, netid)
352
353    s.sendto(packet, (dstaddr, IPPROTO_GRE))
354    desc, expected = Packets.GRE(version, myaddr, dstaddr, ethertype, inner)
355    msg = "Raw IPv%d GRE with inner IPv%d UDP: expected %s on %s" % (
356        version, inner_version, desc, self.GetInterfaceName(netid))
357    self.ExpectPacketOn(netid, msg, expected)
358
359  def CheckOutgoingPackets(self, routing_mode):
360    v4addr = self.IPV4_ADDR
361    v6addr = self.IPV6_ADDR
362    v4mapped = "::ffff:" + v4addr
363
364    for _ in xrange(self.ITERATIONS):
365      for netid in self.tuns:
366
367        self.CheckPingPacket(4, netid, routing_mode, v4addr, self.IPV4_PING)
368        # Kernel bug.
369        if routing_mode != "oif":
370          self.CheckPingPacket(6, netid, routing_mode, v6addr, self.IPV6_PING)
371
372        # IP_UNICAST_IF doesn't seem to work on connected sockets, so no TCP.
373        if routing_mode != "ucast_oif":
374          self.CheckTCPSYNPacket(4, netid, routing_mode, v4addr)
375          self.CheckTCPSYNPacket(6, netid, routing_mode, v6addr)
376          self.CheckTCPSYNPacket(6, netid, routing_mode, v4mapped)
377
378        self.CheckUDPPacket(4, netid, routing_mode, v4addr)
379        self.CheckUDPPacket(6, netid, routing_mode, v6addr)
380        self.CheckUDPPacket(6, netid, routing_mode, v4mapped)
381
382        # Creating raw sockets on non-root UIDs requires properly setting
383        # capabilities, which is hard to do from Python.
384        # IP_UNICAST_IF is not supported on raw sockets.
385        if routing_mode not in ["uid", "ucast_oif"]:
386          self.CheckRawGrePacket(4, netid, routing_mode, v4addr)
387          self.CheckRawGrePacket(6, netid, routing_mode, v6addr)
388
389  def testMarkRouting(self):
390    """Checks that socket marking selects the right outgoing interface."""
391    self.CheckOutgoingPackets("mark")
392
393  @unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
394  def testUidRouting(self):
395    """Checks that UID routing selects the right outgoing interface."""
396    self.CheckOutgoingPackets("uid")
397
398  def testOifRouting(self):
399    """Checks that oif routing selects the right outgoing interface."""
400    self.CheckOutgoingPackets("oif")
401
402  @unittest.skipUnless(HAVE_UNICAST_IF, "no support for UNICAST_IF")
403  def testUcastOifRouting(self):
404    """Checks that ucast oif routing selects the right outgoing interface."""
405    self.CheckOutgoingPackets("ucast_oif")
406
407  def CheckRemarking(self, version, use_connect):
408    # Remarking or resetting UNICAST_IF on connected sockets does not work.
409    if use_connect:
410      modes = ["oif"]
411    else:
412      modes = ["mark", "oif"]
413      if HAVE_UNICAST_IF:
414        modes += ["ucast_oif"]
415
416    for mode in modes:
417      s = net_test.UDPSocket(self.GetProtocolFamily(version))
418
419      # Figure out what packets to expect.
420      unspec = {4: "0.0.0.0", 6: "::"}[version]
421      sport = Packets.RandomPort()
422      s.bind((unspec, sport))
423      dstaddr = {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version]
424      desc, expected = Packets.UDP(version, unspec, dstaddr, sport)
425
426      # If we're testing connected sockets, connect the socket on the first
427      # netid now.
428      if use_connect:
429        netid = self.tuns.keys()[0]
430        self.SelectInterface(s, netid, mode)
431        s.connect((dstaddr, 53))
432        expected.src = self.MyAddress(version, netid)
433
434      # For each netid, select that network without closing the socket, and
435      # check that the packets sent on that socket go out on the right network.
436      for netid in self.tuns:
437        self.SelectInterface(s, netid, mode)
438        if not use_connect:
439          expected.src = self.MyAddress(version, netid)
440        s.sendto(UDP_PAYLOAD, (dstaddr, 53))
441        connected_str = "Connected" if use_connect else "Unconnected"
442        msg = "%s UDPv%d socket remarked using %s: expecting %s on %s" % (
443            connected_str, version, mode, desc, self.GetInterfaceName(netid))
444        self.ExpectPacketOn(netid, msg, expected)
445        self.SelectInterface(s, None, mode)
446
447  def testIPv4Remarking(self):
448    """Checks that updating the mark on an IPv4 socket changes routing."""
449    self.CheckRemarking(4, False)
450    self.CheckRemarking(4, True)
451
452  def testIPv6Remarking(self):
453    """Checks that updating the mark on an IPv6 socket changes routing."""
454    self.CheckRemarking(6, False)
455    self.CheckRemarking(6, True)
456
457  def testIPv6StickyPktinfo(self):
458    for _ in xrange(self.ITERATIONS):
459      for netid in self.tuns:
460        s = net_test.UDPSocket(AF_INET6)
461
462        # Set a flowlabel.
463        net_test.SetFlowLabel(s, net_test.IPV6_ADDR, 0xdead)
464        s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_FLOWINFO_SEND, 1)
465
466        # Set some destination options.
467        nonce = "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c"
468        dstopts = "".join([
469            "\x11\x02",              # Next header=UDP, 24 bytes of options.
470            "\x01\x06", "\x00" * 6,  # PadN, 6 bytes of padding.
471            "\x8b\x0c",              # ILNP nonce, 12 bytes.
472            nonce
473        ])
474        s.setsockopt(net_test.SOL_IPV6, IPV6_DSTOPTS, dstopts)
475        s.setsockopt(net_test.SOL_IPV6, IPV6_UNICAST_HOPS, 255)
476
477        pktinfo = multinetwork_base.MakePktInfo(6, None, self.ifindices[netid])
478
479        # Set the sticky pktinfo option.
480        s.setsockopt(net_test.SOL_IPV6, IPV6_PKTINFO, pktinfo)
481
482        # Specify the flowlabel in the destination address.
483        s.sendto(UDP_PAYLOAD, (net_test.IPV6_ADDR, 53, 0xdead, 0))
484
485        sport = s.getsockname()[1]
486        srcaddr = self.MyAddress(6, netid)
487        expected = (scapy.IPv6(src=srcaddr, dst=net_test.IPV6_ADDR,
488                               fl=0xdead, hlim=255) /
489                    scapy.IPv6ExtHdrDestOpt(
490                        options=[scapy.PadN(optdata="\x00\x00\x00\x00\x00\x00"),
491                                 scapy.HBHOptUnknown(otype=0x8b,
492                                                     optdata=nonce)]) /
493                    scapy.UDP(sport=sport, dport=53) /
494                    UDP_PAYLOAD)
495        msg = "IPv6 UDP using sticky pktinfo: expected UDP packet on %s" % (
496            self.GetInterfaceName(netid))
497        self.ExpectPacketOn(netid, msg, expected)
498
499  def CheckPktinfoRouting(self, version):
500    for _ in xrange(self.ITERATIONS):
501      for netid in self.tuns:
502        family = self.GetProtocolFamily(version)
503        s = net_test.UDPSocket(family)
504
505        if version == 6:
506          # Create a flowlabel so we can use it.
507          net_test.SetFlowLabel(s, net_test.IPV6_ADDR, 0xbeef)
508
509          # Specify some arbitrary options.
510          cmsgs = [
511              (net_test.SOL_IPV6, IPV6_HOPLIMIT, 39),
512              (net_test.SOL_IPV6, IPV6_TCLASS, 0x83),
513              (net_test.SOL_IPV6, IPV6_FLOWINFO, int(htonl(0xbeef))),
514          ]
515        else:
516          # Support for setting IPv4 TOS and TTL via cmsg only appeared in 3.13.
517          cmsgs = []
518          s.setsockopt(net_test.SOL_IP, IP_TTL, 39)
519          s.setsockopt(net_test.SOL_IP, IP_TOS, 0x83)
520
521        dstaddr = self.GetRemoteAddress(version)
522        self.SendOnNetid(version, s, dstaddr, 53, netid, UDP_PAYLOAD, cmsgs)
523
524        sport = s.getsockname()[1]
525        srcaddr = self.MyAddress(version, netid)
526
527        desc, expected = Packets.UDPWithOptions(version, srcaddr, dstaddr,
528                                                sport=sport)
529
530        msg = "IPv%d UDP using pktinfo routing: expected %s on %s" % (
531            version, desc, self.GetInterfaceName(netid))
532        self.ExpectPacketOn(netid, msg, expected)
533
534  def testIPv4PktinfoRouting(self):
535    self.CheckPktinfoRouting(4)
536
537  def testIPv6PktinfoRouting(self):
538    self.CheckPktinfoRouting(6)
539
540
541class MarkTest(InboundMarkingTest):
542
543  def CheckReflection(self, version, gen_packet, gen_reply):
544    """Checks that replies go out on the same interface as the original.
545
546    For each combination:
547     - Calls gen_packet to generate a packet to that IP address.
548     - Writes the packet generated by gen_packet on the given tun
549       interface, causing the kernel to receive it.
550     - Checks that the kernel's reply matches the packet generated by
551       gen_reply.
552
553    Args:
554      version: An integer, 4 or 6.
555      gen_packet: A function taking an IP version (an integer), a source
556        address and a destination address (strings), and returning a scapy
557        packet.
558      gen_reply: A function taking the same arguments as gen_packet,
559        plus a scapy packet, and returning a scapy packet.
560    """
561    for netid, iif, ip_if, myaddr, remoteaddr in self.Combinations(version):
562      # Generate a test packet.
563      desc, packet = gen_packet(version, remoteaddr, myaddr)
564
565      # Test with mark reflection enabled and disabled.
566      for reflect in [0, 1]:
567        self.SetMarkReflectSysctls(reflect)
568        # HACK: IPv6 ping replies always do a routing lookup with the
569        # interface the ping came in on. So even if mark reflection is not
570        # working, IPv6 ping replies will be properly reflected. Don't
571        # fail when that happens.
572        if reflect or desc == "ICMPv6 echo":
573          reply_desc, reply = gen_reply(version, myaddr, remoteaddr, packet)
574        else:
575          reply_desc, reply = None, None
576
577        msg = self._FormatMessage(iif, ip_if, "reflect=%d" % reflect,
578                                  desc, reply_desc)
579        self._ReceiveAndExpectResponse(netid, packet, reply, msg)
580
581  def SYNToClosedPort(self, *args):
582    return Packets.SYN(999, *args)
583
584  @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
585  def testIPv4ICMPErrorsReflectMark(self):
586    self.CheckReflection(4, Packets.UDP, Packets.ICMPPortUnreachable)
587
588  @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
589  def testIPv6ICMPErrorsReflectMark(self):
590    self.CheckReflection(6, Packets.UDP, Packets.ICMPPortUnreachable)
591
592  @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
593  def testIPv4PingRepliesReflectMarkAndTos(self):
594    self.CheckReflection(4, Packets.ICMPEcho, Packets.ICMPReply)
595
596  @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
597  def testIPv6PingRepliesReflectMarkAndTos(self):
598    self.CheckReflection(6, Packets.ICMPEcho, Packets.ICMPReply)
599
600  @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
601  def testIPv4RSTsReflectMark(self):
602    self.CheckReflection(4, self.SYNToClosedPort, Packets.RST)
603
604  @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
605  def testIPv6RSTsReflectMark(self):
606    self.CheckReflection(6, self.SYNToClosedPort, Packets.RST)
607
608
609class TCPAcceptTest(InboundMarkingTest):
610
611  MODE_BINDTODEVICE = "SO_BINDTODEVICE"
612  MODE_INCOMING_MARK = "incoming mark"
613  MODE_EXPLICIT_MARK = "explicit mark"
614  MODE_UID = "uid"
615
616  @classmethod
617  def setUpClass(cls):
618    super(TCPAcceptTest, cls).setUpClass()
619
620    # Open a port so we can observe SYN+ACKs. Since it's a dual-stack socket it
621    # will accept both IPv4 and IPv6 connections. We do this here instead of in
622    # each test so we can use the same socket every time. That way, if a kernel
623    # bug causes incoming packets to mark the listening socket instead of the
624    # accepted socket, the test will fail as soon as the next address/interface
625    # combination is tried.
626    cls.listenport = 1234
627    cls.listensocket = net_test.IPv6TCPSocket()
628    cls.listensocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
629    cls.listensocket.bind(("::", cls.listenport))
630    cls.listensocket.listen(100)
631
632  def BounceSocket(self, s):
633    """Attempts to invalidate a socket's destination cache entry."""
634    if s.family == AF_INET:
635      tos = s.getsockopt(SOL_IP, IP_TOS)
636      s.setsockopt(net_test.SOL_IP, IP_TOS, 53)
637      s.setsockopt(net_test.SOL_IP, IP_TOS, tos)
638    else:
639      # UDP, 8 bytes dstopts; PAD1, 4 bytes padding; 4 bytes zeros.
640      pad8 = "".join(["\x11\x00", "\x01\x04", "\x00" * 4])
641      s.setsockopt(net_test.SOL_IPV6, IPV6_DSTOPTS, pad8)
642      s.setsockopt(net_test.SOL_IPV6, IPV6_DSTOPTS, "")
643
644  def _SetTCPMarkAcceptSysctl(self, value):
645    self.SetSysctl(TCP_MARK_ACCEPT_SYSCTL, value)
646
647  def CheckTCPConnection(self, mode, listensocket, netid, version,
648                         myaddr, remoteaddr, packet, reply, msg):
649    establishing_ack = Packets.ACK(version, remoteaddr, myaddr, reply)[1]
650
651    # Attempt to confuse the kernel.
652    self.BounceSocket(listensocket)
653
654    self.ReceivePacketOn(netid, establishing_ack)
655
656    # If we're using UID routing, the accept() call has to be run as a UID that
657    # is routed to the specified netid, because the UID of the socket returned
658    # by accept() is the effective UID of the process that calls it. It doesn't
659    # need to be the same UID; any UID that selects the same interface will do.
660    with net_test.RunAsUid(self.UidForNetid(netid)):
661      s, _ = listensocket.accept()
662
663    try:
664      # Check that data sent on the connection goes out on the right interface.
665      desc, data = Packets.ACK(version, myaddr, remoteaddr, establishing_ack,
666                               payload=UDP_PAYLOAD)
667      s.send(UDP_PAYLOAD)
668      self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
669      self.BounceSocket(s)
670
671      # Keep up our end of the conversation.
672      ack = Packets.ACK(version, remoteaddr, myaddr, data)[1]
673      self.BounceSocket(listensocket)
674      self.ReceivePacketOn(netid, ack)
675
676      mark = self.GetSocketMark(s)
677    finally:
678      self.BounceSocket(s)
679      s.close()
680
681    if mode == self.MODE_INCOMING_MARK:
682      self.assertEquals(netid, mark,
683                        msg + ": Accepted socket: Expected mark %d, got %d" % (
684                            netid, mark))
685    elif mode != self.MODE_EXPLICIT_MARK:
686      self.assertEquals(0, self.GetSocketMark(listensocket))
687
688    # Check the FIN was sent on the right interface, and ack it. We don't expect
689    # this to fail because by the time the connection is established things are
690    # likely working, but a) extra tests are always good and b) extra packets
691    # like the FIN (and retransmitted FINs) could cause later tests that expect
692    # no packets to fail.
693    desc, fin = Packets.FIN(version, myaddr, remoteaddr, ack)
694    self.ExpectPacketOn(netid, msg + ": expecting %s after close" % desc, fin)
695
696    desc, finack = Packets.FIN(version, remoteaddr, myaddr, fin)
697    self.ReceivePacketOn(netid, finack)
698
699    # Since we called close() earlier, the userspace socket object is gone, so
700    # the socket has no UID. If we're doing UID routing, the ack might be routed
701    # incorrectly. Not much we can do here.
702    desc, finackack = Packets.ACK(version, myaddr, remoteaddr, finack)
703    if mode != self.MODE_UID:
704      self.ExpectPacketOn(netid, msg + ": expecting final ack", finackack)
705    else:
706      self.ClearTunQueues()
707
708  def CheckTCP(self, version, modes):
709    """Checks that incoming TCP connections work.
710
711    Args:
712      version: An integer, 4 or 6.
713      modes: A list of modes to excercise.
714    """
715    for syncookies in [0, 2]:
716      for mode in modes:
717        for netid, iif, ip_if, myaddr, remoteaddr in self.Combinations(version):
718          if mode == self.MODE_UID:
719            listensocket = self.BuildSocket(6, net_test.TCPSocket, netid, mode)
720            listensocket.listen(100)
721          else:
722            listensocket = self.listensocket
723
724          listenport = listensocket.getsockname()[1]
725
726          if HAVE_TCP_MARK_ACCEPT:
727            accept_sysctl = 1 if mode == self.MODE_INCOMING_MARK else 0
728            self._SetTCPMarkAcceptSysctl(accept_sysctl)
729
730          bound_dev = iif if mode == self.MODE_BINDTODEVICE else None
731          self.BindToDevice(listensocket, bound_dev)
732
733          mark = netid if mode == self.MODE_EXPLICIT_MARK else 0
734          self.SetSocketMark(listensocket, mark)
735
736          # Generate the packet here instead of in the outer loop, so
737          # subsequent TCP connections use different source ports and
738          # retransmissions from old connections don't confuse subsequent
739          # tests.
740          desc, packet = Packets.SYN(listenport, version, remoteaddr, myaddr)
741
742          if mode:
743            reply_desc, reply = Packets.SYNACK(version, myaddr, remoteaddr,
744                                               packet)
745          else:
746            reply_desc, reply = None, None
747
748          extra = "mode=%s, syncookies=%d" % (mode, syncookies)
749          msg = self._FormatMessage(iif, ip_if, extra, desc, reply_desc)
750          reply = self._ReceiveAndExpectResponse(netid, packet, reply, msg)
751          if reply:
752            self.CheckTCPConnection(mode, listensocket, netid, version, myaddr,
753                                    remoteaddr, packet, reply, msg)
754
755  def testBasicTCP(self):
756    self.CheckTCP(4, [None, self.MODE_BINDTODEVICE, self.MODE_EXPLICIT_MARK])
757    self.CheckTCP(6, [None, self.MODE_BINDTODEVICE, self.MODE_EXPLICIT_MARK])
758
759  @unittest.skipUnless(HAVE_TCP_MARK_ACCEPT, "fwmark writeback not supported")
760  def testIPv4MarkAccept(self):
761    self.CheckTCP(4, [self.MODE_INCOMING_MARK])
762
763  @unittest.skipUnless(HAVE_TCP_MARK_ACCEPT, "fwmark writeback not supported")
764  def testIPv6MarkAccept(self):
765    self.CheckTCP(6, [self.MODE_INCOMING_MARK])
766
767  @unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
768  def testIPv4UidAccept(self):
769    self.CheckTCP(4, [self.MODE_UID])
770
771  @unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
772  def testIPv6UidAccept(self):
773    self.CheckTCP(6, [self.MODE_UID])
774
775  def testIPv6ExplicitMark(self):
776    self.CheckTCP(6, [self.MODE_EXPLICIT_MARK])
777
778
779class RATest(multinetwork_base.MultiNetworkBaseTest):
780
781  def testDoesNotHaveObsoleteSysctl(self):
782    self.assertFalse(os.path.isfile(
783        "/proc/sys/net/ipv6/route/autoconf_table_offset"))
784
785  @unittest.skipUnless(multinetwork_base.HAVE_AUTOCONF_TABLE,
786                       "no support for per-table autoconf")
787  def testPurgeDefaultRouters(self):
788
789    def CheckIPv6Connectivity(expect_connectivity):
790      for netid in self.NETIDS:
791        s = net_test.UDPSocket(AF_INET6)
792        self.SetSocketMark(s, netid)
793        if expect_connectivity:
794          self.assertTrue(s.sendto(UDP_PAYLOAD, (net_test.IPV6_ADDR, 1234)))
795        else:
796          self.assertRaisesErrno(errno.ENETUNREACH, s.sendto, UDP_PAYLOAD,
797                                 (net_test.IPV6_ADDR, 1234))
798
799    try:
800      CheckIPv6Connectivity(True)
801      self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 1)
802      CheckIPv6Connectivity(False)
803    finally:
804      self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 0)
805      for netid in self.NETIDS:
806        self.SendRA(netid)
807      CheckIPv6Connectivity(True)
808
809  def testOnlinkCommunication(self):
810    """Checks that on-link communication goes direct and not through routers."""
811    for netid in self.tuns:
812      # Send a UDP packet to a random on-link destination.
813      s = net_test.UDPSocket(AF_INET6)
814      iface = self.GetInterfaceName(netid)
815      self.BindToDevice(s, iface)
816      # dstaddr can never be our address because GetRandomDestination only fills
817      # in the lower 32 bits, but our address has 0xff in the byte before that
818      # (since it's constructed from the EUI-64 and so has ff:fe in the middle).
819      dstaddr = self.GetRandomDestination(self.IPv6Prefix(netid))
820      s.sendto(UDP_PAYLOAD, (dstaddr, 53))
821
822      # Expect an NS for that destination on the interface.
823      myaddr = self.MyAddress(6, netid)
824      mymac = self.MyMacAddress(netid)
825      desc, expected = Packets.NS(myaddr, dstaddr, mymac)
826      msg = "Sending UDP packet to on-link destination: expecting %s" % desc
827      time.sleep(0.0001)  # Required to make the test work on kernel 3.1(!)
828      self.ExpectPacketOn(netid, msg, expected)
829
830      # Send an NA.
831      tgtmac = "02:00:00:00:%02x:99" % netid
832      _, reply = Packets.NA(dstaddr, myaddr, tgtmac)
833      # Don't use ReceivePacketOn, since that uses the router's MAC address as
834      # the source. Instead, construct our own Ethernet header with source
835      # MAC of tgtmac.
836      reply = scapy.Ether(src=tgtmac, dst=mymac) / reply
837      self.ReceiveEtherPacketOn(netid, reply)
838
839      # Expect the kernel to send the original UDP packet now that the ND cache
840      # entry has been populated.
841      sport = s.getsockname()[1]
842      desc, expected = Packets.UDP(6, myaddr, dstaddr, sport=sport)
843      msg = "After NA response, expecting %s" % desc
844      self.ExpectPacketOn(netid, msg, expected)
845
846  # This test documents a known issue: routing tables are never deleted.
847  @unittest.skipUnless(multinetwork_base.HAVE_AUTOCONF_TABLE,
848                       "no support for per-table autoconf")
849  def testLeftoverRoutes(self):
850    def GetNumRoutes():
851      return len(open("/proc/net/ipv6_route").readlines())
852
853    num_routes = GetNumRoutes()
854    for i in xrange(10, 20):
855      try:
856        self.tuns[i] = self.CreateTunInterface(i)
857        self.SendRA(i)
858        self.tuns[i].close()
859      finally:
860        del self.tuns[i]
861    self.assertLess(num_routes, GetNumRoutes())
862
863
864class PMTUTest(InboundMarkingTest):
865
866  PAYLOAD_SIZE = 1400
867
868  # Socket options to change PMTU behaviour.
869  IP_MTU_DISCOVER = 10
870  IP_PMTUDISC_DO = 1
871  IPV6_DONTFRAG = 62
872
873  # Socket options to get the MTU.
874  IP_MTU = 14
875  IPV6_PATHMTU = 61
876
877  def GetSocketMTU(self, version, s):
878    if version == 6:
879      ip6_mtuinfo = s.getsockopt(net_test.SOL_IPV6, self.IPV6_PATHMTU, 32)
880      unused_sockaddr, mtu = struct.unpack("=28sI", ip6_mtuinfo)
881      return mtu
882    else:
883      return s.getsockopt(net_test.SOL_IP, self.IP_MTU)
884
885  def DisableFragmentationAndReportErrors(self, version, s):
886    if version == 4:
887      s.setsockopt(net_test.SOL_IP, self.IP_MTU_DISCOVER, self.IP_PMTUDISC_DO)
888      s.setsockopt(net_test.SOL_IP, net_test.IP_RECVERR, 1)
889    else:
890      s.setsockopt(net_test.SOL_IPV6, self.IPV6_DONTFRAG, 1)
891      s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_RECVERR, 1)
892
893  def CheckPMTU(self, version, use_connect, modes):
894
895    def SendBigPacket(version, s, dstaddr, netid, payload):
896      if use_connect:
897        s.send(payload)
898      else:
899        self.SendOnNetid(version, s, dstaddr, 1234, netid, payload, [])
900
901    for netid in self.tuns:
902      for mode in modes:
903        s = self.BuildSocket(version, net_test.UDPSocket, netid, mode)
904        self.DisableFragmentationAndReportErrors(version, s)
905
906        srcaddr = self.MyAddress(version, netid)
907        dst_prefix, intermediate = {
908            4: ("172.19.", "172.16.9.12"),
909            6: ("2001:db8::", "2001:db8::1")
910        }[version]
911        dstaddr = self.GetRandomDestination(dst_prefix)
912
913        if use_connect:
914          s.connect((dstaddr, 1234))
915
916        payload = self.PAYLOAD_SIZE * "a"
917
918        # Send a packet and receive a packet too big.
919        SendBigPacket(version, s, dstaddr, netid, payload)
920        packets = self.ReadAllPacketsOn(netid)
921        self.assertEquals(1, len(packets))
922        _, toobig = Packets.ICMPPacketTooBig(version, intermediate, srcaddr,
923                                             packets[0])
924        self.ReceivePacketOn(netid, toobig)
925
926        # Check that another send on the same socket returns EMSGSIZE.
927        self.assertRaisesErrno(
928            errno.EMSGSIZE,
929            SendBigPacket, version, s, dstaddr, netid, payload)
930
931        # If this is a connected socket, make sure the socket MTU was set.
932        # Note that in IPv4 this only started working in Linux 3.6!
933        if use_connect and (version == 6 or net_test.LINUX_VERSION >= (3, 6)):
934          self.assertEquals(1280, self.GetSocketMTU(version, s))
935
936        s.close()
937
938        # Check that other sockets pick up the PMTU we have been told about by
939        # connecting another socket to the same destination and getting its MTU.
940        # This new socket can use any method to select its outgoing interface;
941        # here we use a mark for simplicity.
942        s2 = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
943        s2.connect((dstaddr, 1234))
944        self.assertEquals(1280, self.GetSocketMTU(version, s2))
945
946        # Also check the MTU reported by ip route get, this time using the oif.
947        routes = self.iproute.GetRoutes(dstaddr, self.ifindices[netid], 0, None)
948        self.assertTrue(routes)
949        route = routes[0]
950        rtmsg, attributes = route
951        self.assertEquals(iproute.RTN_UNICAST, rtmsg.type)
952        metrics = attributes["RTA_METRICS"]
953        self.assertEquals(metrics["RTAX_MTU"], 1280)
954
955  def testIPv4BasicPMTU(self):
956    self.CheckPMTU(4, True, ["mark", "oif"])
957    self.CheckPMTU(4, False, ["mark", "oif"])
958
959  def testIPv6BasicPMTU(self):
960    self.CheckPMTU(6, True, ["mark", "oif"])
961    self.CheckPMTU(6, False, ["mark", "oif"])
962
963  @unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
964  def testIPv4UIDPMTU(self):
965    self.CheckPMTU(4, True, ["uid"])
966    self.CheckPMTU(4, False, ["uid"])
967
968  @unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
969  def testIPv6UIDPMTU(self):
970    self.CheckPMTU(6, True, ["uid"])
971    self.CheckPMTU(6, False, ["uid"])
972
973  # Making Path MTU Discovery work on unmarked  sockets requires that mark
974  # reflection be enabled. Otherwise the kernel has no way to know what routing
975  # table the original packet used, and thus it won't be able to clone the
976  # correct route.
977
978  @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
979  def testIPv4UnmarkedSocketPMTU(self):
980    self.SetMarkReflectSysctls(1)
981    try:
982      self.CheckPMTU(4, False, [None])
983    finally:
984      self.SetMarkReflectSysctls(0)
985
986  @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
987  def testIPv6UnmarkedSocketPMTU(self):
988    self.SetMarkReflectSysctls(1)
989    try:
990      self.CheckPMTU(6, False, [None])
991    finally:
992      self.SetMarkReflectSysctls(0)
993
994
995@unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
996class UidRoutingTest(multinetwork_base.MultiNetworkBaseTest):
997
998  def GetRulesAtPriority(self, version, priority):
999    rules = self.iproute.DumpRules(version)
1000    out = [(rule, attributes) for rule, attributes in rules
1001           if attributes.get("FRA_PRIORITY", 0) == priority]
1002    return out
1003
1004  def CheckInitialTablesHaveNoUIDs(self, version):
1005    rules = []
1006    for priority in [0, 32766, 32767]:
1007      rules.extend(self.GetRulesAtPriority(version, priority))
1008    for _, attributes in rules:
1009      self.assertNotIn("FRA_UID_START", attributes)
1010      self.assertNotIn("FRA_UID_END", attributes)
1011
1012  def testIPv4InitialTablesHaveNoUIDs(self):
1013    self.CheckInitialTablesHaveNoUIDs(4)
1014
1015  def testIPv6InitialTablesHaveNoUIDs(self):
1016    self.CheckInitialTablesHaveNoUIDs(6)
1017
1018  def CheckGetAndSetRules(self, version):
1019    def Random():
1020      return random.randint(1000000, 2000000)
1021
1022    start, end = tuple(sorted([Random(), Random()]))
1023    table = Random()
1024    priority = Random()
1025
1026    try:
1027      self.iproute.UidRangeRule(version, True, start, end, table,
1028                                priority=priority)
1029
1030      rules = self.GetRulesAtPriority(version, priority)
1031      self.assertTrue(rules)
1032      _, attributes = rules[-1]
1033      self.assertEquals(priority, attributes["FRA_PRIORITY"])
1034      self.assertEquals(start, attributes["FRA_UID_START"])
1035      self.assertEquals(end, attributes["FRA_UID_END"])
1036      self.assertEquals(table, attributes["FRA_TABLE"])
1037    finally:
1038      self.iproute.UidRangeRule(version, False, start, end, table,
1039                                priority=priority)
1040
1041  def testIPv4GetAndSetRules(self):
1042    self.CheckGetAndSetRules(4)
1043
1044  def testIPv6GetAndSetRules(self):
1045    self.CheckGetAndSetRules(6)
1046
1047  def ExpectNoRoute(self, addr, oif, mark, uid):
1048    # The lack of a route may be either an error, or an unreachable route.
1049    try:
1050      routes = self.iproute.GetRoutes(addr, oif, mark, uid)
1051      rtmsg, _ = routes[0]
1052      self.assertEquals(iproute.RTN_UNREACHABLE, rtmsg.type)
1053    except IOError, e:
1054      if int(e.errno) != -int(errno.ENETUNREACH):
1055        raise e
1056
1057  def ExpectRoute(self, addr, oif, mark, uid):
1058    routes = self.iproute.GetRoutes(addr, oif, mark, uid)
1059    rtmsg, _ = routes[0]
1060    self.assertEquals(iproute.RTN_UNICAST, rtmsg.type)
1061
1062  def CheckGetRoute(self, version, addr):
1063    self.ExpectNoRoute(addr, 0, 0, 0)
1064    for netid in self.NETIDS:
1065      uid = self.UidForNetid(netid)
1066      self.ExpectRoute(addr, 0, 0, uid)
1067    self.ExpectNoRoute(addr, 0, 0, 0)
1068
1069  def testIPv4RouteGet(self):
1070    self.CheckGetRoute(4, net_test.IPV4_ADDR)
1071
1072  def testIPv6RouteGet(self):
1073    self.CheckGetRoute(6, net_test.IPV6_ADDR)
1074
1075
1076class RulesTest(net_test.NetworkTest):
1077
1078  RULE_PRIORITY = 99999
1079
1080  def setUp(self):
1081    self.iproute = iproute.IPRoute()
1082    for version in [4, 6]:
1083      self.iproute.DeleteRulesAtPriority(version, self.RULE_PRIORITY)
1084
1085  def tearDown(self):
1086    for version in [4, 6]:
1087      self.iproute.DeleteRulesAtPriority(version, self.RULE_PRIORITY)
1088
1089  def testRuleDeletionMatchesTable(self):
1090    for version in [4, 6]:
1091      # Add rules with mark 300 pointing at tables 301 and 302.
1092      # This checks for a kernel bug where deletion request for tables > 256
1093      # ignored the table.
1094      self.iproute.FwmarkRule(version, True, 300, 301,
1095                              priority=self.RULE_PRIORITY)
1096      self.iproute.FwmarkRule(version, True, 300, 302,
1097                              priority=self.RULE_PRIORITY)
1098      # Delete rule with mark 300 pointing at table 302.
1099      self.iproute.FwmarkRule(version, False, 300, 302,
1100                              priority=self.RULE_PRIORITY)
1101      # Check that the rule pointing at table 301 is still around.
1102      attributes = [a for _, a in self.iproute.DumpRules(version)
1103                    if a.get("FRA_PRIORITY", 0) == self.RULE_PRIORITY]
1104      self.assertEquals(1, len(attributes))
1105      self.assertEquals(301, attributes[0]["FRA_TABLE"])
1106
1107
1108if __name__ == "__main__":
1109  unittest.main()
1110