monitor_db_unittest.py revision b2e2c325bc0b1d822690b6af07f920d5da398cb8
1#!/usr/bin/python
2
3import unittest, time, subprocess, os, StringIO, tempfile, datetime
4import MySQLdb
5import common
6from autotest_lib.client.common_lib import global_config, host_protections
7from autotest_lib.client.common_lib.test_utils import mock
8from autotest_lib.database import database_connection, migrate
9from autotest_lib.scheduler import monitor_db
10
11from autotest_lib.frontend import django_test_utils
12django_test_utils.setup_test_environ()
13from autotest_lib.frontend.afe import models
14
15_DEBUG = False
16
17class Dummy(object):
18    'Dummy object that can have attribute assigned to it'
19
20
21class IsRow(mock.argument_comparator):
22    def __init__(self, row_id):
23        self.row_id = row_id
24
25
26    def is_satisfied_by(self, parameter):
27        return list(parameter)[0] == self.row_id
28
29
30    def __str__(self):
31        return 'row with id %s' % self.row_id
32
33
34class BaseSchedulerTest(unittest.TestCase):
35    _config_section = 'AUTOTEST_WEB'
36    _test_db_initialized = False
37
38    def _do_query(self, sql):
39        self._database.execute(sql)
40
41
42    @classmethod
43    def _initialize_test_db(cls):
44        if cls._test_db_initialized:
45            return
46        temp_fd, cls._test_db_file = tempfile.mkstemp(suffix='.monitor_test')
47        os.close(temp_fd)
48        django_test_utils.set_test_database(cls._test_db_file)
49        django_test_utils.run_syncdb()
50        cls._test_db_backup = django_test_utils.backup_test_database()
51        cls._test_db_initialized = True
52
53
54    def _open_test_db(self):
55        self._initialize_test_db()
56        django_test_utils.restore_test_database(self._test_db_backup)
57        self._database = (
58            database_connection.DatabaseConnection.get_test_database(
59                self._test_db_file))
60        self._database.connect()
61        self._database.debug = _DEBUG
62
63
64    def _close_test_db(self):
65        self._database.disconnect()
66
67
68    def _set_monitor_stubs(self):
69        monitor_db._db = self._database
70
71
72    def _fill_in_test_data(self):
73        user = models.User.objects.create(login='my_user')
74        acl_group = models.AclGroup.objects.create(name='my_acl')
75        acl_group.users.add(user)
76
77        hosts = [models.Host.objects.create(hostname=hostname) for hostname in
78                 ('host1', 'host2', 'host3', 'host4')]
79        acl_group.hosts = hosts
80
81        labels = [models.Label.objects.create(name=name) for name in
82                  ('label1', 'label2', 'label3')]
83        labels[2].only_if_needed = True
84        labels[2].save()
85        hosts[0].labels.add(labels[0])
86        hosts[1].labels.add(labels[1])
87
88
89    def setUp(self):
90        self.god = mock.mock_god()
91        self._open_test_db()
92        self._fill_in_test_data()
93        self._set_monitor_stubs()
94        self._dispatcher = monitor_db.Dispatcher()
95
96
97    def tearDown(self):
98        self._close_test_db()
99        self.god.unstub_all()
100
101
102    def _create_job(self, hosts=[], metahosts=[], priority=0, active=0,
103                    synchronous=False):
104        synch_type = synchronous and 2 or 1
105        created_on = datetime.datetime(2008, 1, 1)
106        job = models.Job.objects.create(name='test', owner='my_user',
107                                        priority=priority,
108                                        synch_type=synch_type,
109                                        created_on=created_on)
110        for host_id in hosts:
111            models.HostQueueEntry.objects.create(job=job, priority=priority,
112                                                 host_id=host_id, active=active)
113            models.IneligibleHostQueue.objects.create(job=job, host_id=host_id)
114        for label_id in metahosts:
115            models.HostQueueEntry.objects.create(job=job, priority=priority,
116                                                 meta_host_id=label_id,
117                                                 active=active)
118        return job
119
120
121    def _create_job_simple(self, hosts, use_metahost=False,
122                          priority=0, active=0):
123        'An alternative interface to _create_job'
124        args = {'hosts' : [], 'metahosts' : []}
125        if use_metahost:
126            args['metahosts'] = hosts
127        else:
128            args['hosts'] = hosts
129        return self._create_job(priority=priority, active=active, **args)
130
131
132    def _update_hqe(self, set, where=''):
133        query = 'UPDATE host_queue_entries SET ' + set
134        if where:
135            query += ' WHERE ' + where
136        self._do_query(query)
137
138
139class DispatcherSchedulingTest(BaseSchedulerTest):
140    _jobs_scheduled = []
141
142    def _set_monitor_stubs(self):
143        super(DispatcherSchedulingTest, self)._set_monitor_stubs()
144        def run_stub(hqe_self, assigned_host=None):
145            hqe_self.set_status('Starting')
146            if hqe_self.meta_host:
147                host = assigned_host
148            else:
149                host = hqe_self.host
150            self._record_job_scheduled(hqe_self.job.id, host.id)
151            return Dummy()
152        monitor_db.HostQueueEntry.run = run_stub
153
154
155    def _record_job_scheduled(self, job_id, host_id):
156        record = (job_id, host_id)
157        self.assert_(record not in self._jobs_scheduled,
158                     'Job %d scheduled on host %d twice' %
159                     (job_id, host_id))
160        self._jobs_scheduled.append(record)
161
162
163    def _assert_job_scheduled_on(self, job_id, host_id):
164        record = (job_id, host_id)
165        self.assert_(record in self._jobs_scheduled,
166                     'Job %d not scheduled on host %d as expected\n'
167                     'Jobs scheduled: %s' %
168                     (job_id, host_id, self._jobs_scheduled))
169        self._jobs_scheduled.remove(record)
170
171
172    def _check_for_extra_schedulings(self):
173        if len(self._jobs_scheduled) != 0:
174            self.fail('Extra jobs scheduled: ' +
175                      str(self._jobs_scheduled))
176
177
178    def _convert_jobs_to_metahosts(self, *job_ids):
179        sql_tuple = '(' + ','.join(str(i) for i in job_ids) + ')'
180        self._do_query('UPDATE host_queue_entries SET '
181                       'meta_host=host_id, host_id=NULL '
182                       'WHERE job_id IN ' + sql_tuple)
183
184
185    def _lock_host(self, host_id):
186        self._do_query('UPDATE hosts SET locked=1 WHERE id=' +
187                       str(host_id))
188
189
190    def setUp(self):
191        super(DispatcherSchedulingTest, self).setUp()
192        self._jobs_scheduled = []
193
194
195    def _test_basic_scheduling_helper(self, use_metahosts):
196        'Basic nonmetahost scheduling'
197        self._create_job_simple([1], use_metahosts)
198        self._create_job_simple([2], use_metahosts)
199        self._dispatcher._schedule_new_jobs()
200        self._assert_job_scheduled_on(1, 1)
201        self._assert_job_scheduled_on(2, 2)
202        self._check_for_extra_schedulings()
203
204
205    def _test_priorities_helper(self, use_metahosts):
206        'Test prioritization ordering'
207        self._create_job_simple([1], use_metahosts)
208        self._create_job_simple([2], use_metahosts)
209        self._create_job_simple([1,2], use_metahosts)
210        self._create_job_simple([1], use_metahosts, priority=1)
211        self._dispatcher._schedule_new_jobs()
212        self._assert_job_scheduled_on(4, 1) # higher priority
213        self._assert_job_scheduled_on(2, 2) # earlier job over later
214        self._check_for_extra_schedulings()
215
216
217    def _test_hosts_ready_helper(self, use_metahosts):
218        """
219        Only hosts that are status=Ready, unlocked and not invalid get
220        scheduled.
221        """
222        self._create_job_simple([1], use_metahosts)
223        self._do_query('UPDATE hosts SET status="Running" WHERE id=1')
224        self._dispatcher._schedule_new_jobs()
225        self._check_for_extra_schedulings()
226
227        self._do_query('UPDATE hosts SET status="Ready", locked=1 '
228                       'WHERE id=1')
229        self._dispatcher._schedule_new_jobs()
230        self._check_for_extra_schedulings()
231
232        self._do_query('UPDATE hosts SET locked=0, invalid=1 '
233                       'WHERE id=1')
234        self._dispatcher._schedule_new_jobs()
235        if not use_metahosts:
236            self._assert_job_scheduled_on(1, 1)
237        self._check_for_extra_schedulings()
238
239
240    def _test_hosts_idle_helper(self, use_metahosts):
241        'Only idle hosts get scheduled'
242        self._create_job(hosts=[1], active=1)
243        self._create_job_simple([1], use_metahosts)
244        self._dispatcher._schedule_new_jobs()
245        self._check_for_extra_schedulings()
246
247
248    def _test_obey_ACLs_helper(self, use_metahosts):
249        self._do_query('DELETE FROM acl_groups_hosts WHERE host_id=1')
250        self._create_job_simple([1], use_metahosts)
251        self._dispatcher._schedule_new_jobs()
252        self._check_for_extra_schedulings()
253
254
255    def _test_only_if_needed_labels_helper(self, use_metahosts):
256        # apply only_if_needed label3 to host1
257        label3 = models.Label.smart_get('label3')
258        models.Host.smart_get('host1').labels.add(label3)
259
260        job = self._create_job_simple([1], use_metahosts)
261        # if the job doesn't depend on label3, there should be no scheduling
262        self._dispatcher._schedule_new_jobs()
263        self._check_for_extra_schedulings()
264
265        # now make the job depend on label3
266        job.dependency_labels.add(label3)
267        self._dispatcher._schedule_new_jobs()
268        self._assert_job_scheduled_on(1, 1)
269        self._check_for_extra_schedulings()
270
271        if use_metahosts:
272            # should also work if the metahost is the only_if_needed label
273            self._do_query('DELETE FROM jobs_dependency_labels')
274            self._create_job(metahosts=[3])
275            self._dispatcher._schedule_new_jobs()
276            self._assert_job_scheduled_on(2, 1)
277            self._check_for_extra_schedulings()
278
279
280    def test_basic_scheduling(self):
281        self._test_basic_scheduling_helper(False)
282
283
284    def test_priorities(self):
285        self._test_priorities_helper(False)
286
287
288    def test_hosts_ready(self):
289        self._test_hosts_ready_helper(False)
290
291
292    def test_hosts_idle(self):
293        self._test_hosts_idle_helper(False)
294
295
296    def test_obey_ACLs(self):
297        self._test_obey_ACLs_helper(False)
298
299
300    def test_only_if_needed_labels(self):
301        self._test_only_if_needed_labels_helper(False)
302
303
304    def test_non_metahost_on_invalid_host(self):
305        """
306        Non-metahost entries can get scheduled on invalid hosts (this is how
307        one-time hosts work).
308        """
309        self._do_query('UPDATE hosts SET invalid=1')
310        self._test_basic_scheduling_helper(False)
311
312
313    def test_metahost_scheduling(self):
314        """
315        Basic metahost scheduling
316        """
317        self._test_basic_scheduling_helper(True)
318
319
320    def test_metahost_priorities(self):
321        self._test_priorities_helper(True)
322
323
324    def test_metahost_hosts_ready(self):
325        self._test_hosts_ready_helper(True)
326
327
328    def test_metahost_hosts_idle(self):
329        self._test_hosts_idle_helper(True)
330
331
332    def test_metahost_obey_ACLs(self):
333        self._test_obey_ACLs_helper(True)
334
335
336    def test_metahost_only_if_needed_labels(self):
337        self._test_only_if_needed_labels_helper(True)
338
339
340    def test_nonmetahost_over_metahost(self):
341        """
342        Non-metahost entries should take priority over metahost entries
343        for the same host
344        """
345        self._create_job(metahosts=[1])
346        self._create_job(hosts=[1])
347        self._dispatcher._schedule_new_jobs()
348        self._assert_job_scheduled_on(2, 1)
349        self._check_for_extra_schedulings()
350
351
352    def test_metahosts_obey_blocks(self):
353        """
354        Metahosts can't get scheduled on hosts already scheduled for
355        that job.
356        """
357        self._create_job(metahosts=[1], hosts=[1])
358        # make the nonmetahost entry complete, so the metahost can try
359        # to get scheduled
360        self._update_hqe(set='complete = 1', where='host_id=1')
361        self._dispatcher._schedule_new_jobs()
362        self._check_for_extra_schedulings()
363
364
365    def test_only_schedule_queued_entries(self):
366        self._create_job(metahosts=[1])
367        self._update_hqe(set='active=1, host_id=2')
368        self._dispatcher._schedule_new_jobs()
369        self._check_for_extra_schedulings()
370
371
372class DispatcherThrottlingTest(BaseSchedulerTest):
373    """
374    Test that the dispatcher throttles:
375     * total number of running processes
376     * number of processes started per cycle
377    """
378    _MAX_RUNNING = 3
379    _MAX_STARTED = 2
380
381    def setUp(self):
382        super(DispatcherThrottlingTest, self).setUp()
383        self._dispatcher.max_running_processes = self._MAX_RUNNING
384        self._dispatcher.max_processes_started_per_cycle = self._MAX_STARTED
385
386
387    class DummyAgent(object):
388        _is_running = False
389        _is_done = False
390        num_processes = 1
391
392        def is_running(self):
393            return self._is_running
394
395
396        def tick(self):
397            self._is_running = True
398
399
400        def is_done(self):
401            return self._is_done
402
403
404        def set_done(self, done):
405            self._is_done = done
406            self._is_running = not done
407
408
409    def _setup_some_agents(self, num_agents):
410        self._agents = [self.DummyAgent() for i in xrange(num_agents)]
411        self._dispatcher._agents = list(self._agents)
412
413
414    def _run_a_few_cycles(self):
415        for i in xrange(4):
416            self._dispatcher._handle_agents()
417
418
419    def _assert_agents_started(self, indexes, is_started=True):
420        for i in indexes:
421            self.assert_(self._agents[i].is_running() == is_started,
422                         'Agent %d %sstarted' %
423                         (i, is_started and 'not ' or ''))
424
425
426    def _assert_agents_not_started(self, indexes):
427        self._assert_agents_started(indexes, False)
428
429
430    def test_throttle_total(self):
431        self._setup_some_agents(4)
432        self._run_a_few_cycles()
433        self._assert_agents_started([0, 1, 2])
434        self._assert_agents_not_started([3])
435
436
437    def test_throttle_per_cycle(self):
438        self._setup_some_agents(3)
439        self._dispatcher._handle_agents()
440        self._assert_agents_started([0, 1])
441        self._assert_agents_not_started([2])
442
443
444    def test_throttle_with_synchronous(self):
445        self._setup_some_agents(2)
446        self._agents[0].num_processes = 3
447        self._run_a_few_cycles()
448        self._assert_agents_started([0])
449        self._assert_agents_not_started([1])
450
451
452    def test_large_agent_starvation(self):
453        """
454        Ensure large agents don't get starved by lower-priority agents.
455        """
456        self._setup_some_agents(3)
457        self._agents[1].num_processes = 3
458        self._run_a_few_cycles()
459        self._assert_agents_started([0])
460        self._assert_agents_not_started([1, 2])
461
462        self._agents[0].set_done(True)
463        self._run_a_few_cycles()
464        self._assert_agents_started([1])
465        self._assert_agents_not_started([2])
466
467
468    def test_zero_process_agent(self):
469        self._setup_some_agents(5)
470        self._agents[4].num_processes = 0
471        self._run_a_few_cycles()
472        self._assert_agents_started([0, 1, 2, 4])
473        self._assert_agents_not_started([3])
474
475
476class AbortTest(BaseSchedulerTest):
477    """
478    Test both the dispatcher abort functionality and AbortTask.
479    """
480    def setUp(self):
481        super(AbortTest, self).setUp()
482        self.god.stub_class(monitor_db, 'RebootTask')
483        self.god.stub_class(monitor_db, 'VerifyTask')
484        self.god.stub_class(monitor_db, 'AbortTask')
485        self.god.stub_class(monitor_db, 'HostQueueEntry')
486        self.god.stub_class(monitor_db, 'Agent')
487
488
489    def _setup_queue_entries(self, host_id, hqe_id):
490        host = monitor_db.Host(id=host_id)
491        self.god.stub_function(host, 'set_status')
492        hqe = monitor_db.HostQueueEntry.expect_new(row=IsRow(hqe_id))
493        hqe.id = hqe_id
494        return host, hqe
495
496
497    def _setup_abort_expects(self, host, hqe, abort_agent=None):
498        hqe.get_host.expect_call().and_return(host)
499        reboot_task = monitor_db.RebootTask.expect_new(host)
500        verify_task = monitor_db.VerifyTask.expect_new(host=host)
501        if abort_agent:
502            abort_task = monitor_db.AbortTask.expect_new(hqe, [abort_agent])
503            tasks = [mock.is_instance_comparator(monitor_db.AbortTask)]
504        else:
505            hqe.set_status.expect_call('Aborted')
506            host.set_status.expect_call('Rebooting')
507            tasks = []
508        tasks += [reboot_task, verify_task]
509        agent = monitor_db.Agent.expect_new(tasks=tasks,
510                                            queue_entry_ids=[hqe.id])
511        agent.queue_entry_ids = [hqe.id]
512        return agent
513
514
515    def test_find_aborting_inactive(self):
516        self._create_job(hosts=[1, 2])
517        self._update_hqe(set='status="Abort"')
518
519        host1, hqe1 = self._setup_queue_entries(1, 1)
520        host2, hqe2 = self._setup_queue_entries(2, 2)
521        agent1 = self._setup_abort_expects(host1, hqe1)
522        agent2 = self._setup_abort_expects(host2, hqe2)
523
524        self._dispatcher._find_aborting()
525
526        self.assertEquals(self._dispatcher._agents, [agent1, agent2])
527        self.god.check_playback()
528
529
530    def test_find_aborting_active(self):
531        self._create_job(hosts=[1, 2])
532        self._update_hqe(set='status="Abort", active=1')
533        # have to make an Agent for the active HQEs
534        task = self.god.create_mock_class(monitor_db.QueueTask, 'QueueTask')
535        agent = self.god.create_mock_class(monitor_db.Agent, 'OldAgent')
536        agent.queue_entry_ids = [1, 2]
537        self._dispatcher.add_agent(agent)
538
539        host1, hqe1 = self._setup_queue_entries(1, 1)
540        host2, hqe2 = self._setup_queue_entries(2, 2)
541        agent1 = self._setup_abort_expects(host1, hqe1, abort_agent=agent)
542        agent2 = self._setup_abort_expects(host2, hqe2)
543
544        self._dispatcher._find_aborting()
545
546        self.assertEquals(self._dispatcher._agents, [agent1, agent2])
547        self.god.check_playback()
548
549
550class PidfileRunMonitorTest(unittest.TestCase):
551    results_dir = '/test/path'
552    pidfile_path = os.path.join(results_dir, monitor_db.AUTOSERV_PID_FILE)
553    pid = 12345
554    args = ('nice -n 10 autoserv -P 123-myuser/myhost -p -n '
555            '-r ' + results_dir + ' -b -u myuser -l my-job-name '
556            '-m myhost /tmp/filejx43Zi -c')
557    bad_args = args.replace(results_dir, '/random/results/dir')
558
559    def setUp(self):
560        self.god = mock.mock_god()
561        self.god.stub_function(monitor_db, 'open')
562        self.god.stub_function(os.path, 'exists')
563        self.god.stub_function(monitor_db.email_manager,
564                               'enqueue_notify_email')
565        self.monitor = monitor_db.PidfileRunMonitor(self.results_dir)
566
567
568    def tearDown(self):
569        self.god.unstub_all()
570
571
572    def set_not_yet_run(self):
573        os.path.exists.expect_call(self.pidfile_path).and_return(False)
574
575
576    def setup_pidfile(self, pidfile_contents):
577        os.path.exists.expect_call(self.pidfile_path).and_return(True)
578        pidfile = StringIO.StringIO(pidfile_contents)
579        monitor_db.open.expect_call(
580            self.pidfile_path, 'r').and_return(pidfile)
581
582
583    def set_running(self):
584        self.setup_pidfile(str(self.pid) + '\n')
585
586
587    def set_complete(self, error_code):
588        self.setup_pidfile(str(self.pid) + '\n' +
589                           str(error_code) + '\n')
590
591
592    def _test_read_pidfile_helper(self, expected_pid, expected_exit_status):
593        pid, exit_status = self.monitor.read_pidfile()
594        self.assertEquals(pid, expected_pid)
595        self.assertEquals(exit_status, expected_exit_status)
596        self.god.check_playback()
597
598
599    def test_read_pidfile(self):
600        self.set_not_yet_run()
601        self._test_read_pidfile_helper(None, None)
602
603        self.set_running()
604        self._test_read_pidfile_helper(self.pid, None)
605
606        self.set_complete(123)
607        self._test_read_pidfile_helper(self.pid, 123)
608
609
610    def test_read_pidfile_error(self):
611        self.setup_pidfile('asdf')
612        self.assertRaises(monitor_db.PidfileException,
613                          self.monitor.read_pidfile)
614        self.god.check_playback()
615
616
617    def setup_proc_cmdline(self, args):
618        proc_cmdline = args.replace(' ', '\x00')
619        proc_file = StringIO.StringIO(proc_cmdline)
620        monitor_db.open.expect_call(
621            '/proc/%d/cmdline' % self.pid, 'r').and_return(proc_file)
622
623
624    def setup_find_autoservs(self, process_dict):
625        self.god.stub_class_method(monitor_db.Dispatcher,
626                                   'find_autoservs')
627        monitor_db.Dispatcher.find_autoservs.expect_call().and_return(
628            process_dict)
629
630
631    def _test_get_pidfile_info_helper(self, expected_pid,
632                                      expected_exit_status):
633        pid, exit_status = self.monitor.get_pidfile_info()
634        self.assertEquals(pid, expected_pid)
635        self.assertEquals(exit_status, expected_exit_status)
636        self.god.check_playback()
637
638
639    def test_get_pidfile_info(self):
640        'normal cases for get_pidfile_info'
641        # running
642        self.set_running()
643        self.setup_proc_cmdline(self.args)
644        self._test_get_pidfile_info_helper(self.pid, None)
645
646        # exited during check
647        self.set_running()
648        monitor_db.open.expect_call(
649            '/proc/%d/cmdline' % self.pid, 'r').and_raises(IOError)
650        self.set_complete(123) # pidfile gets read again
651        self._test_get_pidfile_info_helper(self.pid, 123)
652
653        # completed
654        self.set_complete(123)
655        self._test_get_pidfile_info_helper(self.pid, 123)
656
657
658    def test_get_pidfile_info_running_no_proc(self):
659        'pidfile shows process running, but no proc exists'
660        # running but no proc
661        self.set_running()
662        monitor_db.open.expect_call(
663            '/proc/%d/cmdline' % self.pid, 'r').and_raises(IOError)
664        self.set_running()
665        monitor_db.email_manager.enqueue_notify_email.expect_call(
666            mock.is_string_comparator(), mock.is_string_comparator())
667        self._test_get_pidfile_info_helper(self.pid, 1)
668        self.assertTrue(self.monitor.lost_process)
669
670
671    def test_get_pidfile_info_not_yet_run(self):
672        "pidfile hasn't been written yet"
673        # process not running
674        self.set_not_yet_run()
675        self.setup_find_autoservs({})
676        self._test_get_pidfile_info_helper(None, None)
677
678        # process running
679        self.set_not_yet_run()
680        self.setup_find_autoservs({self.pid : self.args})
681        self._test_get_pidfile_info_helper(None, None)
682
683        # another process running under same pid
684        self.set_not_yet_run()
685        self.setup_find_autoservs({self.pid : self.bad_args})
686        self._test_get_pidfile_info_helper(None, None)
687
688
689class AgentTest(unittest.TestCase):
690    def setUp(self):
691        self.god = mock.mock_god()
692
693
694    def tearDown(self):
695        self.god.unstub_all()
696
697
698    def test_agent(self):
699        task1 = self.god.create_mock_class(monitor_db.AgentTask,
700                                          'task1')
701        task2 = self.god.create_mock_class(monitor_db.AgentTask,
702                                          'task2')
703        task3 = self.god.create_mock_class(monitor_db.AgentTask,
704                                           'task3')
705
706        task1.start.expect_call()
707        task1.is_done.expect_call().and_return(False)
708        task1.poll.expect_call()
709        task1.is_done.expect_call().and_return(True)
710        task1.is_done.expect_call().and_return(True)
711        task1.success = True
712
713        task2.start.expect_call()
714        task2.is_done.expect_call().and_return(True)
715        task2.is_done.expect_call().and_return(True)
716        task2.success = False
717        task2.failure_tasks = [task3]
718
719        task3.start.expect_call()
720        task3.is_done.expect_call().and_return(True)
721        task3.is_done.expect_call().and_return(True)
722        task3.success = True
723
724        agent = monitor_db.Agent([task1, task2])
725        agent.dispatcher = object()
726        agent.start()
727        while not agent.is_done():
728            agent.tick()
729        self.god.check_playback()
730
731
732class AgentTasksTest(unittest.TestCase):
733    TEMP_DIR = '/temp/dir'
734    HOSTNAME = 'myhost'
735    HOST_PROTECTION = host_protections.default
736
737    def setUp(self):
738        self.god = mock.mock_god()
739        self.god.stub_with(tempfile, 'mkdtemp',
740                           mock.mock_function('mkdtemp', self.TEMP_DIR))
741        self.god.stub_class_method(monitor_db.RunMonitor, 'run')
742        self.god.stub_class_method(monitor_db.RunMonitor, 'exit_code')
743        self.host = self.god.create_mock_class(monitor_db.Host, 'host')
744        self.host.hostname = self.HOSTNAME
745        self.host.protection = self.HOST_PROTECTION
746        self.queue_entry = self.god.create_mock_class(
747            monitor_db.HostQueueEntry, 'queue_entry')
748        self.queue_entry.host = self.host
749        self.queue_entry.meta_host = None
750
751
752    def tearDown(self):
753        self.god.unstub_all()
754
755
756    def run_task(self, task, success):
757        """
758        Do essentially what an Agent would do, but protect againt
759        infinite looping from test errors.
760        """
761        if not getattr(task, 'agent', None):
762            task.agent = object()
763        task.start()
764        count = 0
765        while not task.is_done():
766            count += 1
767            if count > 10:
768                print 'Task failed to finish'
769                # in case the playback has clues to why it
770                # failed
771                self.god.check_playback()
772                self.fail()
773            task.poll()
774        self.assertEquals(task.success, success)
775
776
777    def setup_run_monitor(self, exit_status):
778        monitor_db.RunMonitor.run.expect_call()
779        monitor_db.RunMonitor.exit_code.expect_call()
780        monitor_db.RunMonitor.exit_code.expect_call().and_return(
781            exit_status)
782
783
784    def _test_repair_task_helper(self, success):
785        self.host.set_status.expect_call('Repairing')
786        if success:
787            self.setup_run_monitor(0)
788            self.host.set_status.expect_call('Ready')
789        else:
790            self.setup_run_monitor(1)
791            self.host.set_status.expect_call('Repair Failed')
792
793        task = monitor_db.RepairTask(self.host)
794        self.assertEquals(task.failure_tasks, [])
795        self.run_task(task, success)
796
797        expected_protection = host_protections.Protection.get_string(
798            host_protections.default)
799        expected_protection = host_protections.Protection.get_attr_name(
800            expected_protection)
801
802        self.assertTrue(set(task.monitor.cmd) >=
803                        set(['autoserv', '-R', '-m', self.HOSTNAME, '-r',
804                             self.TEMP_DIR, '--host-protection',
805                             expected_protection]))
806        self.god.check_playback()
807
808
809    def test_repair_task(self):
810        self._test_repair_task_helper(True)
811        self._test_repair_task_helper(False)
812
813
814    def test_repair_task_with_queue_entry(self):
815        queue_entry = self.god.create_mock_class(
816            monitor_db.HostQueueEntry, 'queue_entry')
817        self.host.set_status.expect_call('Repairing')
818        self.setup_run_monitor(1)
819        self.host.set_status.expect_call('Repair Failed')
820        queue_entry.handle_host_failure.expect_call()
821
822        task = monitor_db.RepairTask(self.host, queue_entry)
823        self.run_task(task, False)
824        self.god.check_playback()
825
826
827    def setup_verify_expects(self, success, use_queue_entry):
828        if use_queue_entry:
829            self.queue_entry.set_status.expect_call('Verifying')
830            self.queue_entry.verify_results_dir.expect_call(
831                ).and_return('/verify/results/dir')
832            self.queue_entry.clear_results_dir.expect_call(
833                '/verify/results/dir')
834        self.host.set_status.expect_call('Verifying')
835        if success:
836            self.setup_run_monitor(0)
837            self.host.set_status.expect_call('Ready')
838        else:
839            self.setup_run_monitor(1)
840            if use_queue_entry:
841                self.queue_entry.requeue.expect_call()
842
843
844    def _check_verify_failure_tasks(self, verify_task):
845        self.assertEquals(len(verify_task.failure_tasks), 1)
846        repair_task = verify_task.failure_tasks[0]
847        self.assert_(isinstance(repair_task, monitor_db.RepairTask))
848        self.assertEquals(verify_task.host, repair_task.host)
849        if verify_task.queue_entry and not verify_task.queue_entry.meta_host:
850            self.assertEquals(repair_task.fail_queue_entry,
851                              verify_task.queue_entry)
852        else:
853            self.assertEquals(repair_task.fail_queue_entry, None)
854
855
856    def _test_verify_task_helper(self, success, use_queue_entry=False,
857                                 use_meta_host=False):
858        self.setup_verify_expects(success, use_queue_entry)
859
860        if use_queue_entry:
861            task = monitor_db.VerifyTask(
862                queue_entry=self.queue_entry)
863        else:
864            task = monitor_db.VerifyTask(host=self.host)
865        self._check_verify_failure_tasks(task)
866        self.run_task(task, success)
867        self.assertTrue(set(task.monitor.cmd) >=
868                        set(['autoserv', '-v', '-m', self.HOSTNAME, '-r',
869                        self.TEMP_DIR]))
870        self.god.check_playback()
871
872
873    def test_verify_task_with_host(self):
874        self._test_verify_task_helper(True)
875        self._test_verify_task_helper(False)
876
877
878    def test_verify_task_with_queue_entry(self):
879        self._test_verify_task_helper(True, use_queue_entry=True)
880        self._test_verify_task_helper(False, use_queue_entry=True)
881
882
883    def test_verify_task_with_metahost(self):
884        self._test_verify_task_helper(True, use_queue_entry=True,
885                                      use_meta_host=True)
886        self._test_verify_task_helper(False, use_queue_entry=True,
887                                      use_meta_host=True)
888
889
890    def test_verify_synchronous_task(self):
891        job = self.god.create_mock_class(monitor_db.Job, 'job')
892
893        self.setup_verify_expects(True, True)
894        job.num_complete.expect_call().and_return(0)
895        self.queue_entry.on_pending.expect_call()
896        self.queue_entry.job = job
897
898        task = monitor_db.VerifySynchronousTask(self.queue_entry)
899        task.agent = Dummy()
900        task.agent.dispatcher = Dummy()
901        self.god.stub_with(task.agent.dispatcher, 'add_agent',
902                           mock.mock_function('add_agent'))
903        self.run_task(task, True)
904        self.god.check_playback()
905
906
907class JobTest(BaseSchedulerTest):
908    def _test_run_helper(self):
909        job = monitor_db.Job.fetch('id = 1').next()
910        queue_entry = monitor_db.HostQueueEntry.fetch('id = 1').next()
911        agent = job.run(queue_entry)
912
913        self.assert_(isinstance(agent, monitor_db.Agent))
914        tasks = list(agent.queue.queue)
915        return tasks
916
917
918    def test_run_asynchronous(self):
919        self._create_job(hosts=[1, 2])
920
921        tasks = self._test_run_helper()
922
923        self.assertEquals(len(tasks), 2)
924        verify_task, queue_task = tasks
925
926        self.assert_(isinstance(verify_task, monitor_db.VerifyTask))
927        self.assertEquals(verify_task.queue_entry.id, 1)
928
929        self.assert_(isinstance(queue_task, monitor_db.QueueTask))
930        self.assertEquals(queue_task.job.id, 1)
931
932
933    def test_run_synchronous_verify(self):
934        self._create_job(hosts=[1, 2], synchronous=True)
935
936        tasks = self._test_run_helper()
937        self.assertEquals(len(tasks), 1)
938        verify_task = tasks[0]
939
940        self.assert_(isinstance(verify_task, monitor_db.VerifySynchronousTask))
941        self.assertEquals(verify_task.queue_entry.id, 1)
942
943
944    def test_run_synchronous_ready(self):
945        self._create_job(hosts=[1, 2], synchronous=True)
946        self._update_hqe("status='Pending'")
947
948        tasks = self._test_run_helper()
949        self.assertEquals(len(tasks), 1)
950        queue_task = tasks[0]
951
952        self.assert_(isinstance(queue_task, monitor_db.QueueTask))
953        self.assertEquals(queue_task.job.id, 1)
954        hqe_ids = [hqe.id for hqe in queue_task.queue_entries]
955        self.assertEquals(hqe_ids, [1, 2])
956
957
958if __name__ == '__main__':
959    unittest.main()
960