1#!/usr/bin/env python
2# Copyright 2013 The Chromium Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5
6import traceback
7import unittest
8
9
10from future import All, Future, Race
11from mock_function import MockFunction
12
13
14class FutureTest(unittest.TestCase):
15  def testNoValueOrDelegate(self):
16    self.assertRaises(ValueError, Future)
17
18  def testValue(self):
19    future = Future(value=42)
20    self.assertEqual(42, future.Get())
21    self.assertEqual(42, future.Get())
22
23  def testDelegateValue(self):
24    called = [False,]
25    def callback():
26      self.assertFalse(called[0])
27      called[0] = True
28      return 42
29    future = Future(callback=callback)
30    self.assertEqual(42, future.Get())
31    self.assertEqual(42, future.Get())
32
33  def testErrorThrowingDelegate(self):
34    class FunkyException(Exception):
35      pass
36
37    # Set up a chain of functions to test the stack trace.
38    def qux():
39      raise FunkyException()
40    def baz():
41      return qux()
42    def bar():
43      return baz()
44    def foo():
45      return bar()
46    chain = [foo, bar, baz, qux]
47
48    called = [False,]
49    def callback():
50      self.assertFalse(called[0])
51      called[0] = True
52      return foo()
53
54    fail = self.fail
55    assertTrue = self.assertTrue
56    def assert_raises_full_stack(future, err):
57      try:
58        future.Get()
59        fail('Did not raise %s' % err)
60      except Exception as e:
61        assertTrue(isinstance(e, err))
62        stack = traceback.format_exc()
63        assertTrue(all(stack.find(fn.__name__) != -1 for fn in chain))
64
65    future = Future(callback=callback)
66    assert_raises_full_stack(future, FunkyException)
67    assert_raises_full_stack(future, FunkyException)
68
69  def testAll(self):
70    def callback_with_value(value):
71      return MockFunction(lambda: value)
72
73    # Test a single value.
74    callback = callback_with_value(42)
75    future = All((Future(callback=callback),))
76    self.assertTrue(*callback.CheckAndReset(0))
77    self.assertEqual([42], future.Get())
78    self.assertTrue(*callback.CheckAndReset(1))
79
80    # Test multiple callbacks.
81    callbacks = (callback_with_value(1),
82                 callback_with_value(2),
83                 callback_with_value(3))
84    future = All(Future(callback=callback) for callback in callbacks)
85    for callback in callbacks:
86      self.assertTrue(*callback.CheckAndReset(0))
87    self.assertEqual([1, 2, 3], future.Get())
88    for callback in callbacks:
89      self.assertTrue(*callback.CheckAndReset(1))
90
91    # Test throwing an error.
92    def throws_error():
93      raise ValueError()
94    callbacks = (callback_with_value(1),
95                 callback_with_value(2),
96                 MockFunction(throws_error))
97
98    future = All(Future(callback=callback) for callback in callbacks)
99    for callback in callbacks:
100      self.assertTrue(*callback.CheckAndReset(0))
101    self.assertRaises(ValueError, future.Get)
102    for callback in callbacks:
103      # Can't check that the callbacks were actually run because in theory the
104      # Futures can be resolved in any order.
105      callback.CheckAndReset(0)
106
107    # Test throwing an error with except_pass.
108    future = All((Future(callback=callback) for callback in callbacks),
109                 except_pass=ValueError)
110    for callback in callbacks:
111      self.assertTrue(*callback.CheckAndReset(0))
112    self.assertEqual([1, 2, None], future.Get())
113
114  def testRaceSuccess(self):
115    callback = MockFunction(lambda: 42)
116
117    # Test a single value.
118    race = Race((Future(callback=callback),))
119    self.assertTrue(*callback.CheckAndReset(0))
120    self.assertEqual(42, race.Get())
121    self.assertTrue(*callback.CheckAndReset(1))
122
123    # Test multiple success values. Note that we could test different values
124    # and check that the first returned, but this is just an implementation
125    # detail of Race. When we have parallel Futures this might not always hold.
126    race = Race((Future(callback=callback),
127                 Future(callback=callback),
128                 Future(callback=callback)))
129    self.assertTrue(*callback.CheckAndReset(0))
130    self.assertEqual(42, race.Get())
131    # Can't assert the actual count here for the same reason as above.
132    callback.CheckAndReset(99)
133
134    # Test values with except_pass.
135    def throws_error():
136      raise ValueError()
137    race = Race((Future(callback=callback),
138                 Future(callback=throws_error)),
139                 except_pass=(ValueError,))
140    self.assertTrue(*callback.CheckAndReset(0))
141    self.assertEqual(42, race.Get())
142    self.assertTrue(*callback.CheckAndReset(1))
143
144  def testRaceErrors(self):
145    def throws_error():
146      raise ValueError()
147
148    # Test a single error.
149    race = Race((Future(callback=throws_error),))
150    self.assertRaises(ValueError, race.Get)
151
152    # Test multiple errors. Can't use different error types for the same reason
153    # as described in testRaceSuccess.
154    race = Race((Future(callback=throws_error),
155                 Future(callback=throws_error),
156                 Future(callback=throws_error)))
157    self.assertRaises(ValueError, race.Get)
158
159    # Test values with except_pass.
160    def throws_except_error():
161      raise NotImplementedError()
162    race = Race((Future(callback=throws_error),
163                 Future(callback=throws_except_error)),
164                 except_pass=(NotImplementedError,))
165    self.assertRaises(ValueError, race.Get)
166
167    race = Race((Future(callback=throws_error),
168                 Future(callback=throws_error)),
169                 except_pass=(ValueError,))
170    self.assertRaises(ValueError, race.Get)
171
172    # Test except_pass with default values.
173    race = Race((Future(callback=throws_error),
174                 Future(callback=throws_except_error)),
175                 except_pass=(NotImplementedError,),
176                 default=42)
177    self.assertRaises(ValueError, race.Get)
178
179    race = Race((Future(callback=throws_error),
180                 Future(callback=throws_error)),
181                 except_pass=(ValueError,),
182                 default=42)
183    self.assertEqual(42, race.Get())
184
185  def testThen(self):
186    def assertIs42(val):
187      self.assertEqual(val, 42)
188      return val
189
190    then = Future(value=42).Then(assertIs42)
191    # Shouldn't raise an error.
192    self.assertEqual(42, then.Get())
193
194    # Test raising an error.
195    then = Future(value=41).Then(assertIs42)
196    self.assertRaises(AssertionError, then.Get)
197
198    # Test setting up an error handler.
199    def handle(error):
200      if isinstance(error, ValueError):
201        return 'Caught'
202      raise error
203
204    def raiseValueError():
205      raise ValueError
206
207    def raiseException():
208      raise Exception
209
210    then = Future(callback=raiseValueError).Then(assertIs42, handle)
211    self.assertEqual('Caught', then.Get())
212    then = Future(callback=raiseException).Then(assertIs42, handle)
213    self.assertRaises(Exception, then.Get)
214
215    # Test chains of thens.
216    addOne = lambda val: val + 1
217    then = Future(value=40).Then(addOne).Then(addOne).Then(assertIs42)
218    # Shouldn't raise an error.
219    self.assertEqual(42, then.Get())
220
221    # Test error in chain.
222    then = Future(value=40).Then(addOne).Then(assertIs42).Then(addOne)
223    self.assertRaises(AssertionError, then.Get)
224
225    # Test handle error in chain.
226    def raiseValueErrorWithVal(val):
227      raise ValueError
228
229    then = Future(value=40).Then(addOne).Then(raiseValueErrorWithVal).Then(
230        addOne, handle).Then(lambda val: val + ' me')
231    self.assertEquals(then.Get(), 'Caught me')
232
233    # Test multiple handlers.
234    def myHandle(error):
235      if isinstance(error, AssertionError):
236        return 10
237      raise error
238
239    then = Future(value=40).Then(assertIs42).Then(addOne, handle).Then(addOne,
240                                                                       myHandle)
241    self.assertEquals(then.Get(), 10)
242
243  def testThenResolvesReturnedFutures(self):
244    def returnsFortyTwo():
245      return Future(value=42)
246    def inc(x):
247      return x + 1
248    def incFuture(x):
249      return Future(value=x + 1)
250
251    self.assertEqual(43, returnsFortyTwo().Then(inc).Get())
252    self.assertEqual(43, returnsFortyTwo().Then(incFuture).Get())
253    self.assertEqual(44, returnsFortyTwo().Then(inc).Then(inc).Get())
254    self.assertEqual(44, returnsFortyTwo().Then(inc).Then(incFuture).Get())
255    self.assertEqual(44, returnsFortyTwo().Then(incFuture).Then(inc).Get())
256    self.assertEqual(
257        44, returnsFortyTwo().Then(incFuture).Then(incFuture).Get())
258
259    # The same behaviour should apply to error handlers.
260    def raisesSomething():
261      def boom(): raise ValueError
262      return Future(callback=boom)
263    def shouldNotHappen(_):
264      raise AssertionError()
265    def oops(error):
266      return 'oops'
267    def oopsFuture(error):
268      return Future(value='oops')
269
270    self.assertEqual(
271        'oops', raisesSomething().Then(shouldNotHappen, oops).Get())
272    self.assertEqual(
273        'oops', raisesSomething().Then(shouldNotHappen, oopsFuture).Get())
274    self.assertEqual(
275        'oops',
276        raisesSomething().Then(shouldNotHappen, raisesSomething)
277                         .Then(shouldNotHappen, oops).Get())
278    self.assertEqual(
279        'oops',
280        raisesSomething().Then(shouldNotHappen, raisesSomething)
281                         .Then(shouldNotHappen, oopsFuture).Get())
282
283
284if __name__ == '__main__':
285  unittest.main()
286