1# Copyright 2016 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import cStringIO
6import inspect
7import json
8import unittest
9
10import common
11from autotest_lib.server.hosts import host_info
12
13
14class HostInfoTest(unittest.TestCase):
15    """Tests the non-trivial attributes of HostInfo."""
16
17    def setUp(self):
18        self.info = host_info.HostInfo()
19
20    def test_info_comparison_to_wrong_type(self):
21        """Comparing HostInfo to a different type always returns False."""
22        self.assertNotEqual(host_info.HostInfo(), 42)
23        self.assertNotEqual(host_info.HostInfo(), None)
24        # equality and non-equality are unrelated by the data model.
25        self.assertFalse(host_info.HostInfo() == 42)
26        self.assertFalse(host_info.HostInfo() == None)
27
28
29    def test_empty_infos_are_equal(self):
30        """Tests that empty HostInfo objects are considered equal."""
31        self.assertEqual(host_info.HostInfo(), host_info.HostInfo())
32        # equality and non-equality are unrelated by the data model.
33        self.assertFalse(host_info.HostInfo() != host_info.HostInfo())
34
35
36    def test_non_trivial_infos_are_equal(self):
37        """Tests that the most complicated infos are correctly stated equal."""
38        info1 = host_info.HostInfo(
39                labels=['label1', 'label2', 'label1'],
40                attributes={'attrib1': None, 'attrib2': 'val2'})
41        info2 = host_info.HostInfo(
42                labels=['label1', 'label2', 'label1'],
43                attributes={'attrib1': None, 'attrib2': 'val2'})
44        self.assertEqual(info1, info2)
45        # equality and non-equality are unrelated by the data model.
46        self.assertFalse(info1 != info2)
47
48
49    def test_non_equal_infos(self):
50        """Tests that HostInfo objects with different information are unequal"""
51        info1 = host_info.HostInfo(labels=['label'])
52        info2 = host_info.HostInfo(attributes={'attrib': 'value'})
53        self.assertNotEqual(info1, info2)
54        # equality and non-equality are unrelated by the data model.
55        self.assertFalse(info1 == info2)
56
57
58    def test_build_needs_prefix(self):
59        """The build prefix is of the form '<type>-version:'"""
60        self.info.labels = ['cros-version', 'ab-version', 'testbed-version',
61                            'fwrw-version', 'fwro-version']
62        self.assertIsNone(self.info.build)
63
64
65    def test_build_prefix_must_be_anchored(self):
66        """Ensure that build ignores prefixes occuring mid-string."""
67        self.info.labels = ['not-at-start-cros-version:cros1',
68                            'not-at-start-ab-version:ab1',
69                            'not-at-start-testbed-version:testbed1']
70        self.assertIsNone(self.info.build)
71
72
73    def test_build_ignores_firmware(self):
74        """build attribute should ignore firmware versions."""
75        self.info.labels = ['fwrw-version:fwrw1', 'fwro-version:fwro1']
76        self.assertIsNone(self.info.build)
77
78
79    def test_build_returns_first_match(self):
80        """When multiple labels match, first one should be used as build."""
81        self.info.labels = ['cros-version:cros1', 'cros-version:cros2']
82        self.assertEqual(self.info.build, 'cros1')
83        self.info.labels = ['ab-version:ab1', 'ab-version:ab2']
84        self.assertEqual(self.info.build, 'ab1')
85        self.info.labels = ['testbed-version:tb1', 'testbed-version:tb2']
86        self.assertEqual(self.info.build, 'tb1')
87
88
89    def test_build_prefer_cros_over_others(self):
90        """When multiple versions are available, prefer cros."""
91        self.info.labels = ['testbed-version:tb1', 'ab-version:ab1',
92                            'cros-version:cros1']
93        self.assertEqual(self.info.build, 'cros1')
94        self.info.labels = ['cros-version:cros1', 'ab-version:ab1',
95                            'testbed-version:tb1']
96        self.assertEqual(self.info.build, 'cros1')
97
98
99    def test_build_prefer_ab_over_testbed(self):
100        """When multiple versions are available, prefer ab over testbed."""
101        self.info.labels = ['testbed-version:tb1', 'ab-version:ab1']
102        self.assertEqual(self.info.build, 'ab1')
103        self.info.labels = ['ab-version:ab1', 'testbed-version:tb1']
104        self.assertEqual(self.info.build, 'ab1')
105
106
107    def test_os_no_match(self):
108        """Use proper prefix to search for os information."""
109        self.info.labels = ['something_else', 'cros-version:hana',
110                            'os_without_colon']
111        self.assertEqual(self.info.os, '')
112
113
114    def test_os_returns_first_match(self):
115        """Return the first matching os label."""
116        self.info.labels = ['os:linux', 'os:windows', 'os_corrupted_label']
117        self.assertEqual(self.info.os, 'linux')
118
119
120    def test_board_no_match(self):
121        """Use proper prefix to search for board information."""
122        self.info.labels = ['something_else', 'cros-version:hana', 'os:blah',
123                            'board_my_board_no_colon']
124        self.assertEqual(self.info.board, '')
125
126
127    def test_board_returns_first_match(self):
128        """Return the first matching board label."""
129        self.info.labels = ['board_corrupted', 'board:walk', 'board:bored']
130        self.assertEqual(self.info.board, 'walk')
131
132
133    def test_pools_no_match(self):
134        """Use proper prefix to search for pool information."""
135        self.info.labels = ['something_else', 'cros-version:hana', 'os:blah',
136                            'board_my_board_no_colon', 'board:my_board']
137        self.assertEqual(self.info.pools, set())
138
139
140    def test_pools_returns_all_matches(self):
141        """Return all matching pool labels."""
142        self.info.labels = ['board_corrupted', 'board:walk', 'board:bored',
143                            'pool:first_pool', 'pool:second_pool']
144        self.assertEqual(self.info.pools, {'second_pool', 'first_pool'})
145
146
147    def test_str(self):
148        """Sanity checks the __str__ implementation."""
149        info = host_info.HostInfo(labels=['a'], attributes={'b': 2})
150        self.assertEqual(str(info),
151                         "HostInfo[Labels: ['a'], Attributes: {'b': 2}]")
152
153
154    def test_clear_version_labels_no_labels(self):
155        """When no version labels exit, do nothing for clear_version_labels."""
156        original_labels = ['board:something', 'os:something_else',
157                           'pool:mypool', 'ab-version-corrupted:blah',
158                           'cros-version']
159        self.info.labels = list(original_labels)
160        self.info.clear_version_labels()
161        self.assertListEqual(self.info.labels, original_labels)
162
163
164    def test_clear_all_version_labels(self):
165        """Clear each recognized type of version label."""
166        original_labels = ['extra_label', 'cros-version:cr1', 'ab-version:ab1',
167                           'testbed-version:tb1']
168        self.info.labels = list(original_labels)
169        self.info.clear_version_labels()
170        self.assertListEqual(self.info.labels, ['extra_label'])
171
172    def test_clear_all_version_label_prefixes(self):
173        """Clear each recognized type of version label with empty value."""
174        original_labels = ['extra_label', 'cros-version:', 'ab-version:',
175                           'testbed-version:']
176        self.info.labels = list(original_labels)
177        self.info.clear_version_labels()
178        self.assertListEqual(self.info.labels, ['extra_label'])
179
180
181    def test_set_version_labels_updates_in_place(self):
182        """Update version label in place if prefix already exists."""
183        self.info.labels = ['extra', 'cros-version:X', 'ab-version:Y']
184        self.info.set_version_label('cros-version', 'Z')
185        self.assertListEqual(self.info.labels, ['extra', 'cros-version:Z',
186                                                'ab-version:Y'])
187
188    def test_set_version_labels_appends(self):
189        """Append a new version label if the prefix doesn't exist."""
190        self.info.labels = ['extra', 'ab-version:Y']
191        self.info.set_version_label('cros-version', 'Z')
192        self.assertListEqual(self.info.labels, ['extra', 'ab-version:Y',
193                                                'cros-version:Z'])
194
195
196class InMemoryHostInfoStoreTest(unittest.TestCase):
197    """Basic tests for CachingHostInfoStore using InMemoryHostInfoStore."""
198
199    def setUp(self):
200        self.store = host_info.InMemoryHostInfoStore()
201
202
203    def _verify_host_info_data(self, host_info, labels, attributes):
204        """Verifies the data in the given host_info."""
205        self.assertListEqual(host_info.labels, labels)
206        self.assertDictEqual(host_info.attributes, attributes)
207
208
209    def test_first_get_refreshes_cache(self):
210        """Test that the first call to get gets the data from store."""
211        self.store.info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
212        got = self.store.get()
213        self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
214
215
216    def test_repeated_get_returns_from_cache(self):
217        """Tests that repeated calls to get do not refresh cache."""
218        self.store.info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
219        got = self.store.get()
220        self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
221
222        self.store.info = host_info.HostInfo(['label1', 'label2'], {})
223        got = self.store.get()
224        self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
225
226
227    def test_get_uncached_always_refreshes_cache(self):
228        """Tests that calling get_uncached always refreshes the cache."""
229        self.store.info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
230        got = self.store.get(force_refresh=True)
231        self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
232
233        self.store.info = host_info.HostInfo(['label1', 'label2'], {})
234        got = self.store.get(force_refresh=True)
235        self._verify_host_info_data(got, ['label1', 'label2'], {})
236
237
238    def test_commit(self):
239        """Test that commit sends data to store."""
240        info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
241        self._verify_host_info_data(self.store.info, [], {})
242        self.store.commit(info)
243        self._verify_host_info_data(self.store.info, ['label1'],
244                                    {'attrib1': 'val1'})
245
246
247    def test_commit_then_get(self):
248        """Test a commit-get roundtrip."""
249        got = self.store.get()
250        self._verify_host_info_data(got, [], {})
251
252        info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
253        self.store.commit(info)
254        got = self.store.get()
255        self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
256
257
258    def test_commit_then_get_uncached(self):
259        """Test a commit-get_uncached roundtrip."""
260        got = self.store.get()
261        self._verify_host_info_data(got, [], {})
262
263        info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
264        self.store.commit(info)
265        got = self.store.get(force_refresh=True)
266        self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
267
268
269    def test_commit_deepcopies_data(self):
270        """Once commited, changes to HostInfo don't corrupt the store."""
271        info = host_info.HostInfo(['label1'], {'attrib1': {'key1': 'data1'}})
272        self.store.commit(info)
273        info.labels.append('label2')
274        info.attributes['attrib1']['key1'] = 'data2'
275        self._verify_host_info_data(self.store.info,
276                                    ['label1'], {'attrib1': {'key1': 'data1'}})
277
278
279    def test_get_returns_deepcopy(self):
280        """The cached object is protected from |get| caller modifications."""
281        self.store.info = host_info.HostInfo(['label1'],
282                                             {'attrib1': {'key1': 'data1'}})
283        got = self.store.get()
284        self._verify_host_info_data(got,
285                                    ['label1'], {'attrib1': {'key1': 'data1'}})
286        got.labels.append('label2')
287        got.attributes['attrib1']['key1'] = 'data2'
288        got = self.store.get()
289        self._verify_host_info_data(got,
290                                    ['label1'], {'attrib1': {'key1': 'data1'}})
291
292
293    def test_str(self):
294        """Sanity tests __str__ implementation."""
295        self.store.info = host_info.HostInfo(['label1'],
296                                             {'attrib1': {'key1': 'data1'}})
297        self.assertEqual(str(self.store),
298                         'InMemoryHostInfoStore[%s]' % self.store.info)
299
300
301class ExceptionRaisingStore(host_info.CachingHostInfoStore):
302    """A test class that always raises on refresh / commit."""
303
304    def __init__(self):
305        super(ExceptionRaisingStore, self).__init__()
306        self.refresh_raises = True
307        self.commit_raises = True
308
309
310    def _refresh_impl(self):
311        if self.refresh_raises:
312            raise host_info.StoreError('no can do')
313        return host_info.HostInfo()
314
315    def _commit_impl(self, _):
316        if self.commit_raises:
317            raise host_info.StoreError('wont wont wont')
318
319
320class CachingHostInfoStoreErrorTest(unittest.TestCase):
321    """Tests error behaviours of CachingHostInfoStore."""
322
323    def setUp(self):
324        self.store = ExceptionRaisingStore()
325
326
327    def test_failed_refresh_cleans_cache(self):
328        """Sanity checks return values when refresh raises."""
329        with self.assertRaises(host_info.StoreError):
330            self.store.get()
331        # Since |get| hit an error, a subsequent get should again hit the store.
332        with self.assertRaises(host_info.StoreError):
333            self.store.get()
334
335
336    def test_failed_commit_cleans_cache(self):
337        """Check that a failed commit cleanes cache."""
338        # Let's initialize the store without errors.
339        self.store.refresh_raises = False
340        self.store.get(force_refresh=True)
341        self.store.refresh_raises = True
342
343        with self.assertRaises(host_info.StoreError):
344            self.store.commit(host_info.HostInfo())
345        # Since |commit| hit an error, a subsequent get should again hit the
346        # store.
347        with self.assertRaises(host_info.StoreError):
348            self.store.get()
349
350
351class GetStoreFromMachineTest(unittest.TestCase):
352    """Tests the get_store_from_machine function."""
353
354    def test_machine_is_dict(self):
355        """We extract the store when machine is a dict."""
356        machine = {
357                'something': 'else',
358                'host_info_store': 5
359        }
360        self.assertEqual(host_info.get_store_from_machine(machine), 5)
361
362
363    def test_machine_is_string(self):
364        """We return a trivial store when machine is a string."""
365        machine = 'hostname'
366        self.assertTrue(isinstance(host_info.get_store_from_machine(machine),
367                                   host_info.InMemoryHostInfoStore))
368
369
370class HostInfoJsonSerializationTestCase(unittest.TestCase):
371    """Tests the json_serialize and json_deserialize functions."""
372
373    CURRENT_SERIALIZATION_VERSION = host_info._CURRENT_SERIALIZATION_VERSION
374
375    def test_serialize_empty(self):
376        """Serializing empty HostInfo results in the expected json."""
377        info = host_info.HostInfo()
378        file_obj = cStringIO.StringIO()
379        host_info.json_serialize(info, file_obj)
380        file_obj.seek(0)
381        expected_dict = {
382                'serializer_version': self.CURRENT_SERIALIZATION_VERSION,
383                'attributes' : {},
384                'labels': [],
385        }
386        self.assertEqual(json.load(file_obj), expected_dict)
387
388
389    def test_serialize_non_empty(self):
390        """Serializing a populated HostInfo results in expected json."""
391        info = host_info.HostInfo(labels=['label1'],
392                                  attributes={'attrib': 'val'})
393        file_obj = cStringIO.StringIO()
394        host_info.json_serialize(info, file_obj)
395        file_obj.seek(0)
396        expected_dict = {
397                'serializer_version': self.CURRENT_SERIALIZATION_VERSION,
398                'attributes' : {'attrib': 'val'},
399                'labels': ['label1'],
400        }
401        self.assertEqual(json.load(file_obj), expected_dict)
402
403
404    def test_round_trip_empty(self):
405        """Serializing - deserializing empty HostInfo keeps it unchanged."""
406        info = host_info.HostInfo()
407        serialized_fp = cStringIO.StringIO()
408        host_info.json_serialize(info, serialized_fp)
409        serialized_fp.seek(0)
410        got = host_info.json_deserialize(serialized_fp)
411        self.assertEqual(got, info)
412
413
414    def test_round_trip_non_empty(self):
415        """Serializing - deserializing non-empty HostInfo keeps it unchanged."""
416        info = host_info.HostInfo(
417                labels=['label1'],
418                attributes = {'attrib': 'val'})
419        serialized_fp = cStringIO.StringIO()
420        host_info.json_serialize(info, serialized_fp)
421        serialized_fp.seek(0)
422        got = host_info.json_deserialize(serialized_fp)
423        self.assertEqual(got, info)
424
425
426    def test_deserialize_malformed_json_raises(self):
427        """Deserializing a malformed string raises."""
428        with self.assertRaises(host_info.DeserializationError):
429            host_info.json_deserialize(cStringIO.StringIO('{labels:['))
430
431
432    def test_deserialize_no_version_raises(self):
433        """Deserializing a string with no serializer version raises."""
434        info = host_info.HostInfo()
435        serialized_fp = cStringIO.StringIO()
436        host_info.json_serialize(info, serialized_fp)
437        serialized_fp.seek(0)
438
439        serialized_dict = json.load(serialized_fp)
440        del serialized_dict['serializer_version']
441        serialized_no_version_str = json.dumps(serialized_dict)
442
443        with self.assertRaises(host_info.DeserializationError):
444            host_info.json_deserialize(
445                    cStringIO.StringIO(serialized_no_version_str))
446
447
448    def test_deserialize_malformed_host_info_raises(self):
449        """Deserializing a malformed host_info raises."""
450        info = host_info.HostInfo()
451        serialized_fp = cStringIO.StringIO()
452        host_info.json_serialize(info, serialized_fp)
453        serialized_fp.seek(0)
454
455        serialized_dict = json.load(serialized_fp)
456        del serialized_dict['labels']
457        serialized_no_version_str = json.dumps(serialized_dict)
458
459        with self.assertRaises(host_info.DeserializationError):
460            host_info.json_deserialize(
461                    cStringIO.StringIO(serialized_no_version_str))
462
463
464    def test_enforce_compatibility_version_1(self):
465        """Tests that required fields are never dropped.
466
467        Never change this test. If you must break compatibility, uprev the
468        serializer version and add a new test for the newer version.
469
470        Adding a field to compat_info_str means we're making the new field
471        mandatory. This breaks backwards compatibility.
472        Removing a field from compat_info_str means we're no longer requiring a
473        field to be mandatory. This breaks forwards compatibility.
474        """
475        compat_dict = {
476                'serializer_version': 1,
477                'attributes': {},
478                'labels': []
479        }
480        serialized_str = json.dumps(compat_dict)
481        serialized_fp = cStringIO.StringIO(serialized_str)
482        host_info.json_deserialize(serialized_fp)
483
484
485    def test_serialize_pretty_print(self):
486        """Serializing a host_info dumps the json in human-friendly format"""
487        info = host_info.HostInfo(labels=['label1'],
488                                  attributes={'attrib': 'val'})
489        serialized_fp = cStringIO.StringIO()
490        host_info.json_serialize(info, serialized_fp)
491        expected = """{
492            "attributes": {
493                "attrib": "val"
494            },
495            "labels": [
496                "label1"
497            ],
498            "serializer_version": %d
499        }""" % self.CURRENT_SERIALIZATION_VERSION
500        self.assertEqual(serialized_fp.getvalue(), inspect.cleandoc(expected))
501
502
503if __name__ == '__main__':
504    unittest.main()
505