1# Copyright (c) 2013 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import threading
6import select
7import socket
8import subprocess
9import sys
10import unittest
11
12from lansim import host
13from lansim import simulator
14from lansim import tuntap
15
16
17def raise_exception():
18    """Raises an exception."""
19    raise Exception('Something bad.')
20
21
22class InfoTCPServer(threading.Thread):
23    """A TCP server running on a separated thread.
24
25    This simple TCP server thread listen for connections for every new
26    connection it sends the address information of the connected client.
27    """
28    def __init__(self, host, port):
29        """Creates the TCP server on the host:port address.
30
31        @param host: The IP address in plain text.
32        @param port: The TCP port number where the server listens on."""
33        threading.Thread.__init__(self)
34        self._port = port
35        self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
36        self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
37        self._sock.bind((host, port))
38        self._sock.listen(1)
39        self._must_exit = False
40
41
42    def run(self):
43        while not self._must_exit:
44            # Check the must_exit flag every second.
45            rlist, wlist, xlist = select.select([self._sock], [], [], 1.)
46            if self._sock in rlist:
47                conn, (addr, port) = self._sock.accept()
48                # Send back the client address, port and our port
49                conn.send('%s %d %d' % (addr, port,  self._port))
50                conn.close()
51        self._sock.close()
52
53
54    def stop(self):
55        """Signal the termination of the running thread."""
56        self._must_exit = True
57
58
59def GetInfoTCP(host, port):
60    """Connects to a InfoTCPServer on host:port and reads all the information.
61
62    @param host: The host where the InfoTCPServer is running.
63    @param port: The port where the InfoTCPServer is running.
64    """
65    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
66    sock.connect((host, port))
67    data = sock.recv(1024)
68    sock.close()
69    return data
70
71
72class SimulatorTest(unittest.TestCase):
73    """Unit tests for the Simulator class."""
74
75    def setUp(self):
76        """Creates a Simulator under test over a TAP device."""
77        self._tap = tuntap.TunTap(tuntap.IFF_TAP, name="faketap")
78        # According to RFC 3927 (Dynamic Configuration of IPv4 Link-Local
79        # Addresses), a host can pseudorandomly assign an IPv4 address on the
80        # 169.254/16 network to communicate with other devices on the same link
81        # on absence of a DHCP server and other source of network configuration.
82        # The tests on this class explicitly specify the interface to use, so
83        # they can run in parallel even when there are more than one interface
84        # with the same IPv4 address. A TUN/TAP interface with an IPv4 address
85        # on this range shouldn't collide with any useful service running on a
86        # different (physical) interface.
87        self._tap.set_addr('169.254.11.11')
88        self._tap.up()
89
90        self._sim = simulator.Simulator(self._tap)
91
92
93    def tearDown(self):
94        """Stops and destroy the interface."""
95        self._tap.down()
96
97
98    def testTimeout(self):
99        """Tests that the Simulator can start and run for a short time."""
100        # Run for at most 100ms and finish the test. This implies that the
101        # stop() method works.
102        self._sim.run(timeout=0.1)
103
104
105    def testRemoveTimeout(self):
106        """Tests that the Simulator can remove unfired timeout calls."""
107        # Schedule the callback far in time, run the simulator for a short time
108        # and remove it.
109        self._sim.add_timeout(60, raise_exception)
110        self._sim.run(timeout=0.1)
111        self.assertTrue(self._sim.remove_timeout(raise_exception))
112        self.assertFalse(self._sim.remove_timeout(raise_exception))
113
114
115    def testUntil(self):
116        """Tests that the Simulator can start run until a condition is met."""
117        tasks_done = []
118        # After 0.2 seconds we add a task to tasks_done that should break the
119        # loop. If it doesn't, the a second value will be added making the test
120        # fail.
121        self._sim.add_timeout(0.2, lambda: tasks_done.append('good task'))
122        self._sim.add_timeout(4.0, lambda: tasks_done.append('bad task'))
123        self._sim.run(timeout=5.0, until=lambda: tasks_done)
124        self.assertEqual(len(tasks_done), 1)
125
126
127    def testHost(self):
128        """Tests that the Simulator can add rules from the SimpleHost."""
129        # The IP and MAC addresses simulated are unknown to the rest of the
130        # system as they only live on this interface. Again, any IP on the
131        # network 169.254/16 should not cause any problem with other services
132        # running on this host.
133        host.SimpleHost(self._sim, '12:34:56:78:90:AB', '169.254.11.22')
134        self._sim.run(timeout=0.1)
135
136
137class SimulatorThreadTest(unittest.TestCase):
138    """Unit tests for the SimulatorThread class."""
139
140    def setUp(self):
141        """Creates a SimulatorThread under test over a TAP device."""
142        self._tap = tuntap.TunTap(tuntap.IFF_TAP, name="faketap")
143        # See note about IP addresses on SimulatorTest.setUp().
144        self._ip_addr = '169.254.11.11'
145        self._tap.set_addr(self._ip_addr)
146        self._tap.up()
147
148        # 20 seconds timeout for unittest completion (they should run in about
149        # 2 seconds each).
150        self._sim = simulator.SimulatorThread(self._tap, timeout=20)
151
152
153    def tearDown(self):
154        """Stops and destroy the thread."""
155        self._sim.stop() # stop() is idempotent.
156        self._sim.join()
157        self._tap.down()
158        if self._sim.error:
159            sys.stderr.write('SimulatorThread exception: %r' % self._sim.error)
160            sys.stderr.write(self._sim.traceback)
161            raise self._sim.error
162
163
164    def testError(self):
165        """Exceptions raised on the thread appear on the exc_info member."""
166        self._sim.add_timeout(0.1, raise_exception)
167        self._sim.start()
168        self._sim.join()
169        self.assertEqual(self._sim.error.message, 'Something bad.')
170        # Clean the error before tearDown()
171        self._sim.error = None
172
173
174    def testARPPing(self):
175        """Test that the simulator properly handles a ARP request/response."""
176        host.SimpleHost(self._sim, '12:34:56:78:90:22', '169.254.11.22')
177        host.SimpleHost(self._sim, '12:34:56:78:90:33', '169.254.11.33')
178        host.SimpleHost(self._sim, '12:34:56:78:90:44', '169.254.11.33')
179
180        self._sim.start()
181        # arping and wait for one second for the responses.
182        out = subprocess.check_output(
183                ['arping', '-I', self._tap.name, '169.254.11.22',
184                 '-c', '1', '-w', '1'])
185        resp = [line for line in out.splitlines() if 'Unicast reply' in line]
186        self.assertEqual(len(resp), 1)
187        self.assertTrue(resp[0].startswith(
188                'Unicast reply from 169.254.11.22 [12:34:56:78:90:22]'))
189
190        out = subprocess.check_output(
191                ['arping', '-I', self._tap.name, '169.254.11.33',
192                 '-c', '1', '-w', '1'])
193        resp = [line for line in out.splitlines() if 'Unicast reply' in line]
194        self.assertEqual(len(resp), 2)
195        resp.sort()
196        self.assertTrue(resp[0].startswith(
197                'Unicast reply from 169.254.11.33 [12:34:56:78:90:33]'))
198        self.assertTrue(resp[1].startswith(
199                'Unicast reply from 169.254.11.33 [12:34:56:78:90:44]'))
200
201
202    def testTCPForward(self):
203        """Host can forward TCP traffic back to the kernel network stack."""
204        h = host.SimpleHost(self._sim, '12:34:56:78:90:22', '169.254.11.22')
205        # Launch two TCP servers on the network interface end.
206        srv1 = InfoTCPServer(self._ip_addr, 1080)
207        srv1.start()
208        srv2 = InfoTCPServer(self._ip_addr, 1081)
209        srv2.start()
210
211        # Map those two ports to a given IP address on the fake network.
212        h.tcp_forward(80, self._ip_addr, 1080)
213        h.tcp_forward(81, self._ip_addr, 1081)
214
215        # Start the simulation.
216        self._sim.start()
217
218        try:
219            srv1data = GetInfoTCP('169.254.11.22', 80)
220            srv2data = GetInfoTCP('169.254.11.22', 81)
221        finally:
222            srv1.stop()
223            srv2.stop()
224            srv1.join()
225            srv2.join()
226
227        # First connection is seen from the .11.22:1024 client.
228        self.assertEqual(srv1data, '169.254.11.22 1024 1080')
229        # Second connection is seen from the .11.22:1024 client because is made
230        # to a different port.
231        self.assertEqual(srv2data, '169.254.11.22 1024 1081')
232
233
234    def testWaitForCondition(self):
235        """Main thread can wait until a condition is met on the simulator."""
236        self._sim.start()
237
238        # Wait for an always False condition.
239        condition = lambda: False
240        ret = self._sim.wait_for_condition(condition, timeout=1.5)
241        self.assertFalse(ret)
242
243        # Wait for a trivially True condition.
244        condition = lambda: True
245        ret = self._sim.wait_for_condition(condition, timeout=10.)
246        self.assertTrue(ret)
247
248        # Without timeout.
249        ret = self._sim.wait_for_condition(condition, timeout=None)
250        self.assertTrue(ret)
251
252        # Wait for a condition that takes 3 calls to meet.
253        var = []
254        condition = lambda: var if len(var) == 3 else var.append(None)
255        ret = self._sim.wait_for_condition(condition, timeout=10.)
256        self.assertEqual(len(ret), 3)
257
258if __name__ == '__main__':
259    unittest.main()
260