1# Copyright 2016 The Chromium OS Authors. All rights reserved. 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4 5import cStringIO 6import inspect 7import json 8import unittest 9 10import common 11from autotest_lib.server.hosts import host_info 12 13 14class HostInfoTest(unittest.TestCase): 15 """Tests the non-trivial attributes of HostInfo.""" 16 17 def setUp(self): 18 self.info = host_info.HostInfo() 19 20 def test_info_comparison_to_wrong_type(self): 21 """Comparing HostInfo to a different type always returns False.""" 22 self.assertNotEqual(host_info.HostInfo(), 42) 23 self.assertNotEqual(host_info.HostInfo(), None) 24 # equality and non-equality are unrelated by the data model. 25 self.assertFalse(host_info.HostInfo() == 42) 26 self.assertFalse(host_info.HostInfo() == None) 27 28 29 def test_empty_infos_are_equal(self): 30 """Tests that empty HostInfo objects are considered equal.""" 31 self.assertEqual(host_info.HostInfo(), host_info.HostInfo()) 32 # equality and non-equality are unrelated by the data model. 33 self.assertFalse(host_info.HostInfo() != host_info.HostInfo()) 34 35 36 def test_non_trivial_infos_are_equal(self): 37 """Tests that the most complicated infos are correctly stated equal.""" 38 info1 = host_info.HostInfo( 39 labels=['label1', 'label2', 'label1'], 40 attributes={'attrib1': None, 'attrib2': 'val2'}) 41 info2 = host_info.HostInfo( 42 labels=['label1', 'label2', 'label1'], 43 attributes={'attrib1': None, 'attrib2': 'val2'}) 44 self.assertEqual(info1, info2) 45 # equality and non-equality are unrelated by the data model. 46 self.assertFalse(info1 != info2) 47 48 49 def test_non_equal_infos(self): 50 """Tests that HostInfo objects with different information are unequal""" 51 info1 = host_info.HostInfo(labels=['label']) 52 info2 = host_info.HostInfo(attributes={'attrib': 'value'}) 53 self.assertNotEqual(info1, info2) 54 # equality and non-equality are unrelated by the data model. 55 self.assertFalse(info1 == info2) 56 57 58 def test_build_needs_prefix(self): 59 """The build prefix is of the form '<type>-version:'""" 60 self.info.labels = ['cros-version', 'ab-version', 'testbed-version', 61 'fwrw-version', 'fwro-version'] 62 self.assertIsNone(self.info.build) 63 64 65 def test_build_prefix_must_be_anchored(self): 66 """Ensure that build ignores prefixes occuring mid-string.""" 67 self.info.labels = ['not-at-start-cros-version:cros1', 68 'not-at-start-ab-version:ab1', 69 'not-at-start-testbed-version:testbed1'] 70 self.assertIsNone(self.info.build) 71 72 73 def test_build_ignores_firmware(self): 74 """build attribute should ignore firmware versions.""" 75 self.info.labels = ['fwrw-version:fwrw1', 'fwro-version:fwro1'] 76 self.assertIsNone(self.info.build) 77 78 79 def test_build_returns_first_match(self): 80 """When multiple labels match, first one should be used as build.""" 81 self.info.labels = ['cros-version:cros1', 'cros-version:cros2'] 82 self.assertEqual(self.info.build, 'cros1') 83 self.info.labels = ['ab-version:ab1', 'ab-version:ab2'] 84 self.assertEqual(self.info.build, 'ab1') 85 self.info.labels = ['testbed-version:tb1', 'testbed-version:tb2'] 86 self.assertEqual(self.info.build, 'tb1') 87 88 89 def test_build_prefer_cros_over_others(self): 90 """When multiple versions are available, prefer cros.""" 91 self.info.labels = ['testbed-version:tb1', 'ab-version:ab1', 92 'cros-version:cros1'] 93 self.assertEqual(self.info.build, 'cros1') 94 self.info.labels = ['cros-version:cros1', 'ab-version:ab1', 95 'testbed-version:tb1'] 96 self.assertEqual(self.info.build, 'cros1') 97 98 99 def test_build_prefer_ab_over_testbed(self): 100 """When multiple versions are available, prefer ab over testbed.""" 101 self.info.labels = ['testbed-version:tb1', 'ab-version:ab1'] 102 self.assertEqual(self.info.build, 'ab1') 103 self.info.labels = ['ab-version:ab1', 'testbed-version:tb1'] 104 self.assertEqual(self.info.build, 'ab1') 105 106 107 def test_os_no_match(self): 108 """Use proper prefix to search for os information.""" 109 self.info.labels = ['something_else', 'cros-version:hana', 110 'os_without_colon'] 111 self.assertEqual(self.info.os, '') 112 113 114 def test_os_returns_first_match(self): 115 """Return the first matching os label.""" 116 self.info.labels = ['os:linux', 'os:windows', 'os_corrupted_label'] 117 self.assertEqual(self.info.os, 'linux') 118 119 120 def test_board_no_match(self): 121 """Use proper prefix to search for board information.""" 122 self.info.labels = ['something_else', 'cros-version:hana', 'os:blah', 123 'board_my_board_no_colon'] 124 self.assertEqual(self.info.board, '') 125 126 127 def test_board_returns_first_match(self): 128 """Return the first matching board label.""" 129 self.info.labels = ['board_corrupted', 'board:walk', 'board:bored'] 130 self.assertEqual(self.info.board, 'walk') 131 132 133 def test_pools_no_match(self): 134 """Use proper prefix to search for pool information.""" 135 self.info.labels = ['something_else', 'cros-version:hana', 'os:blah', 136 'board_my_board_no_colon', 'board:my_board'] 137 self.assertEqual(self.info.pools, set()) 138 139 140 def test_pools_returns_all_matches(self): 141 """Return all matching pool labels.""" 142 self.info.labels = ['board_corrupted', 'board:walk', 'board:bored', 143 'pool:first_pool', 'pool:second_pool'] 144 self.assertEqual(self.info.pools, {'second_pool', 'first_pool'}) 145 146 147 def test_str(self): 148 """Sanity checks the __str__ implementation.""" 149 info = host_info.HostInfo(labels=['a'], attributes={'b': 2}) 150 self.assertEqual(str(info), 151 "HostInfo[Labels: ['a'], Attributes: {'b': 2}]") 152 153 154 def test_clear_version_labels_no_labels(self): 155 """When no version labels exit, do nothing for clear_version_labels.""" 156 original_labels = ['board:something', 'os:something_else', 157 'pool:mypool', 'ab-version-corrupted:blah', 158 'cros-version'] 159 self.info.labels = list(original_labels) 160 self.info.clear_version_labels() 161 self.assertListEqual(self.info.labels, original_labels) 162 163 164 def test_clear_all_version_labels(self): 165 """Clear each recognized type of version label.""" 166 original_labels = ['extra_label', 'cros-version:cr1', 'ab-version:ab1', 167 'testbed-version:tb1'] 168 self.info.labels = list(original_labels) 169 self.info.clear_version_labels() 170 self.assertListEqual(self.info.labels, ['extra_label']) 171 172 def test_clear_all_version_label_prefixes(self): 173 """Clear each recognized type of version label with empty value.""" 174 original_labels = ['extra_label', 'cros-version:', 'ab-version:', 175 'testbed-version:'] 176 self.info.labels = list(original_labels) 177 self.info.clear_version_labels() 178 self.assertListEqual(self.info.labels, ['extra_label']) 179 180 181 def test_set_version_labels_updates_in_place(self): 182 """Update version label in place if prefix already exists.""" 183 self.info.labels = ['extra', 'cros-version:X', 'ab-version:Y'] 184 self.info.set_version_label('cros-version', 'Z') 185 self.assertListEqual(self.info.labels, ['extra', 'cros-version:Z', 186 'ab-version:Y']) 187 188 def test_set_version_labels_appends(self): 189 """Append a new version label if the prefix doesn't exist.""" 190 self.info.labels = ['extra', 'ab-version:Y'] 191 self.info.set_version_label('cros-version', 'Z') 192 self.assertListEqual(self.info.labels, ['extra', 'ab-version:Y', 193 'cros-version:Z']) 194 195 196class InMemoryHostInfoStoreTest(unittest.TestCase): 197 """Basic tests for CachingHostInfoStore using InMemoryHostInfoStore.""" 198 199 def setUp(self): 200 self.store = host_info.InMemoryHostInfoStore() 201 202 203 def _verify_host_info_data(self, host_info, labels, attributes): 204 """Verifies the data in the given host_info.""" 205 self.assertListEqual(host_info.labels, labels) 206 self.assertDictEqual(host_info.attributes, attributes) 207 208 209 def test_first_get_refreshes_cache(self): 210 """Test that the first call to get gets the data from store.""" 211 self.store.info = host_info.HostInfo(['label1'], {'attrib1': 'val1'}) 212 got = self.store.get() 213 self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'}) 214 215 216 def test_repeated_get_returns_from_cache(self): 217 """Tests that repeated calls to get do not refresh cache.""" 218 self.store.info = host_info.HostInfo(['label1'], {'attrib1': 'val1'}) 219 got = self.store.get() 220 self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'}) 221 222 self.store.info = host_info.HostInfo(['label1', 'label2'], {}) 223 got = self.store.get() 224 self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'}) 225 226 227 def test_get_uncached_always_refreshes_cache(self): 228 """Tests that calling get_uncached always refreshes the cache.""" 229 self.store.info = host_info.HostInfo(['label1'], {'attrib1': 'val1'}) 230 got = self.store.get(force_refresh=True) 231 self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'}) 232 233 self.store.info = host_info.HostInfo(['label1', 'label2'], {}) 234 got = self.store.get(force_refresh=True) 235 self._verify_host_info_data(got, ['label1', 'label2'], {}) 236 237 238 def test_commit(self): 239 """Test that commit sends data to store.""" 240 info = host_info.HostInfo(['label1'], {'attrib1': 'val1'}) 241 self._verify_host_info_data(self.store.info, [], {}) 242 self.store.commit(info) 243 self._verify_host_info_data(self.store.info, ['label1'], 244 {'attrib1': 'val1'}) 245 246 247 def test_commit_then_get(self): 248 """Test a commit-get roundtrip.""" 249 got = self.store.get() 250 self._verify_host_info_data(got, [], {}) 251 252 info = host_info.HostInfo(['label1'], {'attrib1': 'val1'}) 253 self.store.commit(info) 254 got = self.store.get() 255 self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'}) 256 257 258 def test_commit_then_get_uncached(self): 259 """Test a commit-get_uncached roundtrip.""" 260 got = self.store.get() 261 self._verify_host_info_data(got, [], {}) 262 263 info = host_info.HostInfo(['label1'], {'attrib1': 'val1'}) 264 self.store.commit(info) 265 got = self.store.get(force_refresh=True) 266 self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'}) 267 268 269 def test_commit_deepcopies_data(self): 270 """Once commited, changes to HostInfo don't corrupt the store.""" 271 info = host_info.HostInfo(['label1'], {'attrib1': {'key1': 'data1'}}) 272 self.store.commit(info) 273 info.labels.append('label2') 274 info.attributes['attrib1']['key1'] = 'data2' 275 self._verify_host_info_data(self.store.info, 276 ['label1'], {'attrib1': {'key1': 'data1'}}) 277 278 279 def test_get_returns_deepcopy(self): 280 """The cached object is protected from |get| caller modifications.""" 281 self.store.info = host_info.HostInfo(['label1'], 282 {'attrib1': {'key1': 'data1'}}) 283 got = self.store.get() 284 self._verify_host_info_data(got, 285 ['label1'], {'attrib1': {'key1': 'data1'}}) 286 got.labels.append('label2') 287 got.attributes['attrib1']['key1'] = 'data2' 288 got = self.store.get() 289 self._verify_host_info_data(got, 290 ['label1'], {'attrib1': {'key1': 'data1'}}) 291 292 293 def test_str(self): 294 """Sanity tests __str__ implementation.""" 295 self.store.info = host_info.HostInfo(['label1'], 296 {'attrib1': {'key1': 'data1'}}) 297 self.assertEqual(str(self.store), 298 'InMemoryHostInfoStore[%s]' % self.store.info) 299 300 301class ExceptionRaisingStore(host_info.CachingHostInfoStore): 302 """A test class that always raises on refresh / commit.""" 303 304 def __init__(self): 305 super(ExceptionRaisingStore, self).__init__() 306 self.refresh_raises = True 307 self.commit_raises = True 308 309 310 def _refresh_impl(self): 311 if self.refresh_raises: 312 raise host_info.StoreError('no can do') 313 return host_info.HostInfo() 314 315 def _commit_impl(self, _): 316 if self.commit_raises: 317 raise host_info.StoreError('wont wont wont') 318 319 320class CachingHostInfoStoreErrorTest(unittest.TestCase): 321 """Tests error behaviours of CachingHostInfoStore.""" 322 323 def setUp(self): 324 self.store = ExceptionRaisingStore() 325 326 327 def test_failed_refresh_cleans_cache(self): 328 """Sanity checks return values when refresh raises.""" 329 with self.assertRaises(host_info.StoreError): 330 self.store.get() 331 # Since |get| hit an error, a subsequent get should again hit the store. 332 with self.assertRaises(host_info.StoreError): 333 self.store.get() 334 335 336 def test_failed_commit_cleans_cache(self): 337 """Check that a failed commit cleanes cache.""" 338 # Let's initialize the store without errors. 339 self.store.refresh_raises = False 340 self.store.get(force_refresh=True) 341 self.store.refresh_raises = True 342 343 with self.assertRaises(host_info.StoreError): 344 self.store.commit(host_info.HostInfo()) 345 # Since |commit| hit an error, a subsequent get should again hit the 346 # store. 347 with self.assertRaises(host_info.StoreError): 348 self.store.get() 349 350 351class GetStoreFromMachineTest(unittest.TestCase): 352 """Tests the get_store_from_machine function.""" 353 354 def test_machine_is_dict(self): 355 """We extract the store when machine is a dict.""" 356 machine = { 357 'something': 'else', 358 'host_info_store': 5 359 } 360 self.assertEqual(host_info.get_store_from_machine(machine), 5) 361 362 363 def test_machine_is_string(self): 364 """We return a trivial store when machine is a string.""" 365 machine = 'hostname' 366 self.assertTrue(isinstance(host_info.get_store_from_machine(machine), 367 host_info.InMemoryHostInfoStore)) 368 369 370class HostInfoJsonSerializationTestCase(unittest.TestCase): 371 """Tests the json_serialize and json_deserialize functions.""" 372 373 CURRENT_SERIALIZATION_VERSION = host_info._CURRENT_SERIALIZATION_VERSION 374 375 def test_serialize_empty(self): 376 """Serializing empty HostInfo results in the expected json.""" 377 info = host_info.HostInfo() 378 file_obj = cStringIO.StringIO() 379 host_info.json_serialize(info, file_obj) 380 file_obj.seek(0) 381 expected_dict = { 382 'serializer_version': self.CURRENT_SERIALIZATION_VERSION, 383 'attributes' : {}, 384 'labels': [], 385 } 386 self.assertEqual(json.load(file_obj), expected_dict) 387 388 389 def test_serialize_non_empty(self): 390 """Serializing a populated HostInfo results in expected json.""" 391 info = host_info.HostInfo(labels=['label1'], 392 attributes={'attrib': 'val'}) 393 file_obj = cStringIO.StringIO() 394 host_info.json_serialize(info, file_obj) 395 file_obj.seek(0) 396 expected_dict = { 397 'serializer_version': self.CURRENT_SERIALIZATION_VERSION, 398 'attributes' : {'attrib': 'val'}, 399 'labels': ['label1'], 400 } 401 self.assertEqual(json.load(file_obj), expected_dict) 402 403 404 def test_round_trip_empty(self): 405 """Serializing - deserializing empty HostInfo keeps it unchanged.""" 406 info = host_info.HostInfo() 407 serialized_fp = cStringIO.StringIO() 408 host_info.json_serialize(info, serialized_fp) 409 serialized_fp.seek(0) 410 got = host_info.json_deserialize(serialized_fp) 411 self.assertEqual(got, info) 412 413 414 def test_round_trip_non_empty(self): 415 """Serializing - deserializing non-empty HostInfo keeps it unchanged.""" 416 info = host_info.HostInfo( 417 labels=['label1'], 418 attributes = {'attrib': 'val'}) 419 serialized_fp = cStringIO.StringIO() 420 host_info.json_serialize(info, serialized_fp) 421 serialized_fp.seek(0) 422 got = host_info.json_deserialize(serialized_fp) 423 self.assertEqual(got, info) 424 425 426 def test_deserialize_malformed_json_raises(self): 427 """Deserializing a malformed string raises.""" 428 with self.assertRaises(host_info.DeserializationError): 429 host_info.json_deserialize(cStringIO.StringIO('{labels:[')) 430 431 432 def test_deserialize_no_version_raises(self): 433 """Deserializing a string with no serializer version raises.""" 434 info = host_info.HostInfo() 435 serialized_fp = cStringIO.StringIO() 436 host_info.json_serialize(info, serialized_fp) 437 serialized_fp.seek(0) 438 439 serialized_dict = json.load(serialized_fp) 440 del serialized_dict['serializer_version'] 441 serialized_no_version_str = json.dumps(serialized_dict) 442 443 with self.assertRaises(host_info.DeserializationError): 444 host_info.json_deserialize( 445 cStringIO.StringIO(serialized_no_version_str)) 446 447 448 def test_deserialize_malformed_host_info_raises(self): 449 """Deserializing a malformed host_info raises.""" 450 info = host_info.HostInfo() 451 serialized_fp = cStringIO.StringIO() 452 host_info.json_serialize(info, serialized_fp) 453 serialized_fp.seek(0) 454 455 serialized_dict = json.load(serialized_fp) 456 del serialized_dict['labels'] 457 serialized_no_version_str = json.dumps(serialized_dict) 458 459 with self.assertRaises(host_info.DeserializationError): 460 host_info.json_deserialize( 461 cStringIO.StringIO(serialized_no_version_str)) 462 463 464 def test_enforce_compatibility_version_1(self): 465 """Tests that required fields are never dropped. 466 467 Never change this test. If you must break compatibility, uprev the 468 serializer version and add a new test for the newer version. 469 470 Adding a field to compat_info_str means we're making the new field 471 mandatory. This breaks backwards compatibility. 472 Removing a field from compat_info_str means we're no longer requiring a 473 field to be mandatory. This breaks forwards compatibility. 474 """ 475 compat_dict = { 476 'serializer_version': 1, 477 'attributes': {}, 478 'labels': [] 479 } 480 serialized_str = json.dumps(compat_dict) 481 serialized_fp = cStringIO.StringIO(serialized_str) 482 host_info.json_deserialize(serialized_fp) 483 484 485 def test_serialize_pretty_print(self): 486 """Serializing a host_info dumps the json in human-friendly format""" 487 info = host_info.HostInfo(labels=['label1'], 488 attributes={'attrib': 'val'}) 489 serialized_fp = cStringIO.StringIO() 490 host_info.json_serialize(info, serialized_fp) 491 expected = """{ 492 "attributes": { 493 "attrib": "val" 494 }, 495 "labels": [ 496 "label1" 497 ], 498 "serializer_version": %d 499 }""" % self.CURRENT_SERIALIZATION_VERSION 500 self.assertEqual(serialized_fp.getvalue(), inspect.cleandoc(expected)) 501 502 503if __name__ == '__main__': 504 unittest.main() 505