1#!/usr/bin/python
2#pylint: disable-msg=C0111
3
4import datetime
5import unittest
6
7import common
8from autotest_lib.frontend import setup_django_environment
9from autotest_lib.frontend.afe import frontend_test_utils
10from autotest_lib.client.common_lib import host_queue_entry_states
11from autotest_lib.database import database_connection
12from autotest_lib.frontend.afe import models, model_attributes
13from autotest_lib.scheduler import monitor_db
14from autotest_lib.scheduler import scheduler_lib
15from autotest_lib.scheduler import scheduler_models
16
17_DEBUG = False
18
19
20class BaseSchedulerModelsTest(unittest.TestCase,
21                              frontend_test_utils.FrontendTestMixin):
22    _config_section = 'AUTOTEST_WEB'
23
24    def _do_query(self, sql):
25        self._database.execute(sql)
26
27
28    def _set_monitor_stubs(self):
29        # Clear the instance cache as this is a brand new database.
30        scheduler_models.DBObject._clear_instance_cache()
31
32        self._database = (
33            database_connection.TranslatingDatabase.get_test_database(
34                translators=scheduler_lib._DB_TRANSLATORS))
35        self._database.connect(db_type='django')
36        self._database.debug = _DEBUG
37
38        self.god.stub_with(scheduler_models, '_db', self._database)
39
40
41    def setUp(self):
42        self._frontend_common_setup()
43        self._set_monitor_stubs()
44
45
46    def tearDown(self):
47        self._database.disconnect()
48        self._frontend_common_teardown()
49
50
51    def _update_hqe(self, set, where=''):
52        query = 'UPDATE afe_host_queue_entries SET ' + set
53        if where:
54            query += ' WHERE ' + where
55        self._do_query(query)
56
57
58class DBObjectTest(BaseSchedulerModelsTest):
59
60    def test_compare_fields_in_row(self):
61        host = scheduler_models.Host(id=1)
62        fields = list(host._fields)
63        row_data = [getattr(host, fieldname) for fieldname in fields]
64        self.assertEqual({}, host._compare_fields_in_row(row_data))
65        row_data[fields.index('hostname')] = 'spam'
66        self.assertEqual({'hostname': ('host1', 'spam')},
67                         host._compare_fields_in_row(row_data))
68        row_data[fields.index('id')] = 23
69        self.assertEqual({'hostname': ('host1', 'spam'), 'id': (1, 23)},
70                         host._compare_fields_in_row(row_data))
71
72
73    def test_compare_fields_in_row_datetime_ignores_microseconds(self):
74        datetime_with_us = datetime.datetime(2009, 10, 07, 12, 34, 56, 7890)
75        datetime_without_us = datetime.datetime(2009, 10, 07, 12, 34, 56, 0)
76        class TestTable(scheduler_models.DBObject):
77            _table_name = 'test_table'
78            _fields = ('id', 'test_datetime')
79        tt = TestTable(row=[1, datetime_without_us])
80        self.assertEqual({}, tt._compare_fields_in_row([1, datetime_with_us]))
81
82
83    def test_always_query(self):
84        host_a = scheduler_models.Host(id=2)
85        self.assertEqual(host_a.hostname, 'host2')
86        self._do_query('UPDATE afe_hosts SET hostname="host2-updated" '
87                       'WHERE id=2')
88        host_b = scheduler_models.Host(id=2, always_query=True)
89        self.assert_(host_a is host_b, 'Cached instance not returned.')
90        self.assertEqual(host_a.hostname, 'host2-updated',
91                         'Database was not re-queried')
92
93        # If either of these are called, a query was made when it shouldn't be.
94        host_a._compare_fields_in_row = lambda _: self.fail('eek! a query!')
95        host_a._update_fields_from_row = host_a._compare_fields_in_row
96        host_c = scheduler_models.Host(id=2, always_query=False)
97        self.assert_(host_a is host_c, 'Cached instance not returned')
98
99
100    def test_delete(self):
101        host = scheduler_models.Host(id=3)
102        host.delete()
103        host = self.assertRaises(scheduler_models.DBError, scheduler_models.Host, id=3,
104                                 always_query=False)
105        host = self.assertRaises(scheduler_models.DBError, scheduler_models.Host, id=3,
106                                 always_query=True)
107
108    def test_save(self):
109        # Dummy Job to avoid creating a one in the HostQueueEntry __init__.
110        class MockJob(object):
111            def __init__(self, id, row):
112                pass
113            def tag(self):
114                return 'MockJob'
115        self.god.stub_with(scheduler_models, 'Job', MockJob)
116        hqe = scheduler_models.HostQueueEntry(
117                new_record=True,
118                row=[0, 1, 2, 'Queued', None, 0, 0, 0, '.', None, False, None,
119                     None])
120        hqe.save()
121        new_id = hqe.id
122        # Force a re-query and verify that the correct data was stored.
123        scheduler_models.DBObject._clear_instance_cache()
124        hqe = scheduler_models.HostQueueEntry(id=new_id)
125        self.assertEqual(hqe.id, new_id)
126        self.assertEqual(hqe.job_id, 1)
127        self.assertEqual(hqe.host_id, 2)
128        self.assertEqual(hqe.status, 'Queued')
129        self.assertEqual(hqe.meta_host, None)
130        self.assertEqual(hqe.active, False)
131        self.assertEqual(hqe.complete, False)
132        self.assertEqual(hqe.deleted, False)
133        self.assertEqual(hqe.execution_subdir, '.')
134        self.assertEqual(hqe.started_on, None)
135        self.assertEqual(hqe.finished_on, None)
136
137
138class HostTest(BaseSchedulerModelsTest):
139
140    def setUp(self):
141        super(HostTest, self).setUp()
142        self.old_config = scheduler_models.RESPECT_STATIC_LABELS
143
144
145    def tearDown(self):
146        super(HostTest, self).tearDown()
147        scheduler_models.RESPECT_STATIC_LABELS = self.old_config
148
149
150    def _setup_static_labels(self):
151        label1 = models.Label.objects.create(name='non_static_label')
152        non_static_platform = models.Label.objects.create(
153                name='static_platform', platform=False)
154        models.ReplacedLabel.objects.create(label_id=non_static_platform.id)
155
156        static_label1 = models.StaticLabel.objects.create(
157                name='no_reference_label', platform=False)
158        static_platform = models.StaticLabel.objects.create(
159                name=non_static_platform.name, platform=True)
160
161        host1 = models.Host.objects.create(hostname='test_host')
162        host1.labels.add(label1)
163        host1.labels.add(non_static_platform)
164        host1.static_labels.add(static_label1)
165        host1.static_labels.add(static_platform)
166        host1.save()
167        return host1
168
169
170    def test_platform_and_labels_with_respect(self):
171        scheduler_models.RESPECT_STATIC_LABELS = True
172        test_host = self._setup_static_labels()
173        host = scheduler_models.Host(id=test_host.id)
174        platform, all_labels = host.platform_and_labels()
175        self.assertEqual(platform, 'static_platform')
176        self.assertNotIn('no_reference_label', all_labels)
177        self.assertEqual(all_labels, ['non_static_label', 'static_platform'])
178
179
180    def test_platform_and_labels_without_respect(self):
181        scheduler_models.RESPECT_STATIC_LABELS = False
182        test_host = self._setup_static_labels()
183        host = scheduler_models.Host(id=test_host.id)
184        platform, all_labels = host.platform_and_labels()
185        self.assertIsNone(platform)
186        self.assertEqual(all_labels, ['non_static_label', 'static_platform'])
187
188
189    def test_cmp_for_sort(self):
190        expected_order = [
191                'alice', 'Host1', 'host2', 'host3', 'host09', 'HOST010',
192                'host10', 'host11', 'yolkfolk']
193        hostname_idx = list(scheduler_models.Host._fields).index('hostname')
194        row = [None] * len(scheduler_models.Host._fields)
195        hosts = []
196        for hostname in expected_order:
197            row[hostname_idx] = hostname
198            hosts.append(scheduler_models.Host(row=row, new_record=True))
199
200        host1 = hosts[expected_order.index('Host1')]
201        host010 = hosts[expected_order.index('HOST010')]
202        host10 = hosts[expected_order.index('host10')]
203        host3 = hosts[expected_order.index('host3')]
204        alice = hosts[expected_order.index('alice')]
205        self.assertEqual(0, scheduler_models.Host.cmp_for_sort(host10, host10))
206        self.assertEqual(1, scheduler_models.Host.cmp_for_sort(host10, host010))
207        self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host010, host10))
208        self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host1, host10))
209        self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host1, host010))
210        self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host3, host10))
211        self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host3, host010))
212        self.assertEqual(1, scheduler_models.Host.cmp_for_sort(host3, host1))
213        self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host1, host3))
214        self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(alice, host3))
215        self.assertEqual(1, scheduler_models.Host.cmp_for_sort(host3, alice))
216        self.assertEqual(0, scheduler_models.Host.cmp_for_sort(alice, alice))
217
218        hosts.sort(cmp=scheduler_models.Host.cmp_for_sort)
219        self.assertEqual(expected_order, [h.hostname for h in hosts])
220
221        hosts.reverse()
222        hosts.sort(cmp=scheduler_models.Host.cmp_for_sort)
223        self.assertEqual(expected_order, [h.hostname for h in hosts])
224
225
226class HostQueueEntryTest(BaseSchedulerModelsTest):
227    def _create_hqe(self, dependency_labels=(), **create_job_kwargs):
228        job = self._create_job(**create_job_kwargs)
229        for label in dependency_labels:
230            job.dependency_labels.add(label)
231        hqes = list(scheduler_models.HostQueueEntry.fetch(where='job_id=%d' % job.id))
232        self.assertEqual(1, len(hqes))
233        return hqes[0]
234
235
236    def _check_hqe_labels(self, hqe, expected_labels):
237        expected_labels = set(expected_labels)
238        label_names = set(label.name for label in hqe.get_labels())
239        self.assertEqual(expected_labels, label_names)
240
241
242    def test_get_labels_empty(self):
243        hqe = self._create_hqe(hosts=[1])
244        labels = list(hqe.get_labels())
245        self.assertEqual([], labels)
246
247
248    def test_get_labels_metahost(self):
249        hqe = self._create_hqe(metahosts=[2])
250        self._check_hqe_labels(hqe, ['label2'])
251
252
253    def test_get_labels_dependencies(self):
254        hqe = self._create_hqe(dependency_labels=(self.label3,),
255                               metahosts=[1])
256        self._check_hqe_labels(hqe, ['label1', 'label3'])
257
258
259    def setup_abort_test(self, agent_finished=True):
260        """Setup the variables for testing abort method.
261
262        @param agent_finished: True to mock agent is finished before aborting
263                               the hqe.
264        @return hqe, dispatcher: Mock object of hqe and dispatcher to be used
265                               to test abort method.
266        """
267        hqe = self._create_hqe(hosts=[1])
268        hqe.aborted = True
269        hqe.complete = False
270        hqe.status = models.HostQueueEntry.Status.STARTING
271        hqe.started_on = datetime.datetime.now()
272
273        dispatcher = self.god.create_mock_class(monitor_db.Dispatcher,
274                                                'Dispatcher')
275        agent = self.god.create_mock_class(monitor_db.Agent, 'Agent')
276        dispatcher.get_agents_for_entry.expect_call(hqe).and_return([agent])
277        agent.is_done.expect_call().and_return(agent_finished)
278        return hqe, dispatcher
279
280
281    def test_abort_fail_with_unfinished_agent(self):
282        """abort should fail if the hqe still has agent not finished.
283        """
284        hqe, dispatcher = self.setup_abort_test(agent_finished=False)
285        self.assertIsNone(hqe.finished_on)
286        with self.assertRaises(AssertionError):
287            hqe.abort(dispatcher)
288        self.god.check_playback()
289        # abort failed, finished_on should not be set
290        self.assertIsNone(hqe.finished_on)
291
292
293    def test_abort_success(self):
294        """abort should succeed if all agents for the hqe are finished.
295        """
296        hqe, dispatcher = self.setup_abort_test(agent_finished=True)
297        self.assertIsNone(hqe.finished_on)
298        hqe.abort(dispatcher)
299        self.god.check_playback()
300        self.assertIsNotNone(hqe.finished_on)
301
302
303    def test_set_finished_on(self):
304        """Test that finished_on is set when hqe completes."""
305        for status in host_queue_entry_states.Status.values:
306            hqe = self._create_hqe(hosts=[1])
307            hqe.started_on = datetime.datetime.now()
308            hqe.job.update_field('shard_id', 3)
309            self.assertIsNone(hqe.finished_on)
310            hqe.set_status(status)
311            if status in host_queue_entry_states.COMPLETE_STATUSES:
312                self.assertIsNotNone(hqe.finished_on)
313                self.assertIsNone(hqe.job.shard_id)
314            else:
315                self.assertIsNone(hqe.finished_on)
316                self.assertEquals(hqe.job.shard_id, 3)
317
318
319class JobTest(BaseSchedulerModelsTest):
320    def setUp(self):
321        super(JobTest, self).setUp()
322
323        def _mock_create(**kwargs):
324            task = models.SpecialTask(**kwargs)
325            task.save()
326            self._tasks.append(task)
327        self.god.stub_with(models.SpecialTask.objects, 'create', _mock_create)
328
329
330    def _test_pre_job_tasks_helper(self,
331                            reboot_before=model_attributes.RebootBefore.ALWAYS):
332        """
333        Calls HQE._do_schedule_pre_job_tasks() and returns the created special
334        task
335        """
336        self._tasks = []
337        queue_entry = scheduler_models.HostQueueEntry.fetch('id = 1')[0]
338        queue_entry.job.reboot_before = reboot_before
339        queue_entry._do_schedule_pre_job_tasks()
340        return self._tasks
341
342
343    def test_job_request_abort(self):
344        django_job = self._create_job(hosts=[5, 6])
345        job = scheduler_models.Job(django_job.id)
346        job.request_abort()
347        django_hqes = list(models.HostQueueEntry.objects.filter(job=job.id))
348        for hqe in django_hqes:
349            self.assertTrue(hqe.aborted)
350
351
352    def _check_special_tasks(self, tasks, task_types):
353        self.assertEquals(len(tasks), len(task_types))
354        for task, (task_type, queue_entry_id) in zip(tasks, task_types):
355            self.assertEquals(task.task, task_type)
356            self.assertEquals(task.host.id, 1)
357            if queue_entry_id:
358                self.assertEquals(task.queue_entry.id, queue_entry_id)
359
360
361    def test_run_asynchronous(self):
362        self._create_job(hosts=[1, 2])
363
364        tasks = self._test_pre_job_tasks_helper()
365
366        self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)])
367
368
369    def test_run_asynchronous_skip_verify(self):
370        job = self._create_job(hosts=[1, 2])
371        job.run_verify = False
372        job.save()
373
374        tasks = self._test_pre_job_tasks_helper()
375
376        self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)])
377
378
379    def test_run_synchronous_verify(self):
380        self._create_job(hosts=[1, 2], synchronous=True)
381
382        tasks = self._test_pre_job_tasks_helper()
383
384        self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)])
385
386
387    def test_run_synchronous_skip_verify(self):
388        job = self._create_job(hosts=[1, 2], synchronous=True)
389        job.run_verify = False
390        job.save()
391
392        tasks = self._test_pre_job_tasks_helper()
393
394        self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)])
395
396
397    def test_run_asynchronous_do_not_reset(self):
398        job = self._create_job(hosts=[1, 2])
399        job.run_reset = False
400        job.run_verify = False
401        job.save()
402
403        tasks = self._test_pre_job_tasks_helper()
404
405        self.assertEquals(tasks, [])
406
407
408    def test_run_synchronous_do_not_reset_no_RebootBefore(self):
409        job = self._create_job(hosts=[1, 2], synchronous=True)
410        job.reboot_before = model_attributes.RebootBefore.NEVER
411        job.save()
412
413        tasks = self._test_pre_job_tasks_helper(
414                            reboot_before=model_attributes.RebootBefore.NEVER)
415
416        self._check_special_tasks(tasks, [(models.SpecialTask.Task.VERIFY, 1)])
417
418
419    def test_run_asynchronous_do_not_reset(self):
420        job = self._create_job(hosts=[1, 2], synchronous=False)
421        job.reboot_before = model_attributes.RebootBefore.NEVER
422        job.save()
423
424        tasks = self._test_pre_job_tasks_helper(
425                            reboot_before=model_attributes.RebootBefore.NEVER)
426
427        self._check_special_tasks(tasks, [(models.SpecialTask.Task.VERIFY, 1)])
428
429
430    def test_reboot_before_always(self):
431        job = self._create_job(hosts=[1])
432        job.reboot_before = model_attributes.RebootBefore.ALWAYS
433        job.save()
434
435        tasks = self._test_pre_job_tasks_helper()
436
437        self._check_special_tasks(tasks, [
438                (models.SpecialTask.Task.RESET, None)
439            ])
440
441
442    def _test_reboot_before_if_dirty_helper(self):
443        job = self._create_job(hosts=[1])
444        job.reboot_before = model_attributes.RebootBefore.IF_DIRTY
445        job.save()
446
447        tasks = self._test_pre_job_tasks_helper()
448        task_types = [(models.SpecialTask.Task.RESET, None)]
449
450        self._check_special_tasks(tasks, task_types)
451
452
453    def test_reboot_before_if_dirty(self):
454        models.Host.smart_get(1).update_object(dirty=True)
455        self._test_reboot_before_if_dirty_helper()
456
457
458    def test_reboot_before_not_dirty(self):
459        models.Host.smart_get(1).update_object(dirty=False)
460        self._test_reboot_before_if_dirty_helper()
461
462
463if __name__ == '__main__':
464    unittest.main()
465