1# Copyright 2014 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5"""Unit tests for the contents of parallelizer.py."""
6
7# pylint: disable=W0212
8# pylint: disable=W0613
9
10import os
11import tempfile
12import time
13import unittest
14
15from devil.utils import parallelizer
16
17
18class ParallelizerTestObject(object):
19  """Class used to test parallelizer.Parallelizer."""
20
21  parallel = parallelizer.Parallelizer
22
23  def __init__(self, thing, completion_file_name=None):
24    self._thing = thing
25    self._completion_file_name = completion_file_name
26    self.helper = ParallelizerTestObjectHelper(thing)
27
28  @staticmethod
29  def doReturn(what):
30    return what
31
32  @classmethod
33  def doRaise(cls, what):
34    raise what
35
36  def doSetTheThing(self, new_thing):
37    self._thing = new_thing
38
39  def doReturnTheThing(self):
40    return self._thing
41
42  def doRaiseTheThing(self):
43    raise self._thing
44
45  def doRaiseIfExceptionElseSleepFor(self, sleep_duration):
46    if isinstance(self._thing, Exception):
47      raise self._thing
48    time.sleep(sleep_duration)
49    self._write_completion_file()
50    return self._thing
51
52  def _write_completion_file(self):
53    if self._completion_file_name and len(self._completion_file_name):
54      with open(self._completion_file_name, 'w+b') as completion_file:
55        completion_file.write('complete')
56
57  def __getitem__(self, index):
58    return self._thing[index]
59
60  def __str__(self):
61    return type(self).__name__
62
63
64class ParallelizerTestObjectHelper(object):
65
66  def __init__(self, thing):
67    self._thing = thing
68
69  def doReturnStringThing(self):
70    return str(self._thing)
71
72
73class ParallelizerTest(unittest.TestCase):
74
75  def testInitWithNone(self):
76    with self.assertRaises(AssertionError):
77      parallelizer.Parallelizer(None)
78
79  def testInitEmptyList(self):
80    with self.assertRaises(AssertionError):
81      parallelizer.Parallelizer([])
82
83  def testMethodCall(self):
84    test_data = ['abc_foo', 'def_foo', 'ghi_foo']
85    expected = ['abc_bar', 'def_bar', 'ghi_bar']
86    r = parallelizer.Parallelizer(test_data).replace('_foo', '_bar').pGet(0.1)
87    self.assertEquals(expected, r)
88
89  def testMutate(self):
90    devices = [ParallelizerTestObject(True) for _ in xrange(0, 10)]
91    self.assertTrue(all(d.doReturnTheThing() for d in devices))
92    ParallelizerTestObject.parallel(devices).doSetTheThing(False).pFinish(1)
93    self.assertTrue(not any(d.doReturnTheThing() for d in devices))
94
95  def testAllReturn(self):
96    devices = [ParallelizerTestObject(True) for _ in xrange(0, 10)]
97    results = ParallelizerTestObject.parallel(
98        devices).doReturnTheThing().pGet(1)
99    self.assertTrue(isinstance(results, list))
100    self.assertEquals(10, len(results))
101    self.assertTrue(all(results))
102
103  def testAllRaise(self):
104    devices = [ParallelizerTestObject(Exception('thing %d' % i))
105               for i in xrange(0, 10)]
106    p = ParallelizerTestObject.parallel(devices).doRaiseTheThing()
107    with self.assertRaises(Exception):
108      p.pGet(1)
109
110  def testOneFailOthersComplete(self):
111    parallel_device_count = 10
112    exception_index = 7
113    exception_msg = 'thing %d' % exception_index
114
115    try:
116      completion_files = [tempfile.NamedTemporaryFile(delete=False)
117                          for _ in xrange(0, parallel_device_count)]
118      devices = [
119          ParallelizerTestObject(
120              i if i != exception_index else Exception(exception_msg),
121              completion_files[i].name)
122          for i in xrange(0, parallel_device_count)]
123      for f in completion_files:
124        f.close()
125      p = ParallelizerTestObject.parallel(devices)
126      with self.assertRaises(Exception) as e:
127        p.doRaiseIfExceptionElseSleepFor(2).pGet(3)
128      self.assertTrue(exception_msg in str(e.exception))
129      for i in xrange(0, parallel_device_count):
130        with open(completion_files[i].name) as f:
131          if i == exception_index:
132            self.assertEquals('', f.read())
133          else:
134            self.assertEquals('complete', f.read())
135    finally:
136      for f in completion_files:
137        os.remove(f.name)
138
139  def testReusable(self):
140    devices = [ParallelizerTestObject(True) for _ in xrange(0, 10)]
141    p = ParallelizerTestObject.parallel(devices)
142    results = p.doReturn(True).pGet(1)
143    self.assertTrue(all(results))
144    results = p.doReturn(True).pGet(1)
145    self.assertTrue(all(results))
146    with self.assertRaises(Exception):
147      results = p.doRaise(Exception('reusableTest')).pGet(1)
148
149  def testContained(self):
150    devices = [ParallelizerTestObject(i) for i in xrange(0, 10)]
151    results = (ParallelizerTestObject.parallel(devices).helper
152        .doReturnStringThing().pGet(1))
153    self.assertTrue(isinstance(results, list))
154    self.assertEquals(10, len(results))
155    for i in xrange(0, 10):
156      self.assertEquals(str(i), results[i])
157
158  def testGetItem(self):
159    devices = [ParallelizerTestObject(range(i, i + 10)) for i in xrange(0, 10)]
160    results = ParallelizerTestObject.parallel(devices)[9].pGet(1)
161    self.assertEquals(range(9, 19), results)
162
163
164if __name__ == '__main__':
165  unittest.main(verbosity=2)
166
167