1#!/usr/bin/python
2# Copyright 2017 The Chromium OS Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5
6import Queue
7import array
8import collections
9import os
10import shutil
11import tempfile
12import threading
13import unittest
14from contextlib import contextmanager
15from multiprocessing import connection
16
17import common
18from autotest_lib.site_utils import lxc
19from autotest_lib.site_utils.lxc import unittest_setup
20from autotest_lib.site_utils.lxc.container_pool import message
21from autotest_lib.site_utils.lxc.container_pool import service
22from autotest_lib.site_utils.lxc.container_pool import unittest_client
23
24
25FakeHostDir = collections.namedtuple('FakeHostDir', ['path'])
26
27
28class ServiceTests(unittest.TestCase):
29    """Unit tests for the Service class."""
30
31    @classmethod
32    def setUpClass(cls):
33        """Creates a directory for running the unit tests. """
34        # Explicitly use /tmp as the tmpdir.  Board specific TMPDIRs inside of
35        # the chroot are set to a path that causes the socket address to exceed
36        # the maximum allowable length.
37        cls.test_dir = tempfile.mkdtemp(prefix='service_unittest_', dir='/tmp')
38
39
40    @classmethod
41    def tearDownClass(cls):
42        """Deletes the test directory. """
43        shutil.rmtree(cls.test_dir)
44
45
46    def setUp(self):
47        """Per-test setup."""
48        # Put each test in its own test dir, so it's hermetic.
49        self.test_dir = tempfile.mkdtemp(dir=ServiceTests.test_dir)
50        self.host_dir = FakeHostDir(self.test_dir)
51        self.address = os.path.join(self.test_dir,
52                                    lxc.DEFAULT_CONTAINER_POOL_SOCKET)
53
54
55    def testConnection(self):
56        """Tests a simple connection to the pool service."""
57        with self.run_service():
58            self.assertTrue(self._pool_is_healthy())
59
60
61    def testAbortedConnection(self):
62        """Tests that a closed connection doesn't crash the service."""
63        with self.run_service():
64            client = connection.Client(self.address)
65            client.close()
66            self.assertTrue(self._pool_is_healthy())
67
68
69    def testCorruptedMessage(self):
70        """Tests that corrupted messages don't crash the service."""
71        with self.run_service(), self.create_client() as client:
72            # Send a raw array of bytes.  This will cause an unpickling error.
73            client.send_bytes(array.array('i', range(1, 10)))
74            # Verify that the container pool closed the connection.
75            with self.assertRaises(EOFError):
76                client.recv()
77            # Verify that the main container pool service is still alive.
78            self.assertTrue(self._pool_is_healthy())
79
80
81    def testInvalidMessageClass(self):
82        """Tests that bad messages don't crash the service."""
83        with self.run_service(), self.create_client() as client:
84            # Send a valid object but not of the right Message class.
85            client.send('foo')
86            # Verify that the container pool closed the connection.
87            with self.assertRaises(EOFError):
88                client.recv()
89            # Verify that the main container pool service is still alive.
90            self.assertTrue(self._pool_is_healthy())
91
92
93    def testInvalidMessageType(self):
94        """Tests that messages with a bad type don't crash the service."""
95        with self.run_service(), self.create_client() as client:
96            # Send a valid object but not of the right Message class.
97            client.send(message.Message('foo', None))
98            # Verify that the container pool closed the connection.
99            with self.assertRaises(EOFError):
100                client.recv()
101            # Verify that the main container pool service is still alive.
102            self.assertTrue(self._pool_is_healthy())
103
104
105    def testStop(self):
106        """Tests stopping the service."""
107        with self.run_service() as svc, self.create_client() as client:
108            self.assertTrue(svc.is_running())
109            client.send(message.shutdown())
110            client.recv()  # wait for ack
111            self.assertFalse(svc.is_running())
112
113
114    def testStatus(self):
115        """Tests querying service status."""
116        pool = MockPool()
117        with self.run_service(pool) as svc, self.create_client() as client:
118            client.send(message.status())
119            status = client.recv()
120            self.assertTrue(status['running'])
121            self.assertEqual(self.address, status['socket_path'])
122            self.assertEqual(pool.capacity, status['pool capacity'])
123            self.assertEqual(pool.size, status['pool size'])
124            self.assertEqual(pool.worker_count, status['pool worker count'])
125            self.assertEqual(pool.errors.qsize(), status['pool errors'])
126
127            # Change some values, ensure the changes are reflected.
128            pool.capacity = 42
129            pool.size = 19
130            pool.worker_count = 3
131            error_count = 8
132            for e in range(error_count):
133                pool.errors.put(e)
134            client.send(message.status())
135            status = client.recv()
136            self.assertTrue(status['running'])
137            self.assertEqual(self.address, status['socket_path'])
138            self.assertEqual(pool.capacity, status['pool capacity'])
139            self.assertEqual(pool.size, status['pool size'])
140            self.assertEqual(pool.worker_count, status['pool worker count'])
141            self.assertEqual(pool.errors.qsize(), status['pool errors'])
142
143
144    def testGet(self):
145        """Tests getting a container from the pool."""
146        test_pool = MockPool()
147        fake_container = MockContainer()
148        test_id = lxc.ContainerId.create(42)
149        test_pool.containers.put(fake_container)
150
151        with self.run_service(test_pool):
152            with self.create_client() as client:
153                client.send(message.get(test_id))
154                test_container = client.recv()
155                self.assertEqual(test_id, test_container.id)
156
157
158    def testGet_timeoutImmediate(self):
159        """Tests getting a container with timeouts."""
160        test_id = lxc.ContainerId.create(42)
161        with self.run_service():
162            with self.create_client() as client:
163                client.send(message.get(test_id))
164                test_container = client.recv()
165                self.assertIsNone(test_container)
166
167
168    def testGet_timeoutDelayed(self):
169        """Tests getting a container with timeouts."""
170        test_id = lxc.ContainerId.create(42)
171        with self.run_service():
172            with self.create_client() as client:
173                client.send(message.get(test_id, timeout=1))
174                test_container = client.recv()
175                self.assertIsNone(test_container)
176
177
178    def testMultipleClients(self):
179        """Tests multiple simultaneous connections."""
180        with self.run_service():
181            with self.create_client() as client0:
182                with self.create_client() as client1:
183
184                    msg0 = 'two driven jocks help fax my big quiz'
185                    msg1 = 'how quickly daft jumping zebras vex'
186
187                    client0.send(message.echo(msg0))
188                    client1.send(message.echo(msg1))
189
190                    echo0 = client0.recv()
191                    echo1 = client1.recv()
192
193                    self.assertEqual(msg0, echo0)
194                    self.assertEqual(msg1, echo1)
195
196
197    def _pool_is_healthy(self):
198        """Verifies that the pool service is still functioning.
199
200        Sends an echo message and tests for a response.  This is a stronger
201        signal of aliveness than checking Service.is_running, but a False return
202        value does not necessarily indicate that the pool service shut down
203        cleanly.  Use Service.is_running to check that.
204        """
205        with self.create_client() as client:
206            msg = 'foobar'
207            client.send(message.echo(msg))
208            return client.recv() == msg
209
210
211    @contextmanager
212    def run_service(self, pool=None):
213        """Creates and cleans up a Service instance."""
214        if pool is None:
215            pool = MockPool()
216        svc = service.Service(self.host_dir, pool)
217        thread = threading.Thread(name='service', target=svc.start)
218        thread.start()
219        try:
220            yield svc
221        finally:
222            svc.stop()
223            thread.join(1)
224
225
226    @contextmanager
227    def create_client(self):
228        """Creates and cleans up a client connection."""
229        client = unittest_client.connect(self.address)
230        try:
231            yield client
232        finally:
233            client.close()
234
235
236class MockPool(object):
237    """A mock pool class for testing the service."""
238
239    def __init__(self):
240        """Initializes a mock empty pool."""
241        self.capacity = 0
242        self.size = 0
243        self.worker_count = 0
244        self.errors = Queue.Queue()
245        self.containers = Queue.Queue()
246
247
248    def cleanup(self):
249        """Required by pool interface.  Does nothing."""
250        pass
251
252
253    def get(self, timeout=0):
254        """Required by pool interface.
255
256        @return: A pool from the containers queue.
257        """
258        try:
259            return self.containers.get(block=(timeout > 0), timeout=timeout)
260        except Queue.Empty:
261            return None
262
263
264class MockContainer(object):
265    """A mock container class for testing the service."""
266
267    def __init__(self):
268        """Initializes a mock container."""
269        self.id = None
270        self.name = 'test_container'
271
272
273if __name__ == '__main__':
274    unittest_setup.setup(require_sudo=False)
275    unittest.main()
276