1#!/usr/bin/python
2#
3# Copyright 2015 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import itertools
18import random
19import unittest
20
21from socket import *
22
23import iproute
24import multinetwork_base
25import net_test
26import packets
27
28
29class ForwardingTest(multinetwork_base.MultiNetworkBaseTest):
30
31  TCP_TIME_WAIT = 6
32
33  def ForwardBetweenInterfaces(self, enabled, iface1, iface2):
34    for iif, oif in itertools.permutations([iface1, iface2]):
35      self.iproute.IifRule(6, enabled, self.GetInterfaceName(iif),
36                           self._TableForNetid(oif), self.PRIORITY_IIF)
37
38  def setUp(self):
39    self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 1)
40
41  def tearDown(self):
42    self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 0)
43
44  def CheckForwardingCrash(self, netid, iface1, iface2):
45    listenport = packets.RandomPort()
46    listensocket = net_test.IPv6TCPSocket()
47    listensocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
48    listensocket.bind(("::", listenport))
49    listensocket.listen(100)
50    self.SetSocketMark(listensocket, netid)
51
52    version = 6
53    remoteaddr = self.GetRemoteAddress(version)
54    myaddr = self.MyAddress(version, netid)
55
56    desc, syn = packets.SYN(listenport, version, remoteaddr, myaddr)
57    synack_desc, synack = packets.SYNACK(version, myaddr, remoteaddr, syn)
58    msg = "Sent %s, expected %s" % (desc, synack_desc)
59    reply = self._ReceiveAndExpectResponse(netid, syn, synack, msg)
60
61    establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
62    self.ReceivePacketOn(netid, establishing_ack)
63    accepted, peer = listensocket.accept()
64    remoteport = accepted.getpeername()[1]
65
66    accepted.close()
67    desc, fin = packets.FIN(version, myaddr, remoteaddr, establishing_ack)
68    self.ExpectPacketOn(netid, msg + ": expecting %s after close" % desc, fin)
69
70    desc, finack = packets.FIN(version, remoteaddr, myaddr, fin)
71    self.ReceivePacketOn(netid, finack)
72
73    # Check our socket is now in TIME_WAIT.
74    sockets = self.ReadProcNetSocket("tcp6")
75    mysrc = "%s:%04X" % (net_test.FormatSockStatAddress(myaddr), listenport)
76    mydst = "%s:%04X" % (net_test.FormatSockStatAddress(remoteaddr), remoteport)
77    state = None
78    sockets = [s for s in sockets if s[0] == mysrc and s[1] == mydst]
79    self.assertEquals(1, len(sockets))
80    self.assertEquals("%02X" % self.TCP_TIME_WAIT, sockets[0][2])
81
82    # Remove our IP address.
83    try:
84      self.iproute.DelAddress(myaddr, 64, self.ifindices[netid])
85
86      self.ReceivePacketOn(iface1, finack)
87      self.ReceivePacketOn(iface1, establishing_ack)
88      self.ReceivePacketOn(iface1, establishing_ack)
89      # No crashes? Good.
90
91    finally:
92      # Put back our IP address.
93      self.SendRA(netid)
94      listensocket.close()
95
96  def testCrash(self):
97    # Run the test a few times as it doesn't crash/hang the first time.
98    for netids in itertools.permutations(self.tuns):
99      # Pick an interface to send traffic on and two to forward traffic between.
100      netid, iface1, iface2 = random.sample(netids, 3)
101      self.ForwardBetweenInterfaces(True, iface1, iface2)
102      try:
103        self.CheckForwardingCrash(netid, iface1, iface2)
104      finally:
105        self.ForwardBetweenInterfaces(False, iface1, iface2)
106
107
108if __name__ == "__main__":
109  unittest.main()
110