1#!/usr/bin/env python
2# Copyright 2011 Google Inc. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""System integration test for traffic shaping.
17
18Usage:
19$ sudo ./trafficshaper_test.py
20"""
21
22import daemonserver
23import logging
24import platformsettings
25import socket
26import SocketServer
27import trafficshaper
28import unittest
29
30RESPONSE_SIZE_KEY = 'response-size:'
31TEST_DNS_PORT = 5555
32TEST_HTTP_PORT = 8888
33TIMER = platformsettings.timer
34
35
36def GetElapsedMs(start_time, end_time):
37  """Return milliseconds elapsed between |start_time| and |end_time|.
38
39  Args:
40    start_time: seconds as a float (or string representation of float).
41    end_time: seconds as a float (or string representation of float).
42  Return:
43    milliseconds elapsed as integer.
44  """
45  return int((float(end_time) - float(start_time)) * 1000)
46
47
48class TrafficShaperTest(unittest.TestCase):
49
50  def testBadBandwidthRaises(self):
51    self.assertRaises(trafficshaper.BandwidthValueError,
52                      trafficshaper.TrafficShaper,
53                      down_bandwidth='1KBit/s')
54
55
56class TimedUdpHandler(SocketServer.DatagramRequestHandler):
57  """UDP handler that returns the time when the request was handled."""
58
59  def handle(self):
60    data = self.rfile.read()
61    read_time = self.server.timer()
62    self.wfile.write(str(read_time))
63
64
65class TimedTcpHandler(SocketServer.StreamRequestHandler):
66  """Tcp handler that returns the time when the request was read.
67
68  It can respond with the number of bytes specified in the request.
69  The request looks like:
70    request_data -> RESPONSE_SIZE_KEY num_response_bytes '\n' ANY_DATA
71  """
72
73  def handle(self):
74    data = self.rfile.read()
75    read_time = self.server.timer()
76    contents = str(read_time)
77    if data.startswith(RESPONSE_SIZE_KEY):
78      num_response_bytes = int(data[len(RESPONSE_SIZE_KEY):data.index('\n')])
79      contents = '%s\n%s' % (contents,
80                             '\x00' * (num_response_bytes - len(contents) - 1))
81    self.wfile.write(contents)
82
83
84class TimedUdpServer(SocketServer.ThreadingUDPServer,
85                     daemonserver.DaemonServer):
86  """A simple UDP server similar to dnsproxy."""
87
88  # Override SocketServer.TcpServer setting to avoid intermittent errors.
89  allow_reuse_address = True
90
91  def __init__(self, host, port, timer=TIMER):
92    SocketServer.ThreadingUDPServer.__init__(
93        self, (host, port), TimedUdpHandler)
94    self.timer = timer
95
96  def cleanup(self):
97    pass
98
99
100class TimedTcpServer(SocketServer.ThreadingTCPServer,
101                     daemonserver.DaemonServer):
102  """A simple TCP server similar to httpproxy."""
103
104  # Override SocketServer.TcpServer setting to avoid intermittent errors.
105  allow_reuse_address = True
106
107  def __init__(self, host, port, timer=TIMER):
108    SocketServer.ThreadingTCPServer.__init__(
109        self, (host, port), TimedTcpHandler)
110    self.timer = timer
111
112  def cleanup(self):
113    try:
114      self.shutdown()
115    except KeyboardInterrupt, e:
116      pass
117
118
119class TcpTestSocketCreator(object):
120  """A TCP socket creator suitable for with-statement."""
121
122  def __init__(self, host, port, timeout=1.0):
123    self.address = (host, port)
124    self.timeout = timeout
125
126  def __enter__(self):
127    self.socket = socket.create_connection(self.address, timeout=self.timeout)
128    return self.socket
129
130  def __exit__(self, *args):
131    self.socket.close()
132
133
134class TimedTestCase(unittest.TestCase):
135  def assertValuesAlmostEqual(self, expected, actual, tolerance=0.05):
136    """Like the following with nicer default message:
137           assertTrue(expected <= actual + tolerance &&
138                      expected >= actual - tolerance)
139    """
140    delta = tolerance * expected
141    if actual > expected + delta or actual < expected - delta:
142      self.fail('%s is not equal to expected %s +/- %s%%' % (
143              actual, expected, 100 * tolerance))
144
145
146class TcpTrafficShaperTest(TimedTestCase):
147
148  def setUp(self):
149    self.host = platformsettings.get_server_ip_address()
150    self.port = TEST_HTTP_PORT
151    self.tcp_socket_creator = TcpTestSocketCreator(self.host, self.port)
152    self.timer = TIMER
153
154  def TrafficShaper(self, **kwargs):
155    return trafficshaper.TrafficShaper(
156        host=self.host, ports=(self.port,), **kwargs)
157
158  def GetTcpSendTimeMs(self, num_bytes):
159    """Return time in milliseconds to send |num_bytes|."""
160
161    with self.tcp_socket_creator as s:
162      start_time = self.timer()
163      request_data = '\x00' * num_bytes
164
165      s.sendall(request_data)
166      # TODO(slamm): Figure out why partial is shutdown needed to make it work.
167      s.shutdown(socket.SHUT_WR)
168      read_time = s.recv(1024)
169    return GetElapsedMs(start_time, read_time)
170
171  def GetTcpReceiveTimeMs(self, num_bytes):
172    """Return time in milliseconds to receive |num_bytes|."""
173
174    with self.tcp_socket_creator as s:
175      s.sendall('%s%s\n' % (RESPONSE_SIZE_KEY, num_bytes))
176      # TODO(slamm): Figure out why partial is shutdown needed to make it work.
177      s.shutdown(socket.SHUT_WR)
178      num_remaining_bytes = num_bytes
179      read_time = None
180      while num_remaining_bytes > 0:
181        response_data = s.recv(4096)
182        num_remaining_bytes -= len(response_data)
183        if not read_time:
184          read_time, padding = response_data.split('\n')
185    return GetElapsedMs(read_time, self.timer())
186
187  def testTcpConnectToIp(self):
188    """Verify that it takes |delay_ms| to establish a TCP connection."""
189    if not platformsettings.has_ipfw():
190      logging.warning('ipfw is not available in path. Skip the test')
191      return
192    with TimedTcpServer(self.host, self.port):
193      for delay_ms in (100, 175):
194        with self.TrafficShaper(delay_ms=delay_ms):
195          start_time = self.timer()
196          with self.tcp_socket_creator:
197            connect_time = GetElapsedMs(start_time, self.timer())
198        self.assertValuesAlmostEqual(delay_ms, connect_time, tolerance=0.12)
199
200  def testTcpUploadShaping(self):
201    """Verify that 'up' bandwidth is shaped on TCP connections."""
202    if not platformsettings.has_ipfw():
203      logging.warning('ipfw is not available in path. Skip the test')
204      return
205    num_bytes = 1024 * 100
206    bandwidth_kbits = 2000
207    expected_ms = 8.0 * num_bytes / bandwidth_kbits
208    with TimedTcpServer(self.host, self.port):
209      with self.TrafficShaper(up_bandwidth='%sKbit/s' % bandwidth_kbits):
210        self.assertValuesAlmostEqual(expected_ms, self.GetTcpSendTimeMs(num_bytes))
211
212  def testTcpDownloadShaping(self):
213    """Verify that 'down' bandwidth is shaped on TCP connections."""
214    if not platformsettings.has_ipfw():
215      logging.warning('ipfw is not available in path. Skip the test')
216      return
217    num_bytes = 1024 * 100
218    bandwidth_kbits = 2000
219    expected_ms = 8.0 * num_bytes / bandwidth_kbits
220    with TimedTcpServer(self.host, self.port):
221      with self.TrafficShaper(down_bandwidth='%sKbit/s' % bandwidth_kbits):
222        self.assertValuesAlmostEqual(expected_ms, self.GetTcpReceiveTimeMs(num_bytes))
223
224  def testTcpInterleavedDownloads(self):
225    # TODO(slamm): write tcp interleaved downloads test
226    pass
227
228
229class UdpTrafficShaperTest(TimedTestCase):
230
231  def setUp(self):
232    self.host = platformsettings.get_server_ip_address()
233    self.dns_port = TEST_DNS_PORT
234    self.timer = TIMER
235
236  def TrafficShaper(self, **kwargs):
237    return trafficshaper.TrafficShaper(
238        host=self.host, ports=(self.dns_port,), **kwargs)
239
240  def GetUdpSendReceiveTimesMs(self):
241    """Return time in milliseconds to send |num_bytes|."""
242    start_time = self.timer()
243    udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
244    udp_socket.sendto('test data\n', (self.host, self.dns_port))
245    read_time = udp_socket.recv(1024)
246    return (GetElapsedMs(start_time, read_time),
247            GetElapsedMs(read_time, self.timer()))
248
249  def testUdpDelay(self):
250    if not platformsettings.has_ipfw():
251      logging.warning('ipfw is not available in path. Skip the test')
252      return
253    for delay_ms in (100, 170):
254      expected_ms = delay_ms / 2
255      with TimedUdpServer(self.host, self.dns_port):
256        with self.TrafficShaper(delay_ms=delay_ms):
257          send_ms, receive_ms = self.GetUdpSendReceiveTimesMs()
258          self.assertValuesAlmostEqual(expected_ms, send_ms, tolerance=0.10)
259          self.assertValuesAlmostEqual(expected_ms, receive_ms, tolerance=0.10)
260
261
262  def testUdpInterleavedDelay(self):
263    # TODO(slamm): write udp interleaved udp delay test
264    pass
265
266
267class TcpAndUdpTrafficShaperTest(TimedTestCase):
268  # TODO(slamm): Test concurrent TCP and UDP traffic
269  pass
270
271
272# TODO(slamm): Packet loss rate (try different ports)
273
274
275if __name__ == '__main__':
276  #logging.getLogger().setLevel(logging.DEBUG)
277  unittest.main()
278