1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for ODE solvers.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.contrib.integrate.python.ops import odes 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import errors_impl 27from tensorflow.python.framework import tensor_shape 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.platform import test 31 32 33class OdeIntTest(test.TestCase): 34 35 def setUp(self): 36 super(OdeIntTest, self).setUp() 37 # simple defaults (solution is a sin-wave) 38 matrix = constant_op.constant([[0, 1], [-1, 0]], dtype=dtypes.float64) 39 self.func = lambda y, t: math_ops.matmul(matrix, y) 40 self.y0 = np.array([[1.0], [0.0]]) 41 42 def test_odeint_exp(self): 43 # Test odeint by an exponential function: 44 # dy / dt = y, y(0) = 1.0. 45 # Its analytical solution is y = exp(t). 46 func = lambda y, t: y 47 y0 = constant_op.constant(1.0, dtype=dtypes.float64) 48 t = np.linspace(0.0, 1.0, 11) 49 y_solved = odes.odeint(func, y0, t) 50 self.assertIn('odeint', y_solved.name) 51 self.assertEqual(y_solved.get_shape(), tensor_shape.TensorShape([11])) 52 with self.test_session() as sess: 53 y_solved = sess.run(y_solved) 54 y_true = np.exp(t) 55 self.assertAllClose(y_true, y_solved) 56 57 def test_odeint_complex(self): 58 # Test a complex, linear ODE: 59 # dy / dt = k * y, y(0) = 1.0. 60 # Its analytical solution is y = exp(k * t). 61 k = 1j - 0.1 62 func = lambda y, t: k * y 63 t = np.linspace(0.0, 1.0, 11) 64 y_solved = odes.odeint(func, 1.0 + 0.0j, t) 65 with self.test_session() as sess: 66 y_solved = sess.run(y_solved) 67 y_true = np.exp(k * t) 68 self.assertAllClose(y_true, y_solved) 69 70 def test_odeint_riccati(self): 71 # The Ricatti equation is: 72 # dy / dt = (y - t) ** 2 + 1.0, y(0) = 0.5. 73 # Its analytical solution is y = 1.0 / (2.0 - t) + t. 74 func = lambda t, y: (y - t)**2 + 1.0 75 t = np.linspace(0.0, 1.0, 11) 76 y_solved = odes.odeint(func, np.float64(0.5), t) 77 with self.test_session() as sess: 78 y_solved = sess.run(y_solved) 79 y_true = 1.0 / (2.0 - t) + t 80 self.assertAllClose(y_true, y_solved) 81 82 def test_odeint_2d_linear(self): 83 # Solve the 2D linear differential equation: 84 # dy1 / dt = 3.0 * y1 + 4.0 * y2, 85 # dy2 / dt = -4.0 * y1 + 3.0 * y2, 86 # y1(0) = 0.0, 87 # y2(0) = 1.0. 88 # Its analytical solution is 89 # y1 = sin(4.0 * t) * exp(3.0 * t), 90 # y2 = cos(4.0 * t) * exp(3.0 * t). 91 matrix = constant_op.constant( 92 [[3.0, 4.0], [-4.0, 3.0]], dtype=dtypes.float64) 93 func = lambda y, t: math_ops.matmul(matrix, y) 94 95 y0 = constant_op.constant([[0.0], [1.0]], dtype=dtypes.float64) 96 t = np.linspace(0.0, 1.0, 11) 97 98 y_solved = odes.odeint(func, y0, t) 99 with self.test_session() as sess: 100 y_solved = sess.run(y_solved) 101 102 y_true = np.zeros((len(t), 2, 1)) 103 y_true[:, 0, 0] = np.sin(4.0 * t) * np.exp(3.0 * t) 104 y_true[:, 1, 0] = np.cos(4.0 * t) * np.exp(3.0 * t) 105 self.assertAllClose(y_true, y_solved, atol=1e-5) 106 107 def test_odeint_higher_rank(self): 108 func = lambda y, t: y 109 y0 = constant_op.constant(1.0, dtype=dtypes.float64) 110 t = np.linspace(0.0, 1.0, 11) 111 for shape in [(), (1,), (1, 1)]: 112 expected_shape = (len(t),) + shape 113 y_solved = odes.odeint(func, array_ops.reshape(y0, shape), t) 114 self.assertEqual(y_solved.get_shape(), 115 tensor_shape.TensorShape(expected_shape)) 116 with self.test_session() as sess: 117 y_solved = sess.run(y_solved) 118 self.assertEquals(y_solved.shape, expected_shape) 119 120 def test_odeint_all_dtypes(self): 121 func = lambda y, t: y 122 t = np.linspace(0.0, 1.0, 11) 123 for y0_dtype in [ 124 dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128 125 ]: 126 for t_dtype in [dtypes.float32, dtypes.float64]: 127 y0 = math_ops.cast(1.0, y0_dtype) 128 y_solved = odes.odeint(func, y0, math_ops.cast(t, t_dtype)) 129 with self.test_session() as sess: 130 y_solved = sess.run(y_solved) 131 expected = np.asarray(np.exp(t)) 132 self.assertAllClose(y_solved, expected, rtol=1e-5) 133 self.assertEqual(dtypes.as_dtype(y_solved.dtype), y0_dtype) 134 135 def test_odeint_required_dtypes(self): 136 with self.assertRaisesRegexp(TypeError, '`y0` must have a floating point'): 137 odes.odeint(self.func, math_ops.cast(self.y0, dtypes.int32), [0, 1]) 138 139 with self.assertRaisesRegexp(TypeError, '`t` must have a floating point'): 140 odes.odeint(self.func, self.y0, math_ops.cast([0, 1], dtypes.int32)) 141 142 def test_odeint_runtime_errors(self): 143 with self.assertRaisesRegexp(ValueError, 'cannot supply `options` without'): 144 odes.odeint(self.func, self.y0, [0, 1], options={'first_step': 1.0}) 145 146 y = odes.odeint( 147 self.func, 148 self.y0, [0, 1], 149 method='dopri5', 150 options={'max_num_steps': 0}) 151 with self.test_session() as sess: 152 with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 153 'max_num_steps'): 154 sess.run(y) 155 156 y = odes.odeint(self.func, self.y0, [1, 0]) 157 with self.test_session() as sess: 158 with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 159 'monotonic increasing'): 160 sess.run(y) 161 162 def test_odeint_different_times(self): 163 # integrate steps should be independent of interpolation times 164 times0 = np.linspace(0, 10, num=11, dtype=float) 165 times1 = np.linspace(0, 10, num=101, dtype=float) 166 167 with self.test_session() as sess: 168 y_solved_0, info_0 = sess.run( 169 odes.odeint(self.func, self.y0, times0, full_output=True)) 170 y_solved_1, info_1 = sess.run( 171 odes.odeint(self.func, self.y0, times1, full_output=True)) 172 173 self.assertAllClose(y_solved_0, y_solved_1[::10]) 174 self.assertEqual(info_0['num_func_evals'], info_1['num_func_evals']) 175 self.assertAllEqual(info_0['integrate_points'], info_1['integrate_points']) 176 self.assertAllEqual(info_0['error_ratio'], info_1['error_ratio']) 177 178 def test_odeint_5th_order_accuracy(self): 179 t = [0, 20] 180 kwargs = dict( 181 full_output=True, method='dopri5', options=dict(max_num_steps=2000)) 182 with self.test_session() as sess: 183 _, info_0 = sess.run( 184 odes.odeint(self.func, self.y0, t, rtol=0, atol=1e-6, **kwargs)) 185 _, info_1 = sess.run( 186 odes.odeint(self.func, self.y0, t, rtol=0, atol=1e-9, **kwargs)) 187 self.assertAllClose( 188 info_0['integrate_points'].size * 1000**0.2, 189 float(info_1['integrate_points'].size), 190 rtol=0.01) 191 192 193class StepSizeTest(test.TestCase): 194 195 def test_error_ratio_one(self): 196 new_step = odes._optimal_step_size( 197 last_step=constant_op.constant(1.0), 198 error_ratio=constant_op.constant(1.0)) 199 with self.test_session() as sess: 200 new_step = sess.run(new_step) 201 self.assertAllClose(new_step, 0.9) 202 203 def test_ifactor(self): 204 new_step = odes._optimal_step_size( 205 last_step=constant_op.constant(1.0), 206 error_ratio=constant_op.constant(0.0)) 207 with self.test_session() as sess: 208 new_step = sess.run(new_step) 209 self.assertAllClose(new_step, 10.0) 210 211 def test_dfactor(self): 212 new_step = odes._optimal_step_size( 213 last_step=constant_op.constant(1.0), 214 error_ratio=constant_op.constant(1e6)) 215 with self.test_session() as sess: 216 new_step = sess.run(new_step) 217 self.assertAllClose(new_step, 0.2) 218 219 220class InterpolationTest(test.TestCase): 221 222 def test_5th_order_polynomial(self): 223 # this should be an exact fit 224 f = lambda x: x**4 + x**3 - 2 * x**2 + 4 * x + 5 225 f_prime = lambda x: 4 * x**3 + 3 * x**2 - 4 * x + 4 226 coeffs = odes._interp_fit( 227 f(0.0), f(10.0), f(5.0), f_prime(0.0), f_prime(10.0), 10.0) 228 times = np.linspace(0, 10, dtype=np.float32) 229 y_fit = array_ops.stack( 230 [odes._interp_evaluate(coeffs, 0.0, 10.0, t) for t in times]) 231 y_expected = f(times) 232 with self.test_session() as sess: 233 y_actual = sess.run(y_fit) 234 self.assertAllClose(y_expected, y_actual) 235 236 # attempt interpolation outside bounds 237 y_invalid = odes._interp_evaluate(coeffs, 0.0, 10.0, 100.0) 238 with self.test_session() as sess: 239 with self.assertRaises(errors_impl.InvalidArgumentError): 240 sess.run(y_invalid) 241 242 243class OdeIntFixedTest(test.TestCase): 244 245 def _test_integrate_sine(self, method): 246 247 def evol_func(y, t): 248 del t 249 return array_ops.stack([y[1], -y[0]]) 250 251 y0 = [0., 1.] 252 time_grid = np.linspace(0., 10., 200) 253 y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method) 254 255 with self.test_session() as sess: 256 y_grid_array = sess.run(y_grid) 257 258 np.testing.assert_allclose( 259 y_grid_array[:, 0], np.sin(time_grid), rtol=1e-2, atol=1e-2) 260 261 def _test_integrate_gaussian(self, method): 262 263 def evol_func(y, t): 264 return -math_ops.cast(t, dtype=y.dtype) * y[0] 265 266 y0 = [1.] 267 time_grid = np.linspace(0., 2., 100) 268 y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method) 269 270 with self.test_session() as sess: 271 y_grid_array = sess.run(y_grid) 272 273 np.testing.assert_allclose( 274 y_grid_array[:, 0], np.exp(-time_grid**2 / 2), rtol=1e-2, atol=1e-2) 275 276 def _test_everything(self, method): 277 self._test_integrate_sine(method) 278 self._test_integrate_gaussian(method) 279 280 def test_midpoint(self): 281 self._test_everything('midpoint') 282 283 def test_rk4(self): 284 self._test_everything('rk4') 285 286 287if __name__ == '__main__': 288 test.main() 289