1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for ODE solvers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.contrib.integrate.python.ops import odes
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import errors_impl
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.platform import test
31
32
33class OdeIntTest(test.TestCase):
34
35  def setUp(self):
36    super(OdeIntTest, self).setUp()
37    # simple defaults (solution is a sin-wave)
38    matrix = constant_op.constant([[0, 1], [-1, 0]], dtype=dtypes.float64)
39    self.func = lambda y, t: math_ops.matmul(matrix, y)
40    self.y0 = np.array([[1.0], [0.0]])
41
42  def test_odeint_exp(self):
43    # Test odeint by an exponential function:
44    #   dy / dt = y,  y(0) = 1.0.
45    # Its analytical solution is y = exp(t).
46    func = lambda y, t: y
47    y0 = constant_op.constant(1.0, dtype=dtypes.float64)
48    t = np.linspace(0.0, 1.0, 11)
49    y_solved = odes.odeint(func, y0, t)
50    self.assertIn('odeint', y_solved.name)
51    self.assertEqual(y_solved.get_shape(), tensor_shape.TensorShape([11]))
52    with self.test_session() as sess:
53      y_solved = sess.run(y_solved)
54    y_true = np.exp(t)
55    self.assertAllClose(y_true, y_solved)
56
57  def test_odeint_complex(self):
58    # Test a complex, linear ODE:
59    #   dy / dt = k * y,  y(0) = 1.0.
60    # Its analytical solution is y = exp(k * t).
61    k = 1j - 0.1
62    func = lambda y, t: k * y
63    t = np.linspace(0.0, 1.0, 11)
64    y_solved = odes.odeint(func, 1.0 + 0.0j, t)
65    with self.test_session() as sess:
66      y_solved = sess.run(y_solved)
67    y_true = np.exp(k * t)
68    self.assertAllClose(y_true, y_solved)
69
70  def test_odeint_riccati(self):
71    # The Ricatti equation is:
72    #   dy / dt = (y - t) ** 2 + 1.0,  y(0) = 0.5.
73    # Its analytical solution is y = 1.0 / (2.0 - t) + t.
74    func = lambda t, y: (y - t)**2 + 1.0
75    t = np.linspace(0.0, 1.0, 11)
76    y_solved = odes.odeint(func, np.float64(0.5), t)
77    with self.test_session() as sess:
78      y_solved = sess.run(y_solved)
79    y_true = 1.0 / (2.0 - t) + t
80    self.assertAllClose(y_true, y_solved)
81
82  def test_odeint_2d_linear(self):
83    # Solve the 2D linear differential equation:
84    #   dy1 / dt = 3.0 * y1 + 4.0 * y2,
85    #   dy2 / dt = -4.0 * y1 + 3.0 * y2,
86    #   y1(0) = 0.0,
87    #   y2(0) = 1.0.
88    # Its analytical solution is
89    #   y1 = sin(4.0 * t) * exp(3.0 * t),
90    #   y2 = cos(4.0 * t) * exp(3.0 * t).
91    matrix = constant_op.constant(
92        [[3.0, 4.0], [-4.0, 3.0]], dtype=dtypes.float64)
93    func = lambda y, t: math_ops.matmul(matrix, y)
94
95    y0 = constant_op.constant([[0.0], [1.0]], dtype=dtypes.float64)
96    t = np.linspace(0.0, 1.0, 11)
97
98    y_solved = odes.odeint(func, y0, t)
99    with self.test_session() as sess:
100      y_solved = sess.run(y_solved)
101
102    y_true = np.zeros((len(t), 2, 1))
103    y_true[:, 0, 0] = np.sin(4.0 * t) * np.exp(3.0 * t)
104    y_true[:, 1, 0] = np.cos(4.0 * t) * np.exp(3.0 * t)
105    self.assertAllClose(y_true, y_solved, atol=1e-5)
106
107  def test_odeint_higher_rank(self):
108    func = lambda y, t: y
109    y0 = constant_op.constant(1.0, dtype=dtypes.float64)
110    t = np.linspace(0.0, 1.0, 11)
111    for shape in [(), (1,), (1, 1)]:
112      expected_shape = (len(t),) + shape
113      y_solved = odes.odeint(func, array_ops.reshape(y0, shape), t)
114      self.assertEqual(y_solved.get_shape(),
115                       tensor_shape.TensorShape(expected_shape))
116      with self.test_session() as sess:
117        y_solved = sess.run(y_solved)
118        self.assertEquals(y_solved.shape, expected_shape)
119
120  def test_odeint_all_dtypes(self):
121    func = lambda y, t: y
122    t = np.linspace(0.0, 1.0, 11)
123    for y0_dtype in [
124        dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128
125    ]:
126      for t_dtype in [dtypes.float32, dtypes.float64]:
127        y0 = math_ops.cast(1.0, y0_dtype)
128        y_solved = odes.odeint(func, y0, math_ops.cast(t, t_dtype))
129        with self.test_session() as sess:
130          y_solved = sess.run(y_solved)
131        expected = np.asarray(np.exp(t))
132        self.assertAllClose(y_solved, expected, rtol=1e-5)
133        self.assertEqual(dtypes.as_dtype(y_solved.dtype), y0_dtype)
134
135  def test_odeint_required_dtypes(self):
136    with self.assertRaisesRegexp(TypeError, '`y0` must have a floating point'):
137      odes.odeint(self.func, math_ops.cast(self.y0, dtypes.int32), [0, 1])
138
139    with self.assertRaisesRegexp(TypeError, '`t` must have a floating point'):
140      odes.odeint(self.func, self.y0, math_ops.cast([0, 1], dtypes.int32))
141
142  def test_odeint_runtime_errors(self):
143    with self.assertRaisesRegexp(ValueError, 'cannot supply `options` without'):
144      odes.odeint(self.func, self.y0, [0, 1], options={'first_step': 1.0})
145
146    y = odes.odeint(
147        self.func,
148        self.y0, [0, 1],
149        method='dopri5',
150        options={'max_num_steps': 0})
151    with self.test_session() as sess:
152      with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
153                                   'max_num_steps'):
154        sess.run(y)
155
156    y = odes.odeint(self.func, self.y0, [1, 0])
157    with self.test_session() as sess:
158      with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
159                                   'monotonic increasing'):
160        sess.run(y)
161
162  def test_odeint_different_times(self):
163    # integrate steps should be independent of interpolation times
164    times0 = np.linspace(0, 10, num=11, dtype=float)
165    times1 = np.linspace(0, 10, num=101, dtype=float)
166
167    with self.test_session() as sess:
168      y_solved_0, info_0 = sess.run(
169          odes.odeint(self.func, self.y0, times0, full_output=True))
170      y_solved_1, info_1 = sess.run(
171          odes.odeint(self.func, self.y0, times1, full_output=True))
172
173    self.assertAllClose(y_solved_0, y_solved_1[::10])
174    self.assertEqual(info_0['num_func_evals'], info_1['num_func_evals'])
175    self.assertAllEqual(info_0['integrate_points'], info_1['integrate_points'])
176    self.assertAllEqual(info_0['error_ratio'], info_1['error_ratio'])
177
178  def test_odeint_5th_order_accuracy(self):
179    t = [0, 20]
180    kwargs = dict(
181        full_output=True, method='dopri5', options=dict(max_num_steps=2000))
182    with self.test_session() as sess:
183      _, info_0 = sess.run(
184          odes.odeint(self.func, self.y0, t, rtol=0, atol=1e-6, **kwargs))
185      _, info_1 = sess.run(
186          odes.odeint(self.func, self.y0, t, rtol=0, atol=1e-9, **kwargs))
187    self.assertAllClose(
188        info_0['integrate_points'].size * 1000**0.2,
189        float(info_1['integrate_points'].size),
190        rtol=0.01)
191
192
193class StepSizeTest(test.TestCase):
194
195  def test_error_ratio_one(self):
196    new_step = odes._optimal_step_size(
197        last_step=constant_op.constant(1.0),
198        error_ratio=constant_op.constant(1.0))
199    with self.test_session() as sess:
200      new_step = sess.run(new_step)
201    self.assertAllClose(new_step, 0.9)
202
203  def test_ifactor(self):
204    new_step = odes._optimal_step_size(
205        last_step=constant_op.constant(1.0),
206        error_ratio=constant_op.constant(0.0))
207    with self.test_session() as sess:
208      new_step = sess.run(new_step)
209    self.assertAllClose(new_step, 10.0)
210
211  def test_dfactor(self):
212    new_step = odes._optimal_step_size(
213        last_step=constant_op.constant(1.0),
214        error_ratio=constant_op.constant(1e6))
215    with self.test_session() as sess:
216      new_step = sess.run(new_step)
217    self.assertAllClose(new_step, 0.2)
218
219
220class InterpolationTest(test.TestCase):
221
222  def test_5th_order_polynomial(self):
223    # this should be an exact fit
224    f = lambda x: x**4 + x**3 - 2 * x**2 + 4 * x + 5
225    f_prime = lambda x: 4 * x**3 + 3 * x**2 - 4 * x + 4
226    coeffs = odes._interp_fit(
227        f(0.0), f(10.0), f(5.0), f_prime(0.0), f_prime(10.0), 10.0)
228    times = np.linspace(0, 10, dtype=np.float32)
229    y_fit = array_ops.stack(
230        [odes._interp_evaluate(coeffs, 0.0, 10.0, t) for t in times])
231    y_expected = f(times)
232    with self.test_session() as sess:
233      y_actual = sess.run(y_fit)
234      self.assertAllClose(y_expected, y_actual)
235
236    # attempt interpolation outside bounds
237    y_invalid = odes._interp_evaluate(coeffs, 0.0, 10.0, 100.0)
238    with self.test_session() as sess:
239      with self.assertRaises(errors_impl.InvalidArgumentError):
240        sess.run(y_invalid)
241
242
243class OdeIntFixedTest(test.TestCase):
244
245  def _test_integrate_sine(self, method):
246
247    def evol_func(y, t):
248      del t
249      return array_ops.stack([y[1], -y[0]])
250
251    y0 = [0., 1.]
252    time_grid = np.linspace(0., 10., 200)
253    y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method)
254
255    with self.test_session() as sess:
256      y_grid_array = sess.run(y_grid)
257
258    np.testing.assert_allclose(
259        y_grid_array[:, 0], np.sin(time_grid), rtol=1e-2, atol=1e-2)
260
261  def _test_integrate_gaussian(self, method):
262
263    def evol_func(y, t):
264      return -math_ops.cast(t, dtype=y.dtype) * y[0]
265
266    y0 = [1.]
267    time_grid = np.linspace(0., 2., 100)
268    y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method)
269
270    with self.test_session() as sess:
271      y_grid_array = sess.run(y_grid)
272
273    np.testing.assert_allclose(
274        y_grid_array[:, 0], np.exp(-time_grid**2 / 2), rtol=1e-2, atol=1e-2)
275
276  def _test_everything(self, method):
277    self._test_integrate_sine(method)
278    self._test_integrate_gaussian(method)
279
280  def test_midpoint(self):
281    self._test_everything('midpoint')
282
283  def test_rk4(self):
284    self._test_everything('rk4')
285
286
287if __name__ == '__main__':
288  test.main()
289