1#!/usr/bin/python 2# Copyright 2017 The Chromium OS Authors. All rights reserved. 3# Use of this source code is governed by a BSD-style license that can be 4# found in the LICENSE file. 5 6import Queue 7import array 8import collections 9import os 10import shutil 11import tempfile 12import threading 13import unittest 14from contextlib import contextmanager 15from multiprocessing import connection 16 17import common 18from autotest_lib.site_utils import lxc 19from autotest_lib.site_utils.lxc import unittest_setup 20from autotest_lib.site_utils.lxc.container_pool import message 21from autotest_lib.site_utils.lxc.container_pool import service 22from autotest_lib.site_utils.lxc.container_pool import unittest_client 23 24 25FakeHostDir = collections.namedtuple('FakeHostDir', ['path']) 26 27 28class ServiceTests(unittest.TestCase): 29 """Unit tests for the Service class.""" 30 31 @classmethod 32 def setUpClass(cls): 33 """Creates a directory for running the unit tests. """ 34 # Explicitly use /tmp as the tmpdir. Board specific TMPDIRs inside of 35 # the chroot are set to a path that causes the socket address to exceed 36 # the maximum allowable length. 37 cls.test_dir = tempfile.mkdtemp(prefix='service_unittest_', dir='/tmp') 38 39 40 @classmethod 41 def tearDownClass(cls): 42 """Deletes the test directory. """ 43 shutil.rmtree(cls.test_dir) 44 45 46 def setUp(self): 47 """Per-test setup.""" 48 # Put each test in its own test dir, so it's hermetic. 49 self.test_dir = tempfile.mkdtemp(dir=ServiceTests.test_dir) 50 self.host_dir = FakeHostDir(self.test_dir) 51 self.address = os.path.join(self.test_dir, 52 lxc.DEFAULT_CONTAINER_POOL_SOCKET) 53 54 55 def testConnection(self): 56 """Tests a simple connection to the pool service.""" 57 with self.run_service(): 58 self.assertTrue(self._pool_is_healthy()) 59 60 61 def testAbortedConnection(self): 62 """Tests that a closed connection doesn't crash the service.""" 63 with self.run_service(): 64 client = connection.Client(self.address) 65 client.close() 66 self.assertTrue(self._pool_is_healthy()) 67 68 69 def testCorruptedMessage(self): 70 """Tests that corrupted messages don't crash the service.""" 71 with self.run_service(), self.create_client() as client: 72 # Send a raw array of bytes. This will cause an unpickling error. 73 client.send_bytes(array.array('i', range(1, 10))) 74 # Verify that the container pool closed the connection. 75 with self.assertRaises(EOFError): 76 client.recv() 77 # Verify that the main container pool service is still alive. 78 self.assertTrue(self._pool_is_healthy()) 79 80 81 def testInvalidMessageClass(self): 82 """Tests that bad messages don't crash the service.""" 83 with self.run_service(), self.create_client() as client: 84 # Send a valid object but not of the right Message class. 85 client.send('foo') 86 # Verify that the container pool closed the connection. 87 with self.assertRaises(EOFError): 88 client.recv() 89 # Verify that the main container pool service is still alive. 90 self.assertTrue(self._pool_is_healthy()) 91 92 93 def testInvalidMessageType(self): 94 """Tests that messages with a bad type don't crash the service.""" 95 with self.run_service(), self.create_client() as client: 96 # Send a valid object but not of the right Message class. 97 client.send(message.Message('foo', None)) 98 # Verify that the container pool closed the connection. 99 with self.assertRaises(EOFError): 100 client.recv() 101 # Verify that the main container pool service is still alive. 102 self.assertTrue(self._pool_is_healthy()) 103 104 105 def testStop(self): 106 """Tests stopping the service.""" 107 with self.run_service() as svc, self.create_client() as client: 108 self.assertTrue(svc.is_running()) 109 client.send(message.shutdown()) 110 client.recv() # wait for ack 111 self.assertFalse(svc.is_running()) 112 113 114 def testStatus(self): 115 """Tests querying service status.""" 116 pool = MockPool() 117 with self.run_service(pool) as svc, self.create_client() as client: 118 client.send(message.status()) 119 status = client.recv() 120 self.assertTrue(status['running']) 121 self.assertEqual(self.address, status['socket_path']) 122 self.assertEqual(pool.capacity, status['pool capacity']) 123 self.assertEqual(pool.size, status['pool size']) 124 self.assertEqual(pool.worker_count, status['pool worker count']) 125 self.assertEqual(pool.errors.qsize(), status['pool errors']) 126 127 # Change some values, ensure the changes are reflected. 128 pool.capacity = 42 129 pool.size = 19 130 pool.worker_count = 3 131 error_count = 8 132 for e in range(error_count): 133 pool.errors.put(e) 134 client.send(message.status()) 135 status = client.recv() 136 self.assertTrue(status['running']) 137 self.assertEqual(self.address, status['socket_path']) 138 self.assertEqual(pool.capacity, status['pool capacity']) 139 self.assertEqual(pool.size, status['pool size']) 140 self.assertEqual(pool.worker_count, status['pool worker count']) 141 self.assertEqual(pool.errors.qsize(), status['pool errors']) 142 143 144 def testGet(self): 145 """Tests getting a container from the pool.""" 146 test_pool = MockPool() 147 fake_container = MockContainer() 148 test_id = lxc.ContainerId.create(42) 149 test_pool.containers.put(fake_container) 150 151 with self.run_service(test_pool): 152 with self.create_client() as client: 153 client.send(message.get(test_id)) 154 test_container = client.recv() 155 self.assertEqual(test_id, test_container.id) 156 157 158 def testGet_timeoutImmediate(self): 159 """Tests getting a container with timeouts.""" 160 test_id = lxc.ContainerId.create(42) 161 with self.run_service(): 162 with self.create_client() as client: 163 client.send(message.get(test_id)) 164 test_container = client.recv() 165 self.assertIsNone(test_container) 166 167 168 def testGet_timeoutDelayed(self): 169 """Tests getting a container with timeouts.""" 170 test_id = lxc.ContainerId.create(42) 171 with self.run_service(): 172 with self.create_client() as client: 173 client.send(message.get(test_id, timeout=1)) 174 test_container = client.recv() 175 self.assertIsNone(test_container) 176 177 178 def testMultipleClients(self): 179 """Tests multiple simultaneous connections.""" 180 with self.run_service(): 181 with self.create_client() as client0: 182 with self.create_client() as client1: 183 184 msg0 = 'two driven jocks help fax my big quiz' 185 msg1 = 'how quickly daft jumping zebras vex' 186 187 client0.send(message.echo(msg0)) 188 client1.send(message.echo(msg1)) 189 190 echo0 = client0.recv() 191 echo1 = client1.recv() 192 193 self.assertEqual(msg0, echo0) 194 self.assertEqual(msg1, echo1) 195 196 197 def _pool_is_healthy(self): 198 """Verifies that the pool service is still functioning. 199 200 Sends an echo message and tests for a response. This is a stronger 201 signal of aliveness than checking Service.is_running, but a False return 202 value does not necessarily indicate that the pool service shut down 203 cleanly. Use Service.is_running to check that. 204 """ 205 with self.create_client() as client: 206 msg = 'foobar' 207 client.send(message.echo(msg)) 208 return client.recv() == msg 209 210 211 @contextmanager 212 def run_service(self, pool=None): 213 """Creates and cleans up a Service instance.""" 214 if pool is None: 215 pool = MockPool() 216 svc = service.Service(self.host_dir, pool) 217 thread = threading.Thread(name='service', target=svc.start) 218 thread.start() 219 try: 220 yield svc 221 finally: 222 svc.stop() 223 thread.join(1) 224 225 226 @contextmanager 227 def create_client(self): 228 """Creates and cleans up a client connection.""" 229 client = unittest_client.connect(self.address) 230 try: 231 yield client 232 finally: 233 client.close() 234 235 236class MockPool(object): 237 """A mock pool class for testing the service.""" 238 239 def __init__(self): 240 """Initializes a mock empty pool.""" 241 self.capacity = 0 242 self.size = 0 243 self.worker_count = 0 244 self.errors = Queue.Queue() 245 self.containers = Queue.Queue() 246 247 248 def cleanup(self): 249 """Required by pool interface. Does nothing.""" 250 pass 251 252 253 def get(self, timeout=0): 254 """Required by pool interface. 255 256 @return: A pool from the containers queue. 257 """ 258 try: 259 return self.containers.get(block=(timeout > 0), timeout=timeout) 260 except Queue.Empty: 261 return None 262 263 264class MockContainer(object): 265 """A mock container class for testing the service.""" 266 267 def __init__(self): 268 """Initializes a mock container.""" 269 self.id = None 270 self.name = 'test_container' 271 272 273if __name__ == '__main__': 274 unittest_setup.setup(require_sudo=False) 275 unittest.main() 276