1#!/usr/bin/python
2#pylint: disable-msg=C0111
3
4import datetime
5import unittest
6
7import common
8from autotest_lib.frontend import setup_django_environment
9from autotest_lib.frontend.afe import frontend_test_utils
10from autotest_lib.client.common_lib import host_queue_entry_states
11from autotest_lib.client.common_lib.test_utils import mock
12from autotest_lib.database import database_connection
13from autotest_lib.frontend.afe import models, model_attributes
14from autotest_lib.scheduler import monitor_db
15from autotest_lib.scheduler import scheduler_lib
16from autotest_lib.scheduler import scheduler_models
17
18_DEBUG = False
19
20
21class BaseSchedulerModelsTest(unittest.TestCase,
22                              frontend_test_utils.FrontendTestMixin):
23    _config_section = 'AUTOTEST_WEB'
24
25    def _do_query(self, sql):
26        self._database.execute(sql)
27
28
29    def _set_monitor_stubs(self):
30        # Clear the instance cache as this is a brand new database.
31        scheduler_models.DBObject._clear_instance_cache()
32
33        self._database = (
34            database_connection.TranslatingDatabase.get_test_database(
35                translators=scheduler_lib._DB_TRANSLATORS))
36        self._database.connect(db_type='django')
37        self._database.debug = _DEBUG
38
39        self.god.stub_with(scheduler_models, '_db', self._database)
40
41
42    def setUp(self):
43        self._frontend_common_setup()
44        self._set_monitor_stubs()
45
46
47    def tearDown(self):
48        self._database.disconnect()
49        self._frontend_common_teardown()
50
51
52    def _update_hqe(self, set, where=''):
53        query = 'UPDATE afe_host_queue_entries SET ' + set
54        if where:
55            query += ' WHERE ' + where
56        self._do_query(query)
57
58
59class DelayedCallTaskTest(unittest.TestCase):
60    def setUp(self):
61        self.god = mock.mock_god()
62
63
64    def tearDown(self):
65        self.god.unstub_all()
66
67
68    def test_delayed_call(self):
69        test_time = self.god.create_mock_function('time')
70        test_time.expect_call().and_return(33)
71        test_time.expect_call().and_return(34.01)
72        test_time.expect_call().and_return(34.99)
73        test_time.expect_call().and_return(35.01)
74        def test_callback():
75            test_callback.calls += 1
76        test_callback.calls = 0
77        delay_task = scheduler_models.DelayedCallTask(
78                delay_seconds=2, callback=test_callback,
79                now_func=test_time)  # time 33
80        self.assertEqual(35, delay_task.end_time)
81        delay_task.poll()  # activates the task and polls it once, time 34.01
82        self.assertEqual(0, test_callback.calls, "callback called early")
83        delay_task.poll()  # time 34.99
84        self.assertEqual(0, test_callback.calls, "callback called early")
85        delay_task.poll()  # time 35.01
86        self.assertEqual(1, test_callback.calls)
87        self.assert_(delay_task.is_done())
88        self.assert_(delay_task.success)
89        self.assert_(not delay_task.aborted)
90        self.god.check_playback()
91
92
93    def test_delayed_call_abort(self):
94        delay_task = scheduler_models.DelayedCallTask(
95                delay_seconds=987654, callback=lambda : None)
96        delay_task.abort()
97        self.assert_(delay_task.aborted)
98        self.assert_(delay_task.is_done())
99        self.assert_(not delay_task.success)
100        self.god.check_playback()
101
102
103class DBObjectTest(BaseSchedulerModelsTest):
104    def test_compare_fields_in_row(self):
105        host = scheduler_models.Host(id=1)
106        fields = list(host._fields)
107        row_data = [getattr(host, fieldname) for fieldname in fields]
108        self.assertEqual({}, host._compare_fields_in_row(row_data))
109        row_data[fields.index('hostname')] = 'spam'
110        self.assertEqual({'hostname': ('host1', 'spam')},
111                         host._compare_fields_in_row(row_data))
112        row_data[fields.index('id')] = 23
113        self.assertEqual({'hostname': ('host1', 'spam'), 'id': (1, 23)},
114                         host._compare_fields_in_row(row_data))
115
116
117    def test_compare_fields_in_row_datetime_ignores_microseconds(self):
118        datetime_with_us = datetime.datetime(2009, 10, 07, 12, 34, 56, 7890)
119        datetime_without_us = datetime.datetime(2009, 10, 07, 12, 34, 56, 0)
120        class TestTable(scheduler_models.DBObject):
121            _table_name = 'test_table'
122            _fields = ('id', 'test_datetime')
123        tt = TestTable(row=[1, datetime_without_us])
124        self.assertEqual({}, tt._compare_fields_in_row([1, datetime_with_us]))
125
126
127    def test_always_query(self):
128        host_a = scheduler_models.Host(id=2)
129        self.assertEqual(host_a.hostname, 'host2')
130        self._do_query('UPDATE afe_hosts SET hostname="host2-updated" '
131                       'WHERE id=2')
132        host_b = scheduler_models.Host(id=2, always_query=True)
133        self.assert_(host_a is host_b, 'Cached instance not returned.')
134        self.assertEqual(host_a.hostname, 'host2-updated',
135                         'Database was not re-queried')
136
137        # If either of these are called, a query was made when it shouldn't be.
138        host_a._compare_fields_in_row = lambda _: self.fail('eek! a query!')
139        host_a._update_fields_from_row = host_a._compare_fields_in_row
140        host_c = scheduler_models.Host(id=2, always_query=False)
141        self.assert_(host_a is host_c, 'Cached instance not returned')
142
143
144    def test_delete(self):
145        host = scheduler_models.Host(id=3)
146        host.delete()
147        host = self.assertRaises(scheduler_models.DBError, scheduler_models.Host, id=3,
148                                 always_query=False)
149        host = self.assertRaises(scheduler_models.DBError, scheduler_models.Host, id=3,
150                                 always_query=True)
151
152    def test_save(self):
153        # Dummy Job to avoid creating a one in the HostQueueEntry __init__.
154        class MockJob(object):
155            def __init__(self, id, row):
156                pass
157            def tag(self):
158                return 'MockJob'
159        self.god.stub_with(scheduler_models, 'Job', MockJob)
160        hqe = scheduler_models.HostQueueEntry(
161                new_record=True,
162                row=[0, 1, 2, 'Queued', None, 0, 0, 0, '.', None, False, None,
163                     None])
164        hqe.save()
165        new_id = hqe.id
166        # Force a re-query and verify that the correct data was stored.
167        scheduler_models.DBObject._clear_instance_cache()
168        hqe = scheduler_models.HostQueueEntry(id=new_id)
169        self.assertEqual(hqe.id, new_id)
170        self.assertEqual(hqe.job_id, 1)
171        self.assertEqual(hqe.host_id, 2)
172        self.assertEqual(hqe.status, 'Queued')
173        self.assertEqual(hqe.meta_host, None)
174        self.assertEqual(hqe.active, False)
175        self.assertEqual(hqe.complete, False)
176        self.assertEqual(hqe.deleted, False)
177        self.assertEqual(hqe.execution_subdir, '.')
178        self.assertEqual(hqe.started_on, None)
179        self.assertEqual(hqe.finished_on, None)
180
181
182class HostTest(BaseSchedulerModelsTest):
183    def test_cmp_for_sort(self):
184        expected_order = [
185                'alice', 'Host1', 'host2', 'host3', 'host09', 'HOST010',
186                'host10', 'host11', 'yolkfolk']
187        hostname_idx = list(scheduler_models.Host._fields).index('hostname')
188        row = [None] * len(scheduler_models.Host._fields)
189        hosts = []
190        for hostname in expected_order:
191            row[hostname_idx] = hostname
192            hosts.append(scheduler_models.Host(row=row, new_record=True))
193
194        host1 = hosts[expected_order.index('Host1')]
195        host010 = hosts[expected_order.index('HOST010')]
196        host10 = hosts[expected_order.index('host10')]
197        host3 = hosts[expected_order.index('host3')]
198        alice = hosts[expected_order.index('alice')]
199        self.assertEqual(0, scheduler_models.Host.cmp_for_sort(host10, host10))
200        self.assertEqual(1, scheduler_models.Host.cmp_for_sort(host10, host010))
201        self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host010, host10))
202        self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host1, host10))
203        self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host1, host010))
204        self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host3, host10))
205        self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host3, host010))
206        self.assertEqual(1, scheduler_models.Host.cmp_for_sort(host3, host1))
207        self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host1, host3))
208        self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(alice, host3))
209        self.assertEqual(1, scheduler_models.Host.cmp_for_sort(host3, alice))
210        self.assertEqual(0, scheduler_models.Host.cmp_for_sort(alice, alice))
211
212        hosts.sort(cmp=scheduler_models.Host.cmp_for_sort)
213        self.assertEqual(expected_order, [h.hostname for h in hosts])
214
215        hosts.reverse()
216        hosts.sort(cmp=scheduler_models.Host.cmp_for_sort)
217        self.assertEqual(expected_order, [h.hostname for h in hosts])
218
219
220class HostQueueEntryTest(BaseSchedulerModelsTest):
221    def _create_hqe(self, dependency_labels=(), **create_job_kwargs):
222        job = self._create_job(**create_job_kwargs)
223        for label in dependency_labels:
224            job.dependency_labels.add(label)
225        hqes = list(scheduler_models.HostQueueEntry.fetch(where='job_id=%d' % job.id))
226        self.assertEqual(1, len(hqes))
227        return hqes[0]
228
229
230    def _check_hqe_labels(self, hqe, expected_labels):
231        expected_labels = set(expected_labels)
232        label_names = set(label.name for label in hqe.get_labels())
233        self.assertEqual(expected_labels, label_names)
234
235
236    def test_get_labels_empty(self):
237        hqe = self._create_hqe(hosts=[1])
238        labels = list(hqe.get_labels())
239        self.assertEqual([], labels)
240
241
242    def test_get_labels_metahost(self):
243        hqe = self._create_hqe(metahosts=[2])
244        self._check_hqe_labels(hqe, ['label2'])
245
246
247    def test_get_labels_dependencies(self):
248        hqe = self._create_hqe(dependency_labels=(self.label3,),
249                               metahosts=[1])
250        self._check_hqe_labels(hqe, ['label1', 'label3'])
251
252
253    def setup_abort_test(self, agent_finished=True):
254        """Setup the variables for testing abort method.
255
256        @param agent_finished: True to mock agent is finished before aborting
257                               the hqe.
258        @return hqe, dispatcher: Mock object of hqe and dispatcher to be used
259                               to test abort method.
260        """
261        hqe = self._create_hqe(hosts=[1])
262        hqe.aborted = True
263        hqe.complete = False
264        hqe.status = models.HostQueueEntry.Status.STARTING
265        hqe.started_on = datetime.datetime.now()
266
267        dispatcher = self.god.create_mock_class(monitor_db.Dispatcher,
268                                                'Dispatcher')
269        agent = self.god.create_mock_class(monitor_db.Agent, 'Agent')
270        dispatcher.get_agents_for_entry.expect_call(hqe).and_return([agent])
271        agent.is_done.expect_call().and_return(agent_finished)
272        return hqe, dispatcher
273
274
275    def test_abort_fail_with_unfinished_agent(self):
276        """abort should fail if the hqe still has agent not finished.
277        """
278        hqe, dispatcher = self.setup_abort_test(agent_finished=False)
279        self.assertIsNone(hqe.finished_on)
280        with self.assertRaises(AssertionError):
281            hqe.abort(dispatcher)
282        self.god.check_playback()
283        # abort failed, finished_on should not be set
284        self.assertIsNone(hqe.finished_on)
285
286
287    def test_abort_success(self):
288        """abort should succeed if all agents for the hqe are finished.
289        """
290        hqe, dispatcher = self.setup_abort_test(agent_finished=True)
291        self.assertIsNone(hqe.finished_on)
292        hqe.abort(dispatcher)
293        self.god.check_playback()
294        self.assertIsNotNone(hqe.finished_on)
295
296
297    def test_set_finished_on(self):
298        """Test that finished_on is set when hqe completes."""
299        for status in host_queue_entry_states.Status.values:
300            hqe = self._create_hqe(hosts=[1])
301            hqe.started_on = datetime.datetime.now()
302            hqe.job.update_field('shard_id', 3)
303            self.assertIsNone(hqe.finished_on)
304            hqe.set_status(status)
305            if status in host_queue_entry_states.COMPLETE_STATUSES:
306                self.assertIsNotNone(hqe.finished_on)
307                self.assertIsNone(hqe.job.shard_id)
308            else:
309                self.assertIsNone(hqe.finished_on)
310                self.assertEquals(hqe.job.shard_id, 3)
311
312
313class JobTest(BaseSchedulerModelsTest):
314    def setUp(self):
315        super(JobTest, self).setUp()
316
317        def _mock_create(**kwargs):
318            task = models.SpecialTask(**kwargs)
319            task.save()
320            self._tasks.append(task)
321        self.god.stub_with(models.SpecialTask.objects, 'create', _mock_create)
322
323
324    def _test_pre_job_tasks_helper(self,
325                            reboot_before=model_attributes.RebootBefore.ALWAYS):
326        """
327        Calls HQE._do_schedule_pre_job_tasks() and returns the created special
328        task
329        """
330        self._tasks = []
331        queue_entry = scheduler_models.HostQueueEntry.fetch('id = 1')[0]
332        queue_entry.job.reboot_before = reboot_before
333        queue_entry._do_schedule_pre_job_tasks()
334        return self._tasks
335
336
337    def test_job_request_abort(self):
338        django_job = self._create_job(hosts=[5, 6])
339        job = scheduler_models.Job(django_job.id)
340        job.request_abort()
341        django_hqes = list(models.HostQueueEntry.objects.filter(job=job.id))
342        for hqe in django_hqes:
343            self.assertTrue(hqe.aborted)
344
345
346    def _check_special_tasks(self, tasks, task_types):
347        self.assertEquals(len(tasks), len(task_types))
348        for task, (task_type, queue_entry_id) in zip(tasks, task_types):
349            self.assertEquals(task.task, task_type)
350            self.assertEquals(task.host.id, 1)
351            if queue_entry_id:
352                self.assertEquals(task.queue_entry.id, queue_entry_id)
353
354
355    def test_run_asynchronous(self):
356        self._create_job(hosts=[1, 2])
357
358        tasks = self._test_pre_job_tasks_helper()
359
360        self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)])
361
362
363    def test_run_asynchronous_skip_verify(self):
364        job = self._create_job(hosts=[1, 2])
365        job.run_verify = False
366        job.save()
367
368        tasks = self._test_pre_job_tasks_helper()
369
370        self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)])
371
372
373    def test_run_synchronous_verify(self):
374        self._create_job(hosts=[1, 2], synchronous=True)
375
376        tasks = self._test_pre_job_tasks_helper()
377
378        self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)])
379
380
381    def test_run_synchronous_skip_verify(self):
382        job = self._create_job(hosts=[1, 2], synchronous=True)
383        job.run_verify = False
384        job.save()
385
386        tasks = self._test_pre_job_tasks_helper()
387
388        self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)])
389
390
391    def test_run_asynchronous_do_not_reset(self):
392        job = self._create_job(hosts=[1, 2])
393        job.run_reset = False
394        job.run_verify = False
395        job.save()
396
397        tasks = self._test_pre_job_tasks_helper()
398
399        self.assertEquals(tasks, [])
400
401
402    def test_run_synchronous_do_not_reset_no_RebootBefore(self):
403        job = self._create_job(hosts=[1, 2], synchronous=True)
404        job.reboot_before = model_attributes.RebootBefore.NEVER
405        job.save()
406
407        tasks = self._test_pre_job_tasks_helper(
408                            reboot_before=model_attributes.RebootBefore.NEVER)
409
410        self._check_special_tasks(tasks, [(models.SpecialTask.Task.VERIFY, 1)])
411
412
413    def test_run_asynchronous_do_not_reset(self):
414        job = self._create_job(hosts=[1, 2], synchronous=False)
415        job.reboot_before = model_attributes.RebootBefore.NEVER
416        job.save()
417
418        tasks = self._test_pre_job_tasks_helper(
419                            reboot_before=model_attributes.RebootBefore.NEVER)
420
421        self._check_special_tasks(tasks, [(models.SpecialTask.Task.VERIFY, 1)])
422
423
424    def test_reboot_before_always(self):
425        job = self._create_job(hosts=[1])
426        job.reboot_before = model_attributes.RebootBefore.ALWAYS
427        job.save()
428
429        tasks = self._test_pre_job_tasks_helper()
430
431        self._check_special_tasks(tasks, [
432                (models.SpecialTask.Task.RESET, None)
433            ])
434
435
436    def _test_reboot_before_if_dirty_helper(self):
437        job = self._create_job(hosts=[1])
438        job.reboot_before = model_attributes.RebootBefore.IF_DIRTY
439        job.save()
440
441        tasks = self._test_pre_job_tasks_helper()
442        task_types = [(models.SpecialTask.Task.RESET, None)]
443
444        self._check_special_tasks(tasks, task_types)
445
446
447    def test_reboot_before_if_dirty(self):
448        models.Host.smart_get(1).update_object(dirty=True)
449        self._test_reboot_before_if_dirty_helper()
450
451
452    def test_reboot_before_not_dirty(self):
453        models.Host.smart_get(1).update_object(dirty=False)
454        self._test_reboot_before_if_dirty_helper()
455
456
457if __name__ == '__main__':
458    unittest.main()
459