monitor_db_unittest.py revision 364fe862f88bcaaefaa40dd145a777bee840ec9b
1#!/usr/bin/python
2
3import unittest, time, subprocess, os, StringIO, tempfile, datetime
4import MySQLdb
5import common
6from autotest_lib.frontend import setup_django_environment
7from autotest_lib.frontend import setup_test_environment
8from autotest_lib.client.common_lib import global_config, host_protections
9from autotest_lib.client.common_lib.test_utils import mock
10from autotest_lib.database import database_connection, migrate
11from autotest_lib.frontend.afe import models
12from autotest_lib.scheduler import monitor_db
13
14_DEBUG = False
15
16class Dummy(object):
17    'Dummy object that can have attribute assigned to it'
18
19
20class IsRow(mock.argument_comparator):
21    def __init__(self, row_id):
22        self.row_id = row_id
23
24
25    def is_satisfied_by(self, parameter):
26        return list(parameter)[0] == self.row_id
27
28
29    def __str__(self):
30        return 'row with id %s' % self.row_id
31
32
33class BaseSchedulerTest(unittest.TestCase):
34    _config_section = 'AUTOTEST_WEB'
35    _test_db_initialized = False
36
37    def _do_query(self, sql):
38        self._database.execute(sql)
39
40
41    @classmethod
42    def _initialize_test_db(cls):
43        if cls._test_db_initialized:
44            return
45        temp_fd, cls._test_db_file = tempfile.mkstemp(suffix='.monitor_test')
46        os.close(temp_fd)
47        setup_test_environment.set_test_database(cls._test_db_file)
48        setup_test_environment.run_syncdb()
49        cls._test_db_backup = setup_test_environment.backup_test_database()
50        cls._test_db_initialized = True
51
52
53    def _open_test_db(self):
54        self._initialize_test_db()
55        setup_test_environment.restore_test_database(self._test_db_backup)
56        self._database = (
57            database_connection.DatabaseConnection.get_test_database(
58                self._test_db_file))
59        self._database.connect()
60        self._database.debug = _DEBUG
61
62
63    def _close_test_db(self):
64        self._database.disconnect()
65
66
67    def _set_monitor_stubs(self):
68        monitor_db._db = self._database
69
70
71    def _fill_in_test_data(self):
72        user = models.User.objects.create(login='my_user')
73        acl_group = models.AclGroup.objects.create(name='my_acl')
74        acl_group.users.add(user)
75
76        hosts = [models.Host.objects.create(hostname=hostname) for hostname in
77                 ('host1', 'host2', 'host3', 'host4')]
78        acl_group.hosts = hosts
79
80        labels = [models.Label.objects.create(name=name) for name in
81                  ('label1', 'label2', 'label3')]
82        labels[2].only_if_needed = True
83        labels[2].save()
84        hosts[0].labels.add(labels[0])
85        hosts[1].labels.add(labels[1])
86
87
88    def setUp(self):
89        self.god = mock.mock_god()
90        self._open_test_db()
91        self._fill_in_test_data()
92        self._set_monitor_stubs()
93        self._dispatcher = monitor_db.Dispatcher()
94
95
96    def tearDown(self):
97        self._close_test_db()
98        self.god.unstub_all()
99
100
101    def _create_job(self, hosts=[], metahosts=[], priority=0, active=0,
102                    synchronous=False):
103        synch_type = synchronous and 2 or 1
104        created_on = datetime.datetime(2008, 1, 1)
105        job = models.Job.objects.create(name='test', owner='my_user',
106                                        priority=priority,
107                                        synch_type=synch_type,
108                                        created_on=created_on)
109        for host_id in hosts:
110            models.HostQueueEntry.objects.create(job=job, priority=priority,
111                                                 host_id=host_id, active=active)
112            models.IneligibleHostQueue.objects.create(job=job, host_id=host_id)
113        for label_id in metahosts:
114            models.HostQueueEntry.objects.create(job=job, priority=priority,
115                                                 meta_host_id=label_id,
116                                                 active=active)
117        return job
118
119
120    def _create_job_simple(self, hosts, use_metahost=False,
121                          priority=0, active=0):
122        'An alternative interface to _create_job'
123        args = {'hosts' : [], 'metahosts' : []}
124        if use_metahost:
125            args['metahosts'] = hosts
126        else:
127            args['hosts'] = hosts
128        return self._create_job(priority=priority, active=active, **args)
129
130
131    def _update_hqe(self, set, where=''):
132        query = 'UPDATE host_queue_entries SET ' + set
133        if where:
134            query += ' WHERE ' + where
135        self._do_query(query)
136
137
138class DispatcherSchedulingTest(BaseSchedulerTest):
139    _jobs_scheduled = []
140
141    def _set_monitor_stubs(self):
142        super(DispatcherSchedulingTest, self)._set_monitor_stubs()
143        def run_stub(hqe_self, assigned_host=None):
144            hqe_self.set_status('Starting')
145            if hqe_self.meta_host:
146                host = assigned_host
147            else:
148                host = hqe_self.host
149            self._record_job_scheduled(hqe_self.job.id, host.id)
150            return Dummy()
151        monitor_db.HostQueueEntry.run = run_stub
152
153
154    def _record_job_scheduled(self, job_id, host_id):
155        record = (job_id, host_id)
156        self.assert_(record not in self._jobs_scheduled,
157                     'Job %d scheduled on host %d twice' %
158                     (job_id, host_id))
159        self._jobs_scheduled.append(record)
160
161
162    def _assert_job_scheduled_on(self, job_id, host_id):
163        record = (job_id, host_id)
164        self.assert_(record in self._jobs_scheduled,
165                     'Job %d not scheduled on host %d as expected\n'
166                     'Jobs scheduled: %s' %
167                     (job_id, host_id, self._jobs_scheduled))
168        self._jobs_scheduled.remove(record)
169
170
171    def _check_for_extra_schedulings(self):
172        if len(self._jobs_scheduled) != 0:
173            self.fail('Extra jobs scheduled: ' +
174                      str(self._jobs_scheduled))
175
176
177    def _convert_jobs_to_metahosts(self, *job_ids):
178        sql_tuple = '(' + ','.join(str(i) for i in job_ids) + ')'
179        self._do_query('UPDATE host_queue_entries SET '
180                       'meta_host=host_id, host_id=NULL '
181                       'WHERE job_id IN ' + sql_tuple)
182
183
184    def _lock_host(self, host_id):
185        self._do_query('UPDATE hosts SET locked=1 WHERE id=' +
186                       str(host_id))
187
188
189    def setUp(self):
190        super(DispatcherSchedulingTest, self).setUp()
191        self._jobs_scheduled = []
192
193
194    def _test_basic_scheduling_helper(self, use_metahosts):
195        'Basic nonmetahost scheduling'
196        self._create_job_simple([1], use_metahosts)
197        self._create_job_simple([2], use_metahosts)
198        self._dispatcher._schedule_new_jobs()
199        self._assert_job_scheduled_on(1, 1)
200        self._assert_job_scheduled_on(2, 2)
201        self._check_for_extra_schedulings()
202
203
204    def _test_priorities_helper(self, use_metahosts):
205        'Test prioritization ordering'
206        self._create_job_simple([1], use_metahosts)
207        self._create_job_simple([2], use_metahosts)
208        self._create_job_simple([1,2], use_metahosts)
209        self._create_job_simple([1], use_metahosts, priority=1)
210        self._dispatcher._schedule_new_jobs()
211        self._assert_job_scheduled_on(4, 1) # higher priority
212        self._assert_job_scheduled_on(2, 2) # earlier job over later
213        self._check_for_extra_schedulings()
214
215
216    def _test_hosts_ready_helper(self, use_metahosts):
217        """
218        Only hosts that are status=Ready, unlocked and not invalid get
219        scheduled.
220        """
221        self._create_job_simple([1], use_metahosts)
222        self._do_query('UPDATE hosts SET status="Running" WHERE id=1')
223        self._dispatcher._schedule_new_jobs()
224        self._check_for_extra_schedulings()
225
226        self._do_query('UPDATE hosts SET status="Ready", locked=1 '
227                       'WHERE id=1')
228        self._dispatcher._schedule_new_jobs()
229        self._check_for_extra_schedulings()
230
231        self._do_query('UPDATE hosts SET locked=0, invalid=1 '
232                       'WHERE id=1')
233        self._dispatcher._schedule_new_jobs()
234        if not use_metahosts:
235            self._assert_job_scheduled_on(1, 1)
236        self._check_for_extra_schedulings()
237
238
239    def _test_hosts_idle_helper(self, use_metahosts):
240        'Only idle hosts get scheduled'
241        self._create_job(hosts=[1], active=1)
242        self._create_job_simple([1], use_metahosts)
243        self._dispatcher._schedule_new_jobs()
244        self._check_for_extra_schedulings()
245
246
247    def _test_obey_ACLs_helper(self, use_metahosts):
248        self._do_query('DELETE FROM acl_groups_hosts WHERE host_id=1')
249        self._create_job_simple([1], use_metahosts)
250        self._dispatcher._schedule_new_jobs()
251        self._check_for_extra_schedulings()
252
253
254    def _test_only_if_needed_labels_helper(self, use_metahosts):
255        # apply only_if_needed label3 to host1
256        label3 = models.Label.smart_get('label3')
257        models.Host.smart_get('host1').labels.add(label3)
258
259        job = self._create_job_simple([1], use_metahosts)
260        # if the job doesn't depend on label3, there should be no scheduling
261        self._dispatcher._schedule_new_jobs()
262        self._check_for_extra_schedulings()
263
264        # now make the job depend on label3
265        job.dependency_labels.add(label3)
266        self._dispatcher._schedule_new_jobs()
267        self._assert_job_scheduled_on(1, 1)
268        self._check_for_extra_schedulings()
269
270        if use_metahosts:
271            # should also work if the metahost is the only_if_needed label
272            self._do_query('DELETE FROM jobs_dependency_labels')
273            self._create_job(metahosts=[3])
274            self._dispatcher._schedule_new_jobs()
275            self._assert_job_scheduled_on(2, 1)
276            self._check_for_extra_schedulings()
277
278
279    def test_basic_scheduling(self):
280        self._test_basic_scheduling_helper(False)
281
282
283    def test_priorities(self):
284        self._test_priorities_helper(False)
285
286
287    def test_hosts_ready(self):
288        self._test_hosts_ready_helper(False)
289
290
291    def test_hosts_idle(self):
292        self._test_hosts_idle_helper(False)
293
294
295    def test_obey_ACLs(self):
296        self._test_obey_ACLs_helper(False)
297
298
299    def test_only_if_needed_labels(self):
300        self._test_only_if_needed_labels_helper(False)
301
302
303    def test_non_metahost_on_invalid_host(self):
304        """
305        Non-metahost entries can get scheduled on invalid hosts (this is how
306        one-time hosts work).
307        """
308        self._do_query('UPDATE hosts SET invalid=1')
309        self._test_basic_scheduling_helper(False)
310
311
312    def test_metahost_scheduling(self):
313        """
314        Basic metahost scheduling
315        """
316        self._test_basic_scheduling_helper(True)
317
318
319    def test_metahost_priorities(self):
320        self._test_priorities_helper(True)
321
322
323    def test_metahost_hosts_ready(self):
324        self._test_hosts_ready_helper(True)
325
326
327    def test_metahost_hosts_idle(self):
328        self._test_hosts_idle_helper(True)
329
330
331    def test_metahost_obey_ACLs(self):
332        self._test_obey_ACLs_helper(True)
333
334
335    def test_metahost_only_if_needed_labels(self):
336        self._test_only_if_needed_labels_helper(True)
337
338
339    def test_nonmetahost_over_metahost(self):
340        """
341        Non-metahost entries should take priority over metahost entries
342        for the same host
343        """
344        self._create_job(metahosts=[1])
345        self._create_job(hosts=[1])
346        self._dispatcher._schedule_new_jobs()
347        self._assert_job_scheduled_on(2, 1)
348        self._check_for_extra_schedulings()
349
350
351    def test_metahosts_obey_blocks(self):
352        """
353        Metahosts can't get scheduled on hosts already scheduled for
354        that job.
355        """
356        self._create_job(metahosts=[1], hosts=[1])
357        # make the nonmetahost entry complete, so the metahost can try
358        # to get scheduled
359        self._update_hqe(set='complete = 1', where='host_id=1')
360        self._dispatcher._schedule_new_jobs()
361        self._check_for_extra_schedulings()
362
363
364    def test_only_schedule_queued_entries(self):
365        self._create_job(metahosts=[1])
366        self._update_hqe(set='active=1, host_id=2')
367        self._dispatcher._schedule_new_jobs()
368        self._check_for_extra_schedulings()
369
370
371class DispatcherThrottlingTest(BaseSchedulerTest):
372    """
373    Test that the dispatcher throttles:
374     * total number of running processes
375     * number of processes started per cycle
376    """
377    _MAX_RUNNING = 3
378    _MAX_STARTED = 2
379
380    def setUp(self):
381        super(DispatcherThrottlingTest, self).setUp()
382        self._dispatcher.max_running_processes = self._MAX_RUNNING
383        self._dispatcher.max_processes_started_per_cycle = self._MAX_STARTED
384
385
386    class DummyAgent(object):
387        _is_running = False
388        _is_done = False
389        num_processes = 1
390
391        def is_running(self):
392            return self._is_running
393
394
395        def tick(self):
396            self._is_running = True
397
398
399        def is_done(self):
400            return self._is_done
401
402
403        def set_done(self, done):
404            self._is_done = done
405            self._is_running = not done
406
407
408    def _setup_some_agents(self, num_agents):
409        self._agents = [self.DummyAgent() for i in xrange(num_agents)]
410        self._dispatcher._agents = list(self._agents)
411
412
413    def _run_a_few_cycles(self):
414        for i in xrange(4):
415            self._dispatcher._handle_agents()
416
417
418    def _assert_agents_started(self, indexes, is_started=True):
419        for i in indexes:
420            self.assert_(self._agents[i].is_running() == is_started,
421                         'Agent %d %sstarted' %
422                         (i, is_started and 'not ' or ''))
423
424
425    def _assert_agents_not_started(self, indexes):
426        self._assert_agents_started(indexes, False)
427
428
429    def test_throttle_total(self):
430        self._setup_some_agents(4)
431        self._run_a_few_cycles()
432        self._assert_agents_started([0, 1, 2])
433        self._assert_agents_not_started([3])
434
435
436    def test_throttle_per_cycle(self):
437        self._setup_some_agents(3)
438        self._dispatcher._handle_agents()
439        self._assert_agents_started([0, 1])
440        self._assert_agents_not_started([2])
441
442
443    def test_throttle_with_synchronous(self):
444        self._setup_some_agents(2)
445        self._agents[0].num_processes = 3
446        self._run_a_few_cycles()
447        self._assert_agents_started([0])
448        self._assert_agents_not_started([1])
449
450
451    def test_large_agent_starvation(self):
452        """
453        Ensure large agents don't get starved by lower-priority agents.
454        """
455        self._setup_some_agents(3)
456        self._agents[1].num_processes = 3
457        self._run_a_few_cycles()
458        self._assert_agents_started([0])
459        self._assert_agents_not_started([1, 2])
460
461        self._agents[0].set_done(True)
462        self._run_a_few_cycles()
463        self._assert_agents_started([1])
464        self._assert_agents_not_started([2])
465
466
467    def test_zero_process_agent(self):
468        self._setup_some_agents(5)
469        self._agents[4].num_processes = 0
470        self._run_a_few_cycles()
471        self._assert_agents_started([0, 1, 2, 4])
472        self._assert_agents_not_started([3])
473
474
475class AbortTest(BaseSchedulerTest):
476    """
477    Test both the dispatcher abort functionality and AbortTask.
478    """
479    def setUp(self):
480        super(AbortTest, self).setUp()
481        self.god.stub_class(monitor_db, 'RebootTask')
482        self.god.stub_class(monitor_db, 'VerifyTask')
483        self.god.stub_class(monitor_db, 'AbortTask')
484        self.god.stub_class(monitor_db, 'HostQueueEntry')
485        self.god.stub_class(monitor_db, 'Agent')
486
487
488    def _setup_queue_entries(self, host_id, hqe_id):
489        host = monitor_db.Host(id=host_id)
490        self.god.stub_function(host, 'set_status')
491        hqe = monitor_db.HostQueueEntry.expect_new(row=IsRow(hqe_id))
492        hqe.id = hqe_id
493        return host, hqe
494
495
496    def _setup_abort_expects(self, host, hqe, abort_agent=None):
497        hqe.get_host.expect_call().and_return(host)
498        reboot_task = monitor_db.RebootTask.expect_new(host)
499        verify_task = monitor_db.VerifyTask.expect_new(host=host)
500        if abort_agent:
501            abort_task = monitor_db.AbortTask.expect_new(hqe, [abort_agent])
502            tasks = [mock.is_instance_comparator(monitor_db.AbortTask)]
503        else:
504            hqe.set_status.expect_call('Aborted')
505            host.set_status.expect_call('Rebooting')
506            tasks = []
507        tasks += [reboot_task, verify_task]
508        agent = monitor_db.Agent.expect_new(tasks=tasks,
509                                            queue_entry_ids=[hqe.id])
510        agent.queue_entry_ids = [hqe.id]
511        return agent
512
513
514    def test_find_aborting_inactive(self):
515        self._create_job(hosts=[1, 2])
516        self._update_hqe(set='status="Abort"')
517
518        host1, hqe1 = self._setup_queue_entries(1, 1)
519        host2, hqe2 = self._setup_queue_entries(2, 2)
520        agent1 = self._setup_abort_expects(host1, hqe1)
521        agent2 = self._setup_abort_expects(host2, hqe2)
522
523        self._dispatcher._find_aborting()
524
525        self.assertEquals(self._dispatcher._agents, [agent1, agent2])
526        self.god.check_playback()
527
528
529    def test_find_aborting_active(self):
530        self._create_job(hosts=[1, 2])
531        self._update_hqe(set='status="Abort", active=1')
532        # have to make an Agent for the active HQEs
533        task = self.god.create_mock_class(monitor_db.QueueTask, 'QueueTask')
534        agent = self.god.create_mock_class(monitor_db.Agent, 'OldAgent')
535        agent.queue_entry_ids = [1, 2]
536        self._dispatcher.add_agent(agent)
537
538        host1, hqe1 = self._setup_queue_entries(1, 1)
539        host2, hqe2 = self._setup_queue_entries(2, 2)
540        agent1 = self._setup_abort_expects(host1, hqe1, abort_agent=agent)
541        agent2 = self._setup_abort_expects(host2, hqe2)
542
543        self._dispatcher._find_aborting()
544
545        self.assertEquals(self._dispatcher._agents, [agent1, agent2])
546        self.god.check_playback()
547
548
549class PidfileRunMonitorTest(unittest.TestCase):
550    results_dir = '/test/path'
551    pidfile_path = os.path.join(results_dir, monitor_db.AUTOSERV_PID_FILE)
552    pid = 12345
553    args = ('nice -n 10 autoserv -P 123-myuser/myhost -p -n '
554            '-r ' + results_dir + ' -b -u myuser -l my-job-name '
555            '-m myhost /tmp/filejx43Zi -c')
556    bad_args = args.replace(results_dir, '/random/results/dir')
557
558    def setUp(self):
559        self.god = mock.mock_god()
560        self.god.stub_function(monitor_db, 'open')
561        self.god.stub_function(os.path, 'exists')
562        self.god.stub_function(monitor_db.email_manager,
563                               'enqueue_notify_email')
564        self.monitor = monitor_db.PidfileRunMonitor(self.results_dir)
565
566
567    def tearDown(self):
568        self.god.unstub_all()
569
570
571    def set_not_yet_run(self):
572        os.path.exists.expect_call(self.pidfile_path).and_return(False)
573
574
575    def setup_pidfile(self, pidfile_contents):
576        os.path.exists.expect_call(self.pidfile_path).and_return(True)
577        pidfile = StringIO.StringIO(pidfile_contents)
578        monitor_db.open.expect_call(
579            self.pidfile_path, 'r').and_return(pidfile)
580
581
582    def set_running(self):
583        self.setup_pidfile(str(self.pid) + '\n')
584
585
586    def set_complete(self, error_code):
587        self.setup_pidfile(str(self.pid) + '\n' +
588                           str(error_code) + '\n')
589
590
591    def _test_read_pidfile_helper(self, expected_pid, expected_exit_status):
592        pid, exit_status = self.monitor.read_pidfile()
593        self.assertEquals(pid, expected_pid)
594        self.assertEquals(exit_status, expected_exit_status)
595        self.god.check_playback()
596
597
598    def test_read_pidfile(self):
599        self.set_not_yet_run()
600        self._test_read_pidfile_helper(None, None)
601
602        self.set_running()
603        self._test_read_pidfile_helper(self.pid, None)
604
605        self.set_complete(123)
606        self._test_read_pidfile_helper(self.pid, 123)
607
608
609    def test_read_pidfile_error(self):
610        self.setup_pidfile('asdf')
611        self.assertRaises(monitor_db.PidfileException,
612                          self.monitor.read_pidfile)
613        self.god.check_playback()
614
615
616    def setup_proc_cmdline(self, args):
617        proc_cmdline = args.replace(' ', '\x00')
618        proc_file = StringIO.StringIO(proc_cmdline)
619        monitor_db.open.expect_call(
620            '/proc/%d/cmdline' % self.pid, 'r').and_return(proc_file)
621
622
623    def setup_find_autoservs(self, process_dict):
624        self.god.stub_class_method(monitor_db.Dispatcher,
625                                   'find_autoservs')
626        monitor_db.Dispatcher.find_autoservs.expect_call().and_return(
627            process_dict)
628
629
630    def _test_get_pidfile_info_helper(self, expected_pid,
631                                      expected_exit_status):
632        pid, exit_status = self.monitor.get_pidfile_info()
633        self.assertEquals(pid, expected_pid)
634        self.assertEquals(exit_status, expected_exit_status)
635        self.god.check_playback()
636
637
638    def test_get_pidfile_info(self):
639        'normal cases for get_pidfile_info'
640        # running
641        self.set_running()
642        self.setup_proc_cmdline(self.args)
643        self._test_get_pidfile_info_helper(self.pid, None)
644
645        # exited during check
646        self.set_running()
647        monitor_db.open.expect_call(
648            '/proc/%d/cmdline' % self.pid, 'r').and_raises(IOError)
649        self.set_complete(123) # pidfile gets read again
650        self._test_get_pidfile_info_helper(self.pid, 123)
651
652        # completed
653        self.set_complete(123)
654        self._test_get_pidfile_info_helper(self.pid, 123)
655
656
657    def test_get_pidfile_info_running_no_proc(self):
658        'pidfile shows process running, but no proc exists'
659        # running but no proc
660        self.set_running()
661        monitor_db.open.expect_call(
662            '/proc/%d/cmdline' % self.pid, 'r').and_raises(IOError)
663        self.set_running()
664        monitor_db.email_manager.enqueue_notify_email.expect_call(
665            mock.is_string_comparator(), mock.is_string_comparator())
666        self._test_get_pidfile_info_helper(self.pid, 1)
667        self.assertTrue(self.monitor.lost_process)
668
669
670    def test_get_pidfile_info_not_yet_run(self):
671        "pidfile hasn't been written yet"
672        # process not running
673        self.set_not_yet_run()
674        self.setup_find_autoservs({})
675        self._test_get_pidfile_info_helper(None, None)
676
677        # process running
678        self.set_not_yet_run()
679        self.setup_find_autoservs({self.pid : self.args})
680        self._test_get_pidfile_info_helper(None, None)
681
682        # another process running under same pid
683        self.set_not_yet_run()
684        self.setup_find_autoservs({self.pid : self.bad_args})
685        self._test_get_pidfile_info_helper(None, None)
686
687
688class AgentTest(unittest.TestCase):
689    def setUp(self):
690        self.god = mock.mock_god()
691
692
693    def tearDown(self):
694        self.god.unstub_all()
695
696
697    def test_agent(self):
698        task1 = self.god.create_mock_class(monitor_db.AgentTask,
699                                          'task1')
700        task2 = self.god.create_mock_class(monitor_db.AgentTask,
701                                          'task2')
702        task3 = self.god.create_mock_class(monitor_db.AgentTask,
703                                           'task3')
704
705        task1.start.expect_call()
706        task1.is_done.expect_call().and_return(False)
707        task1.poll.expect_call()
708        task1.is_done.expect_call().and_return(True)
709        task1.is_done.expect_call().and_return(True)
710        task1.success = True
711
712        task2.start.expect_call()
713        task2.is_done.expect_call().and_return(True)
714        task2.is_done.expect_call().and_return(True)
715        task2.success = False
716        task2.failure_tasks = [task3]
717
718        task3.start.expect_call()
719        task3.is_done.expect_call().and_return(True)
720        task3.is_done.expect_call().and_return(True)
721        task3.success = True
722
723        agent = monitor_db.Agent([task1, task2])
724        agent.dispatcher = object()
725        agent.start()
726        while not agent.is_done():
727            agent.tick()
728        self.god.check_playback()
729
730
731class AgentTasksTest(unittest.TestCase):
732    TEMP_DIR = '/temp/dir'
733    HOSTNAME = 'myhost'
734    HOST_PROTECTION = host_protections.default
735
736    def setUp(self):
737        self.god = mock.mock_god()
738        self.god.stub_with(tempfile, 'mkdtemp',
739                           mock.mock_function('mkdtemp', self.TEMP_DIR))
740        self.god.stub_class_method(monitor_db.RunMonitor, 'run')
741        self.god.stub_class_method(monitor_db.RunMonitor, 'exit_code')
742        self.host = self.god.create_mock_class(monitor_db.Host, 'host')
743        self.host.hostname = self.HOSTNAME
744        self.host.protection = self.HOST_PROTECTION
745        self.queue_entry = self.god.create_mock_class(
746            monitor_db.HostQueueEntry, 'queue_entry')
747        self.queue_entry.host = self.host
748        self.queue_entry.meta_host = None
749
750
751    def tearDown(self):
752        self.god.unstub_all()
753
754
755    def run_task(self, task, success):
756        """
757        Do essentially what an Agent would do, but protect againt
758        infinite looping from test errors.
759        """
760        if not getattr(task, 'agent', None):
761            task.agent = object()
762        task.start()
763        count = 0
764        while not task.is_done():
765            count += 1
766            if count > 10:
767                print 'Task failed to finish'
768                # in case the playback has clues to why it
769                # failed
770                self.god.check_playback()
771                self.fail()
772            task.poll()
773        self.assertEquals(task.success, success)
774
775
776    def setup_run_monitor(self, exit_status):
777        monitor_db.RunMonitor.run.expect_call()
778        monitor_db.RunMonitor.exit_code.expect_call()
779        monitor_db.RunMonitor.exit_code.expect_call().and_return(
780            exit_status)
781
782
783    def _test_repair_task_helper(self, success):
784        self.host.set_status.expect_call('Repairing')
785        if success:
786            self.setup_run_monitor(0)
787            self.host.set_status.expect_call('Ready')
788        else:
789            self.setup_run_monitor(1)
790            self.host.set_status.expect_call('Repair Failed')
791
792        task = monitor_db.RepairTask(self.host)
793        self.assertEquals(task.failure_tasks, [])
794        self.run_task(task, success)
795
796        expected_protection = host_protections.Protection.get_string(
797            host_protections.default)
798        expected_protection = host_protections.Protection.get_attr_name(
799            expected_protection)
800
801        self.assertTrue(set(task.monitor.cmd) >=
802                        set(['autoserv', '-R', '-m', self.HOSTNAME, '-r',
803                             self.TEMP_DIR, '--host-protection',
804                             expected_protection]))
805        self.god.check_playback()
806
807
808    def test_repair_task(self):
809        self._test_repair_task_helper(True)
810        self._test_repair_task_helper(False)
811
812
813    def test_repair_task_with_queue_entry(self):
814        queue_entry = self.god.create_mock_class(
815            monitor_db.HostQueueEntry, 'queue_entry')
816        self.host.set_status.expect_call('Repairing')
817        self.setup_run_monitor(1)
818        self.host.set_status.expect_call('Repair Failed')
819        queue_entry.handle_host_failure.expect_call()
820
821        task = monitor_db.RepairTask(self.host, queue_entry)
822        self.run_task(task, False)
823        self.god.check_playback()
824
825
826    def setup_verify_expects(self, success, use_queue_entry):
827        if use_queue_entry:
828            self.queue_entry.set_status.expect_call('Verifying')
829            self.queue_entry.verify_results_dir.expect_call(
830                ).and_return('/verify/results/dir')
831            self.queue_entry.clear_results_dir.expect_call(
832                '/verify/results/dir')
833        self.host.set_status.expect_call('Verifying')
834        if success:
835            self.setup_run_monitor(0)
836            self.host.set_status.expect_call('Ready')
837        else:
838            self.setup_run_monitor(1)
839            if use_queue_entry:
840                self.queue_entry.requeue.expect_call()
841
842
843    def _check_verify_failure_tasks(self, verify_task):
844        self.assertEquals(len(verify_task.failure_tasks), 1)
845        repair_task = verify_task.failure_tasks[0]
846        self.assert_(isinstance(repair_task, monitor_db.RepairTask))
847        self.assertEquals(verify_task.host, repair_task.host)
848        if verify_task.queue_entry and not verify_task.queue_entry.meta_host:
849            self.assertEquals(repair_task.fail_queue_entry,
850                              verify_task.queue_entry)
851        else:
852            self.assertEquals(repair_task.fail_queue_entry, None)
853
854
855    def _test_verify_task_helper(self, success, use_queue_entry=False,
856                                 use_meta_host=False):
857        self.setup_verify_expects(success, use_queue_entry)
858
859        if use_queue_entry:
860            task = monitor_db.VerifyTask(
861                queue_entry=self.queue_entry)
862        else:
863            task = monitor_db.VerifyTask(host=self.host)
864        self._check_verify_failure_tasks(task)
865        self.run_task(task, success)
866        self.assertTrue(set(task.monitor.cmd) >=
867                        set(['autoserv', '-v', '-m', self.HOSTNAME, '-r',
868                        self.TEMP_DIR]))
869        self.god.check_playback()
870
871
872    def test_verify_task_with_host(self):
873        self._test_verify_task_helper(True)
874        self._test_verify_task_helper(False)
875
876
877    def test_verify_task_with_queue_entry(self):
878        self._test_verify_task_helper(True, use_queue_entry=True)
879        self._test_verify_task_helper(False, use_queue_entry=True)
880
881
882    def test_verify_task_with_metahost(self):
883        self._test_verify_task_helper(True, use_queue_entry=True,
884                                      use_meta_host=True)
885        self._test_verify_task_helper(False, use_queue_entry=True,
886                                      use_meta_host=True)
887
888
889    def test_verify_synchronous_task(self):
890        job = self.god.create_mock_class(monitor_db.Job, 'job')
891
892        self.setup_verify_expects(True, True)
893        job.num_complete.expect_call().and_return(0)
894        self.queue_entry.on_pending.expect_call()
895        self.queue_entry.job = job
896
897        task = monitor_db.VerifySynchronousTask(self.queue_entry)
898        task.agent = Dummy()
899        task.agent.dispatcher = Dummy()
900        self.god.stub_with(task.agent.dispatcher, 'add_agent',
901                           mock.mock_function('add_agent'))
902        self.run_task(task, True)
903        self.god.check_playback()
904
905
906class JobTest(BaseSchedulerTest):
907    def _test_run_helper(self, expect_agent=True):
908        job = monitor_db.Job.fetch('id = 1').next()
909        queue_entry = monitor_db.HostQueueEntry.fetch('id = 1').next()
910        agent = job.run(queue_entry)
911
912        if not expect_agent:
913            self.assertEquals(agent, None)
914            return
915
916        self.assert_(isinstance(agent, monitor_db.Agent))
917        tasks = list(agent.queue.queue)
918        return tasks
919
920
921    def test_run_asynchronous(self):
922        self._create_job(hosts=[1, 2])
923
924        tasks = self._test_run_helper()
925
926        self.assertEquals(len(tasks), 2)
927        verify_task, queue_task = tasks
928
929        self.assert_(isinstance(verify_task, monitor_db.VerifyTask))
930        self.assertEquals(verify_task.queue_entry.id, 1)
931
932        self.assert_(isinstance(queue_task, monitor_db.QueueTask))
933        self.assertEquals(queue_task.job.id, 1)
934
935
936    def test_run_asynchronous_skip_verify(self):
937        job = self._create_job(hosts=[1, 2])
938        job.run_verify = False
939        job.save()
940
941        tasks = self._test_run_helper()
942
943        self.assertEquals(len(tasks), 1)
944        queue_task = tasks[0]
945
946        self.assert_(isinstance(queue_task, monitor_db.QueueTask))
947        self.assertEquals(queue_task.job.id, 1)
948
949
950    def test_run_synchronous_verify(self):
951        self._create_job(hosts=[1, 2], synchronous=True)
952
953        tasks = self._test_run_helper()
954        self.assertEquals(len(tasks), 1)
955        verify_task = tasks[0]
956
957        self.assert_(isinstance(verify_task, monitor_db.VerifySynchronousTask))
958        self.assertEquals(verify_task.queue_entry.id, 1)
959
960
961    def test_run_synchronous_skip_verify(self):
962        job = self._create_job(hosts=[1, 2], synchronous=True)
963        job.run_verify = False
964        job.save()
965
966        self._test_run_helper(expect_agent=False)
967
968        queue_entry = models.HostQueueEntry.smart_get(1)
969        self.assertEquals(queue_entry.status, 'Pending')
970
971
972    def test_run_synchronous_ready(self):
973        self._create_job(hosts=[1, 2], synchronous=True)
974        self._update_hqe("status='Pending'")
975
976        tasks = self._test_run_helper()
977        self.assertEquals(len(tasks), 1)
978        queue_task = tasks[0]
979
980        self.assert_(isinstance(queue_task, monitor_db.QueueTask))
981        self.assertEquals(queue_task.job.id, 1)
982        hqe_ids = [hqe.id for hqe in queue_task.queue_entries]
983        self.assertEquals(hqe_ids, [1, 2])
984
985
986if __name__ == '__main__':
987    unittest.main()
988