1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""ODE solvers for TensorFlow.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22import collections 23 24import six 25 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import functional_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import tensor_array_ops 34 35_ButcherTableau = collections.namedtuple('_ButcherTableau', 36 'alpha beta c_sol c_mid c_error') 37 38# Parameters from Shampine (1986), section 4. 39_DORMAND_PRINCE_TABLEAU = _ButcherTableau( 40 alpha=[1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1.], 41 beta=[ 42 [1 / 5], 43 [3 / 40, 9 / 40], 44 [44 / 45, -56 / 15, 32 / 9], 45 [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], 46 [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], 47 [35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84], 48 ], 49 c_sol=[35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0], 50 c_mid=[ 51 6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2, 52 -2691868925 / 45128329728 / 2, 187940372067 / 1594534317056 / 2, 53 -1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2 54 ], 55 c_error=[ 56 1951 / 21600 - 35 / 384, 57 0, 58 22642 / 50085 - 500 / 1113, 59 451 / 720 - 125 / 192, 60 -12231 / 42400 - -2187 / 6784, 61 649 / 6300 - 11 / 84, 62 1 / 60, 63 ],) 64 65 66def _possibly_nonzero(x): 67 return isinstance(x, ops.Tensor) or x != 0 68 69 70def _scaled_dot_product(scale, xs, ys, name=None): 71 """Calculate a scaled, vector inner product between lists of Tensors.""" 72 with ops.name_scope(name, 'scaled_dot_product', [scale, xs, ys]) as scope: 73 # Some of the parameters in our Butcher tableau include zeros. Using 74 # _possibly_nonzero lets us avoid wasted computation. 75 return math_ops.add_n( 76 [(scale * x) * y for x, y in zip(xs, ys) 77 if _possibly_nonzero(x) or _possibly_nonzero(y)], 78 name=scope) 79 80 81def _dot_product(xs, ys, name=None): 82 """Calculate the vector inner product between two lists of Tensors.""" 83 with ops.name_scope(name, 'dot_product', [xs, ys]) as scope: 84 return math_ops.add_n([x * y for x, y in zip(xs, ys)], name=scope) 85 86 87def _runge_kutta_step(func, 88 y0, 89 f0, 90 t0, 91 dt, 92 tableau=_DORMAND_PRINCE_TABLEAU, 93 name=None): 94 """Take an arbitrary Runge-Kutta step and estimate error. 95 96 Args: 97 func: Function to evaluate like `func(y, t)` to compute the time derivative 98 of `y`. 99 y0: Tensor initial value for the state. 100 f0: Tensor initial value for the derivative, computed from `func(y0, t0)`. 101 t0: float64 scalar Tensor giving the initial time. 102 dt: float64 scalar Tensor giving the size of the desired time step. 103 tableau: optional _ButcherTableau describing how to take the Runge-Kutta 104 step. 105 name: optional name for the operation. 106 107 Returns: 108 Tuple `(y1, f1, y1_error, k)` giving the estimated function value after 109 the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`, 110 estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for 111 calculating these terms. 112 """ 113 with ops.name_scope(name, 'runge_kutta_step', [y0, f0, t0, dt]) as scope: 114 y0 = ops.convert_to_tensor(y0, name='y0') 115 f0 = ops.convert_to_tensor(f0, name='f0') 116 t0 = ops.convert_to_tensor(t0, name='t0') 117 dt = ops.convert_to_tensor(dt, name='dt') 118 dt_cast = math_ops.cast(dt, y0.dtype) 119 120 k = [f0] 121 for alpha_i, beta_i in zip(tableau.alpha, tableau.beta): 122 ti = t0 + alpha_i * dt 123 yi = y0 + _scaled_dot_product(dt_cast, beta_i, k) 124 k.append(func(yi, ti)) 125 126 if not (tableau.c_sol[-1] == 0 and tableau.c_sol == tableau.beta[-1]): 127 # This property (true for Dormand-Prince) lets us save a few FLOPs. 128 yi = y0 + _scaled_dot_product(dt_cast, tableau.c_sol, k) 129 130 y1 = array_ops.identity(yi, name='%s/y1' % scope) 131 f1 = array_ops.identity(k[-1], name='%s/f1' % scope) 132 y1_error = _scaled_dot_product( 133 dt_cast, tableau.c_error, k, name='%s/y1_error' % scope) 134 return (y1, f1, y1_error, k) 135 136 137def _interp_fit(y0, y1, y_mid, f0, f1, dt): 138 """Fit coefficients for 4th order polynomial interpolation. 139 140 Args: 141 y0: function value at the start of the interval. 142 y1: function value at the end of the interval. 143 y_mid: function value at the mid-point of the interval. 144 f0: derivative value at the start of the interval. 145 f1: derivative value at the end of the interval. 146 dt: width of the interval. 147 148 Returns: 149 List of coefficients `[a, b, c, d, e]` for interpolating with the polynomial 150 `p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e` for values of `x` 151 between 0 (start of interval) and 1 (end of interval). 152 """ 153 # a, b, c, d, e = sympy.symbols('a b c d e') 154 # x, dt, y0, y1, y_mid, f0, f1 = sympy.symbols('x dt y0 y1 y_mid f0 f1') 155 # p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e 156 # sympy.solve([p.subs(x, 0) - y0, 157 # p.subs(x, 1 / 2) - y_mid, 158 # p.subs(x, 1) - y1, 159 # (p.diff(x) / dt).subs(x, 0) - f0, 160 # (p.diff(x) / dt).subs(x, 1) - f1], 161 # [a, b, c, d, e]) 162 # {a: -2.0*dt*f0 + 2.0*dt*f1 - 8.0*y0 - 8.0*y1 + 16.0*y_mid, 163 # b: 5.0*dt*f0 - 3.0*dt*f1 + 18.0*y0 + 14.0*y1 - 32.0*y_mid, 164 # c: -4.0*dt*f0 + dt*f1 - 11.0*y0 - 5.0*y1 + 16.0*y_mid, 165 # d: dt*f0, 166 # e: y0} 167 a = _dot_product([-2 * dt, 2 * dt, -8, -8, 16], [f0, f1, y0, y1, y_mid]) 168 b = _dot_product([5 * dt, -3 * dt, 18, 14, -32], [f0, f1, y0, y1, y_mid]) 169 c = _dot_product([-4 * dt, dt, -11, -5, 16], [f0, f1, y0, y1, y_mid]) 170 d = dt * f0 171 e = y0 172 return [a, b, c, d, e] 173 174 175def _interp_fit_rk(y0, y1, k, dt, tableau=_DORMAND_PRINCE_TABLEAU): 176 """Fit an interpolating polynomial to the results of a Runge-Kutta step.""" 177 with ops.name_scope('interp_fit_rk'): 178 dt = math_ops.cast(dt, y0.dtype) 179 y_mid = y0 + _scaled_dot_product(dt, tableau.c_mid, k) 180 f0 = k[0] 181 f1 = k[-1] 182 return _interp_fit(y0, y1, y_mid, f0, f1, dt) 183 184 185def _interp_evaluate(coefficients, t0, t1, t): 186 """Evaluate polynomial interpolation at the given time point. 187 188 Args: 189 coefficients: list of Tensor coefficients as created by `interp_fit`. 190 t0: scalar float64 Tensor giving the start of the interval. 191 t1: scalar float64 Tensor giving the end of the interval. 192 t: scalar float64 Tensor giving the desired interpolation point. 193 194 Returns: 195 Polynomial interpolation of the coefficients at time `t`. 196 """ 197 with ops.name_scope('interp_evaluate'): 198 t0 = ops.convert_to_tensor(t0) 199 t1 = ops.convert_to_tensor(t1) 200 t = ops.convert_to_tensor(t) 201 202 dtype = coefficients[0].dtype 203 204 assert_op = control_flow_ops.Assert( 205 (t0 <= t) & (t <= t1), 206 ['invalid interpolation, fails `t0 <= t <= t1`:', t0, t, t1]) 207 with ops.control_dependencies([assert_op]): 208 x = math_ops.cast((t - t0) / (t1 - t0), dtype) 209 210 xs = [constant_op.constant(1, dtype), x] 211 for _ in range(2, len(coefficients)): 212 xs.append(xs[-1] * x) 213 214 return _dot_product(coefficients, reversed(xs)) 215 216 217def _optimal_step_size(last_step, 218 error_ratio, 219 safety=0.9, 220 ifactor=10.0, 221 dfactor=0.2, 222 order=5, 223 name=None): 224 """Calculate the optimal size for the next Runge-Kutta step.""" 225 with ops.name_scope(name, 'optimal_step_size', [last_step, 226 error_ratio]) as scope: 227 error_ratio = math_ops.cast(error_ratio, last_step.dtype) 228 exponent = math_ops.cast(1 / order, last_step.dtype) 229 # this looks more complex than necessary, but importantly it keeps 230 # error_ratio in the numerator so we can't divide by zero: 231 factor = math_ops.maximum(1 / ifactor, 232 math_ops.minimum(error_ratio**exponent / safety, 233 1 / dfactor)) 234 return math_ops.div(last_step, factor, name=scope) 235 236 237def _abs_square(x): 238 if x.dtype.is_complex: 239 return math_ops.square(math_ops.real(x)) + math_ops.square(math_ops.imag(x)) 240 else: 241 return math_ops.square(x) 242 243 244def _ta_append(tensor_array, value): 245 """Append a value to the end of a tf.TensorArray.""" 246 return tensor_array.write(tensor_array.size(), value) 247 248 249class _RungeKuttaState( 250 collections.namedtuple('_RungeKuttaState', 251 'y1, f1, t0, t1, dt, interp_coeff')): 252 """Saved state of the Runge Kutta solver. 253 254 Attributes: 255 y1: Tensor giving the function value at the end of the last time step. 256 f1: Tensor giving derivative at the end of the last time step. 257 t0: scalar float64 Tensor giving start of the last time step. 258 t1: scalar float64 Tensor giving end of the last time step. 259 dt: scalar float64 Tensor giving the size for the next time step. 260 interp_coef: list of Tensors giving coefficients for polynomial 261 interpolation between `t0` and `t1`. 262 """ 263 264 265class _History( 266 collections.namedtuple('_History', 'integrate_points, error_ratio')): 267 """Saved integration history for use in `info_dict`. 268 269 Attributes: 270 integrate_points: tf.TensorArray storing integrating time points. 271 error_ratio: tf.TensorArray storing computed error ratios at each 272 integration step. 273 """ 274 275 276def _assert_increasing(t): 277 assert_increasing = control_flow_ops.Assert( 278 math_ops.reduce_all(t[1:] > t[:-1]), ['`t` must be monotonic increasing']) 279 return ops.control_dependencies([assert_increasing]) 280 281 282def _check_input_types(t, y0): 283 if not (y0.dtype.is_floating or y0.dtype.is_complex): 284 raise TypeError('`y0` must have a floating point or complex floating ' 285 'point dtype') 286 if not t.dtype.is_floating: 287 raise TypeError('`t` must have a floating point dtype') 288 289 290def _dopri5(func, 291 y0, 292 t, 293 rtol, 294 atol, 295 full_output=False, 296 first_step=None, 297 safety=0.9, 298 ifactor=10.0, 299 dfactor=0.2, 300 max_num_steps=1000, 301 name=None): 302 """Solve an ODE for `odeint` using method='dopri5'.""" 303 304 if first_step is None: 305 # at some point, we might want to switch to picking the step size 306 # automatically 307 first_step = 1.0 308 309 with ops.name_scope(name, 'dopri5', [ 310 y0, t, rtol, atol, safety, ifactor, dfactor, max_num_steps 311 ]) as scope: 312 313 first_step = ops.convert_to_tensor( 314 first_step, dtype=t.dtype, name='first_step') 315 safety = ops.convert_to_tensor(safety, dtype=t.dtype, name='safety') 316 ifactor = ops.convert_to_tensor(ifactor, dtype=t.dtype, name='ifactor') 317 dfactor = ops.convert_to_tensor(dfactor, dtype=t.dtype, name='dfactor') 318 max_num_steps = ops.convert_to_tensor( 319 max_num_steps, dtype=dtypes.int32, name='max_num_steps') 320 321 def adaptive_runge_kutta_step(rk_state, history, n_steps): 322 """Take an adaptive Runge-Kutta step to integrate the ODE.""" 323 y0, f0, _, t0, dt, interp_coeff = rk_state 324 with ops.name_scope('assertions'): 325 check_underflow = control_flow_ops.Assert(t0 + dt > t0, 326 ['underflow in dt', dt]) 327 check_max_num_steps = control_flow_ops.Assert( 328 n_steps < max_num_steps, ['max_num_steps exceeded']) 329 check_numerics = control_flow_ops.Assert( 330 math_ops.reduce_all(math_ops.is_finite(abs(y0))), 331 ['non-finite values in state `y`', y0]) 332 with ops.control_dependencies( 333 [check_underflow, check_max_num_steps, check_numerics]): 334 y1, f1, y1_error, k = _runge_kutta_step(func, y0, f0, t0, dt) 335 336 with ops.name_scope('error_ratio'): 337 # We use the same approach as the dopri5 fortran code. 338 error_tol = atol + rtol * math_ops.maximum(abs(y0), abs(y1)) 339 tensor_error_ratio = _abs_square(y1_error) / _abs_square(error_tol) 340 # Could also use reduce_maximum here. 341 error_ratio = math_ops.sqrt(math_ops.reduce_mean(tensor_error_ratio)) 342 accept_step = error_ratio <= 1 343 344 with ops.name_scope('update/rk_state'): 345 # If we don't accept the step, the _RungeKuttaState will be useless 346 # (covering a time-interval of size 0), but that's OK, because in such 347 # cases we always immediately take another Runge-Kutta step. 348 y_next = control_flow_ops.cond(accept_step, lambda: y1, lambda: y0) 349 f_next = control_flow_ops.cond(accept_step, lambda: f1, lambda: f0) 350 t_next = control_flow_ops.cond(accept_step, lambda: t0 + dt, lambda: t0) 351 interp_coeff = control_flow_ops.cond( 352 accept_step, lambda: _interp_fit_rk(y0, y1, k, dt), 353 lambda: interp_coeff) 354 dt_next = _optimal_step_size(dt, error_ratio, safety, ifactor, dfactor) 355 rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, 356 interp_coeff) 357 358 with ops.name_scope('update/history'): 359 history = _History( 360 _ta_append(history.integrate_points, t0 + dt), 361 _ta_append(history.error_ratio, error_ratio)) 362 return rk_state, history, n_steps + 1 363 364 def interpolate(solution, history, rk_state, i): 365 """Interpolate through the next time point, integrating as necessary.""" 366 with ops.name_scope('interpolate'): 367 rk_state, history, _ = control_flow_ops.while_loop( 368 lambda rk_state, *_: t[i] > rk_state.t1, 369 adaptive_runge_kutta_step, (rk_state, history, 0), 370 name='integrate_loop') 371 y = _interp_evaluate(rk_state.interp_coeff, rk_state.t0, rk_state.t1, 372 t[i]) 373 solution = solution.write(i, y) 374 return solution, history, rk_state, i + 1 375 376 with _assert_increasing(t): 377 num_times = array_ops.size(t) 378 379 solution = tensor_array_ops.TensorArray( 380 y0.dtype, size=num_times).write(0, y0) 381 history = _History( 382 integrate_points=tensor_array_ops.TensorArray( 383 t.dtype, size=0, dynamic_size=True), 384 error_ratio=tensor_array_ops.TensorArray( 385 rtol.dtype, size=0, dynamic_size=True)) 386 rk_state = _RungeKuttaState( 387 y0, func(y0, t[0]), t[0], t[0], first_step, interp_coeff=[y0] * 5) 388 389 solution, history, _, _ = control_flow_ops.while_loop( 390 lambda _, __, ___, i: i < num_times, 391 interpolate, (solution, history, rk_state, 1), 392 name='interpolate_loop') 393 394 y = solution.stack(name=scope) 395 y.set_shape(t.get_shape().concatenate(y0.get_shape())) 396 if not full_output: 397 return y 398 else: 399 integrate_points = history.integrate_points.stack() 400 info_dict = { 401 'num_func_evals': 6 * array_ops.size(integrate_points) + 1, 402 'integrate_points': integrate_points, 403 'error_ratio': history.error_ratio.stack() 404 } 405 return (y, info_dict) 406 407 408def odeint(func, 409 y0, 410 t, 411 rtol=1e-6, 412 atol=1e-12, 413 method=None, 414 options=None, 415 full_output=False, 416 name=None): 417 """Integrate a system of ordinary differential equations. 418 419 Solves the initial value problem for a non-stiff system of first order ODEs: 420 421 ``` 422 dy/dt = func(y, t), y(t[0]) = y0 423 ``` 424 425 where y is a Tensor of any shape. 426 427 For example: 428 429 ``` 430 # solve `dy/dt = -y`, corresponding to exponential decay 431 tf.contrib.integrate.odeint(lambda y, _: -y, 1.0, [0, 1, 2]) 432 => [1, exp(-1), exp(-2)] 433 ``` 434 435 Output dtypes and numerical precision are based on the dtypes of the inputs 436 `y0` and `t`. 437 438 Currently, implements 5th order Runge-Kutta with adaptive step size control 439 and dense output, using the Dormand-Prince method. Similar to the 'dopri5' 440 method of `scipy.integrate.ode` and MATLAB's `ode45`. 441 442 Based on: Shampine, Lawrence F. (1986), "Some Practical Runge-Kutta Formulas", 443 Mathematics of Computation, American Mathematical Society, 46 (173): 135-150, 444 doi:10.2307/2008219 445 446 Args: 447 func: Function that maps a Tensor holding the state `y` and a scalar Tensor 448 `t` into a Tensor of state derivatives with respect to time. 449 y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May 450 have any floating point or complex dtype. 451 t: 1-D Tensor holding a sequence of time points for which to solve for 452 `y`. The initial time point should be the first element of this sequence, 453 and each time must be larger than the previous time. May have any floating 454 point dtype. If not provided as a Tensor, converted to a Tensor with 455 float64 dtype. 456 rtol: optional float64 Tensor specifying an upper bound on relative error, 457 per element of `y`. 458 atol: optional float64 Tensor specifying an upper bound on absolute error, 459 per element of `y`. 460 method: optional string indicating the integration method to use. Currently, 461 the only valid option is `'dopri5'`. 462 options: optional dict of configuring options for the indicated integration 463 method. Can only be provided if a `method` is explicitly set. For 464 `'dopri5'`, valid options include: 465 * first_step: an initial guess for the size of the first integration 466 (current default: 1.0, but may later be changed to use heuristics based 467 on the gradient). 468 * safety: safety factor for adaptive step control, generally a constant 469 in the range 0.8-1 (default: 0.9). 470 * ifactor: maximum factor by which the adaptive step may be increased 471 (default: 10.0). 472 * dfactor: maximum factor by which the adpative step may be decreased 473 (default: 0.2). 474 * max_num_steps: integer maximum number of integrate steps between time 475 points in `t` (default: 1000). 476 full_output: optional boolean. If True, `odeint` returns a tuple 477 `(y, info_dict)` describing the integration process. 478 name: Optional name for this operation. 479 480 Returns: 481 y: (N+1)-D tensor, where the first dimension corresponds to different 482 time points. Contains the solved value of y for each desired time point in 483 `t`, with the initial value `y0` being the first element along the first 484 dimension. 485 info_dict: only if `full_output == True`. A dict with the following values: 486 * num_func_evals: integer Tensor counting the number of function 487 evaluations. 488 * integrate_points: 1D float64 Tensor with the upper bound of each 489 integration time step. 490 * error_ratio: 1D float Tensor with the estimated ratio of the integration 491 error to the error tolerance at each integration step. An ratio greater 492 than 1 corresponds to rejected steps. 493 494 Raises: 495 ValueError: if an invalid `method` is provided. 496 TypeError: if `options` is supplied without `method`, or if `t` or `y0` has 497 an invalid dtype. 498 """ 499 if method is not None and method != 'dopri5': 500 raise ValueError('invalid method: %r' % method) 501 502 if options is None: 503 options = {} 504 elif method is None: 505 raise ValueError('cannot supply `options` without specifying `method`') 506 507 with ops.name_scope(name, 'odeint', [y0, t, rtol, atol]) as scope: 508 # TODO(shoyer): use nest.flatten (like tf.while_loop) to allow `y0` to be an 509 # arbitrarily nested tuple. This will help performance and usability by 510 # avoiding the need to pack/unpack in user functions. 511 y0 = ops.convert_to_tensor(y0, name='y0') 512 t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t') 513 _check_input_types(t, y0) 514 515 error_dtype = abs(y0).dtype 516 rtol = ops.convert_to_tensor(rtol, dtype=error_dtype, name='rtol') 517 atol = ops.convert_to_tensor(atol, dtype=error_dtype, name='atol') 518 519 return _dopri5( 520 func, 521 y0, 522 t, 523 rtol=rtol, 524 atol=atol, 525 full_output=full_output, 526 name=scope, 527 **options) 528 529 530class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)): 531 """Base class for fixed-grid ODE integrators.""" 532 533 def integrate(self, evol_func, y0, time_grid): 534 time_delta_grid = time_grid[1:] - time_grid[:-1] 535 536 scan_func = self._make_scan_func(evol_func) 537 538 y_grid = functional_ops.scan(scan_func, (time_grid[:-1], time_delta_grid), 539 y0) 540 return array_ops.concat([[y0], y_grid], axis=0) 541 542 def _make_scan_func(self, evol_func): 543 544 def scan_func(y, t_and_dt): 545 t, dt = t_and_dt 546 dy = self._step_func(evol_func, t, dt, y) 547 dy = math_ops.cast(dy, dtype=y.dtype) 548 return y + dy 549 550 return scan_func 551 552 @abc.abstractmethod 553 def _step_func(self, evol_func, t, dt, y): 554 pass 555 556 557class _MidpointFixedGridIntegrator(_FixedGridIntegrator): 558 559 def _step_func(self, evol_func, t, dt, y): 560 dt_cast = math_ops.cast(dt, y.dtype) 561 # yn1 = yn + h * f(tn + h/2, yn + f(tn, yn) * h/2) 562 return dt_cast * evol_func(y + evol_func(y, t) * dt_cast / 2, t + dt / 2) 563 564 565class _RK4FixedGridIntegrator(_FixedGridIntegrator): 566 567 def _step_func(self, evol_func, t, dt, y): 568 k1 = evol_func(y, t) 569 half_step = t + dt / 2 570 dt_cast = math_ops.cast(dt, y.dtype) 571 572 k2 = evol_func(y + dt_cast * k1 / 2, half_step) 573 k3 = evol_func(y + dt_cast * k2 / 2, half_step) 574 k4 = evol_func(y + dt_cast * k3, t + dt) 575 return math_ops.add_n([k1, 2 * k2, 2 * k3, k4]) * (dt_cast / 6) 576 577 578def odeint_fixed(func, y0, t, method='rk4', name=None): 579 """ODE integration on a fixed grid (with no step size control). 580 581 Useful in certain scenarios to avoid the overhead of adaptive step size 582 control, e.g. when differentiation of the integration result is desired and/or 583 the time grid is known a priori to be sufficient. 584 585 Args: 586 func: Function that maps a Tensor holding the state `y` and a scalar Tensor 587 `t` into a Tensor of state derivatives with respect to time. 588 y0: N-D Tensor giving starting value of `y` at time point `t[0]`. 589 t: 1-D Tensor holding a sequence of time points for which to solve for 590 `y`. The initial time point should be the first element of this sequence, 591 and each time must be larger than the previous time. May have any floating 592 point dtype. 593 method: One of 'midpoint' or 'rk4'. 594 name: Optional name for the resulting operation. 595 596 Returns: 597 y: (N+1)-D tensor, where the first dimension corresponds to different 598 time points. Contains the solved value of y for each desired time point in 599 `t`, with the initial value `y0` being the first element along the first 600 dimension. 601 602 Raises: 603 ValueError: Upon caller errors. 604 """ 605 with ops.name_scope(name, 'odeint_fixed', [y0, t]): 606 t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t') 607 y0 = ops.convert_to_tensor(y0, name='y0') 608 _check_input_types(t, y0) 609 610 with _assert_increasing(t): 611 with ops.name_scope(method): 612 if method == 'midpoint': 613 return _MidpointFixedGridIntegrator().integrate(func, y0, t) 614 elif method == 'rk4': 615 return _RK4FixedGridIntegrator().integrate(func, y0, t) 616 else: 617 raise ValueError('method not supported: {!s}'.format(method)) 618