1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""ODE solvers for TensorFlow."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22import collections
23
24import six
25
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import functional_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import tensor_array_ops
34
35_ButcherTableau = collections.namedtuple('_ButcherTableau',
36                                         'alpha beta c_sol c_mid c_error')
37
38# Parameters from Shampine (1986), section 4.
39_DORMAND_PRINCE_TABLEAU = _ButcherTableau(
40    alpha=[1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1.],
41    beta=[
42        [1 / 5],
43        [3 / 40, 9 / 40],
44        [44 / 45, -56 / 15, 32 / 9],
45        [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729],
46        [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656],
47        [35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84],
48    ],
49    c_sol=[35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0],
50    c_mid=[
51        6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2,
52        -2691868925 / 45128329728 / 2, 187940372067 / 1594534317056 / 2,
53        -1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2
54    ],
55    c_error=[
56        1951 / 21600 - 35 / 384,
57        0,
58        22642 / 50085 - 500 / 1113,
59        451 / 720 - 125 / 192,
60        -12231 / 42400 - -2187 / 6784,
61        649 / 6300 - 11 / 84,
62        1 / 60,
63    ],)
64
65
66def _possibly_nonzero(x):
67  return isinstance(x, ops.Tensor) or x != 0
68
69
70def _scaled_dot_product(scale, xs, ys, name=None):
71  """Calculate a scaled, vector inner product between lists of Tensors."""
72  with ops.name_scope(name, 'scaled_dot_product', [scale, xs, ys]) as scope:
73    # Some of the parameters in our Butcher tableau include zeros. Using
74    # _possibly_nonzero lets us avoid wasted computation.
75    return math_ops.add_n(
76        [(scale * x) * y for x, y in zip(xs, ys)
77         if _possibly_nonzero(x) or _possibly_nonzero(y)],
78        name=scope)
79
80
81def _dot_product(xs, ys, name=None):
82  """Calculate the vector inner product between two lists of Tensors."""
83  with ops.name_scope(name, 'dot_product', [xs, ys]) as scope:
84    return math_ops.add_n([x * y for x, y in zip(xs, ys)], name=scope)
85
86
87def _runge_kutta_step(func,
88                      y0,
89                      f0,
90                      t0,
91                      dt,
92                      tableau=_DORMAND_PRINCE_TABLEAU,
93                      name=None):
94  """Take an arbitrary Runge-Kutta step and estimate error.
95
96  Args:
97    func: Function to evaluate like `func(y, t)` to compute the time derivative
98      of `y`.
99    y0: Tensor initial value for the state.
100    f0: Tensor initial value for the derivative, computed from `func(y0, t0)`.
101    t0: float64 scalar Tensor giving the initial time.
102    dt: float64 scalar Tensor giving the size of the desired time step.
103    tableau: optional _ButcherTableau describing how to take the Runge-Kutta
104      step.
105    name: optional name for the operation.
106
107  Returns:
108    Tuple `(y1, f1, y1_error, k)` giving the estimated function value after
109    the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`,
110    estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for
111    calculating these terms.
112  """
113  with ops.name_scope(name, 'runge_kutta_step', [y0, f0, t0, dt]) as scope:
114    y0 = ops.convert_to_tensor(y0, name='y0')
115    f0 = ops.convert_to_tensor(f0, name='f0')
116    t0 = ops.convert_to_tensor(t0, name='t0')
117    dt = ops.convert_to_tensor(dt, name='dt')
118    dt_cast = math_ops.cast(dt, y0.dtype)
119
120    k = [f0]
121    for alpha_i, beta_i in zip(tableau.alpha, tableau.beta):
122      ti = t0 + alpha_i * dt
123      yi = y0 + _scaled_dot_product(dt_cast, beta_i, k)
124      k.append(func(yi, ti))
125
126    if not (tableau.c_sol[-1] == 0 and tableau.c_sol == tableau.beta[-1]):
127      # This property (true for Dormand-Prince) lets us save a few FLOPs.
128      yi = y0 + _scaled_dot_product(dt_cast, tableau.c_sol, k)
129
130    y1 = array_ops.identity(yi, name='%s/y1' % scope)
131    f1 = array_ops.identity(k[-1], name='%s/f1' % scope)
132    y1_error = _scaled_dot_product(
133        dt_cast, tableau.c_error, k, name='%s/y1_error' % scope)
134    return (y1, f1, y1_error, k)
135
136
137def _interp_fit(y0, y1, y_mid, f0, f1, dt):
138  """Fit coefficients for 4th order polynomial interpolation.
139
140  Args:
141    y0: function value at the start of the interval.
142    y1: function value at the end of the interval.
143    y_mid: function value at the mid-point of the interval.
144    f0: derivative value at the start of the interval.
145    f1: derivative value at the end of the interval.
146    dt: width of the interval.
147
148  Returns:
149    List of coefficients `[a, b, c, d, e]` for interpolating with the polynomial
150    `p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e` for values of `x`
151    between 0 (start of interval) and 1 (end of interval).
152  """
153  # a, b, c, d, e = sympy.symbols('a b c d e')
154  # x, dt, y0, y1, y_mid, f0, f1 = sympy.symbols('x dt y0 y1 y_mid f0 f1')
155  # p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e
156  # sympy.solve([p.subs(x, 0) - y0,
157  #              p.subs(x, 1 / 2) - y_mid,
158  #              p.subs(x, 1) - y1,
159  #              (p.diff(x) / dt).subs(x, 0) - f0,
160  #              (p.diff(x) / dt).subs(x, 1) - f1],
161  #             [a, b, c, d, e])
162  # {a: -2.0*dt*f0 + 2.0*dt*f1 - 8.0*y0 - 8.0*y1 + 16.0*y_mid,
163  #  b: 5.0*dt*f0 - 3.0*dt*f1 + 18.0*y0 + 14.0*y1 - 32.0*y_mid,
164  #  c: -4.0*dt*f0 + dt*f1 - 11.0*y0 - 5.0*y1 + 16.0*y_mid,
165  #  d: dt*f0,
166  #  e: y0}
167  a = _dot_product([-2 * dt, 2 * dt, -8, -8, 16], [f0, f1, y0, y1, y_mid])
168  b = _dot_product([5 * dt, -3 * dt, 18, 14, -32], [f0, f1, y0, y1, y_mid])
169  c = _dot_product([-4 * dt, dt, -11, -5, 16], [f0, f1, y0, y1, y_mid])
170  d = dt * f0
171  e = y0
172  return [a, b, c, d, e]
173
174
175def _interp_fit_rk(y0, y1, k, dt, tableau=_DORMAND_PRINCE_TABLEAU):
176  """Fit an interpolating polynomial to the results of a Runge-Kutta step."""
177  with ops.name_scope('interp_fit_rk'):
178    dt = math_ops.cast(dt, y0.dtype)
179    y_mid = y0 + _scaled_dot_product(dt, tableau.c_mid, k)
180    f0 = k[0]
181    f1 = k[-1]
182    return _interp_fit(y0, y1, y_mid, f0, f1, dt)
183
184
185def _interp_evaluate(coefficients, t0, t1, t):
186  """Evaluate polynomial interpolation at the given time point.
187
188  Args:
189    coefficients: list of Tensor coefficients as created by `interp_fit`.
190    t0: scalar float64 Tensor giving the start of the interval.
191    t1: scalar float64 Tensor giving the end of the interval.
192    t: scalar float64 Tensor giving the desired interpolation point.
193
194  Returns:
195    Polynomial interpolation of the coefficients at time `t`.
196  """
197  with ops.name_scope('interp_evaluate'):
198    t0 = ops.convert_to_tensor(t0)
199    t1 = ops.convert_to_tensor(t1)
200    t = ops.convert_to_tensor(t)
201
202    dtype = coefficients[0].dtype
203
204    assert_op = control_flow_ops.Assert(
205        (t0 <= t) & (t <= t1),
206        ['invalid interpolation, fails `t0 <= t <= t1`:', t0, t, t1])
207    with ops.control_dependencies([assert_op]):
208      x = math_ops.cast((t - t0) / (t1 - t0), dtype)
209
210    xs = [constant_op.constant(1, dtype), x]
211    for _ in range(2, len(coefficients)):
212      xs.append(xs[-1] * x)
213
214    return _dot_product(coefficients, reversed(xs))
215
216
217def _optimal_step_size(last_step,
218                       error_ratio,
219                       safety=0.9,
220                       ifactor=10.0,
221                       dfactor=0.2,
222                       order=5,
223                       name=None):
224  """Calculate the optimal size for the next Runge-Kutta step."""
225  with ops.name_scope(name, 'optimal_step_size', [last_step,
226                                                  error_ratio]) as scope:
227    error_ratio = math_ops.cast(error_ratio, last_step.dtype)
228    exponent = math_ops.cast(1 / order, last_step.dtype)
229    # this looks more complex than necessary, but importantly it keeps
230    # error_ratio in the numerator so we can't divide by zero:
231    factor = math_ops.maximum(1 / ifactor,
232                              math_ops.minimum(error_ratio**exponent / safety,
233                                               1 / dfactor))
234    return math_ops.div(last_step, factor, name=scope)
235
236
237def _abs_square(x):
238  if x.dtype.is_complex:
239    return math_ops.square(math_ops.real(x)) + math_ops.square(math_ops.imag(x))
240  else:
241    return math_ops.square(x)
242
243
244def _ta_append(tensor_array, value):
245  """Append a value to the end of a tf.TensorArray."""
246  return tensor_array.write(tensor_array.size(), value)
247
248
249class _RungeKuttaState(
250    collections.namedtuple('_RungeKuttaState',
251                           'y1, f1, t0, t1, dt, interp_coeff')):
252  """Saved state of the Runge Kutta solver.
253
254  Attributes:
255    y1: Tensor giving the function value at the end of the last time step.
256    f1: Tensor giving derivative at the end of the last time step.
257    t0: scalar float64 Tensor giving start of the last time step.
258    t1: scalar float64 Tensor giving end of the last time step.
259    dt: scalar float64 Tensor giving the size for the next time step.
260    interp_coef: list of Tensors giving coefficients for polynomial
261      interpolation between `t0` and `t1`.
262  """
263
264
265class _History(
266    collections.namedtuple('_History', 'integrate_points, error_ratio')):
267  """Saved integration history for use in `info_dict`.
268
269  Attributes:
270    integrate_points: tf.TensorArray storing integrating time points.
271    error_ratio: tf.TensorArray storing computed error ratios at each
272      integration step.
273  """
274
275
276def _assert_increasing(t):
277  assert_increasing = control_flow_ops.Assert(
278      math_ops.reduce_all(t[1:] > t[:-1]), ['`t` must be monotonic increasing'])
279  return ops.control_dependencies([assert_increasing])
280
281
282def _check_input_types(t, y0):
283  if not (y0.dtype.is_floating or y0.dtype.is_complex):
284    raise TypeError('`y0` must have a floating point or complex floating '
285                    'point dtype')
286  if not t.dtype.is_floating:
287    raise TypeError('`t` must have a floating point dtype')
288
289
290def _dopri5(func,
291            y0,
292            t,
293            rtol,
294            atol,
295            full_output=False,
296            first_step=None,
297            safety=0.9,
298            ifactor=10.0,
299            dfactor=0.2,
300            max_num_steps=1000,
301            name=None):
302  """Solve an ODE for `odeint` using method='dopri5'."""
303
304  if first_step is None:
305    # at some point, we might want to switch to picking the step size
306    # automatically
307    first_step = 1.0
308
309  with ops.name_scope(name, 'dopri5', [
310      y0, t, rtol, atol, safety, ifactor, dfactor, max_num_steps
311  ]) as scope:
312
313    first_step = ops.convert_to_tensor(
314        first_step, dtype=t.dtype, name='first_step')
315    safety = ops.convert_to_tensor(safety, dtype=t.dtype, name='safety')
316    ifactor = ops.convert_to_tensor(ifactor, dtype=t.dtype, name='ifactor')
317    dfactor = ops.convert_to_tensor(dfactor, dtype=t.dtype, name='dfactor')
318    max_num_steps = ops.convert_to_tensor(
319        max_num_steps, dtype=dtypes.int32, name='max_num_steps')
320
321    def adaptive_runge_kutta_step(rk_state, history, n_steps):
322      """Take an adaptive Runge-Kutta step to integrate the ODE."""
323      y0, f0, _, t0, dt, interp_coeff = rk_state
324      with ops.name_scope('assertions'):
325        check_underflow = control_flow_ops.Assert(t0 + dt > t0,
326                                                  ['underflow in dt', dt])
327        check_max_num_steps = control_flow_ops.Assert(
328            n_steps < max_num_steps, ['max_num_steps exceeded'])
329        check_numerics = control_flow_ops.Assert(
330            math_ops.reduce_all(math_ops.is_finite(abs(y0))),
331            ['non-finite values in state `y`', y0])
332      with ops.control_dependencies(
333          [check_underflow, check_max_num_steps, check_numerics]):
334        y1, f1, y1_error, k = _runge_kutta_step(func, y0, f0, t0, dt)
335
336      with ops.name_scope('error_ratio'):
337        # We use the same approach as the dopri5 fortran code.
338        error_tol = atol + rtol * math_ops.maximum(abs(y0), abs(y1))
339        tensor_error_ratio = _abs_square(y1_error) / _abs_square(error_tol)
340        # Could also use reduce_maximum here.
341        error_ratio = math_ops.sqrt(math_ops.reduce_mean(tensor_error_ratio))
342        accept_step = error_ratio <= 1
343
344      with ops.name_scope('update/rk_state'):
345        # If we don't accept the step, the _RungeKuttaState will be useless
346        # (covering a time-interval of size 0), but that's OK, because in such
347        # cases we always immediately take another Runge-Kutta step.
348        y_next = control_flow_ops.cond(accept_step, lambda: y1, lambda: y0)
349        f_next = control_flow_ops.cond(accept_step, lambda: f1, lambda: f0)
350        t_next = control_flow_ops.cond(accept_step, lambda: t0 + dt, lambda: t0)
351        interp_coeff = control_flow_ops.cond(
352            accept_step, lambda: _interp_fit_rk(y0, y1, k, dt),
353            lambda: interp_coeff)
354        dt_next = _optimal_step_size(dt, error_ratio, safety, ifactor, dfactor)
355        rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next,
356                                    interp_coeff)
357
358      with ops.name_scope('update/history'):
359        history = _History(
360            _ta_append(history.integrate_points, t0 + dt),
361            _ta_append(history.error_ratio, error_ratio))
362      return rk_state, history, n_steps + 1
363
364    def interpolate(solution, history, rk_state, i):
365      """Interpolate through the next time point, integrating as necessary."""
366      with ops.name_scope('interpolate'):
367        rk_state, history, _ = control_flow_ops.while_loop(
368            lambda rk_state, *_: t[i] > rk_state.t1,
369            adaptive_runge_kutta_step, (rk_state, history, 0),
370            name='integrate_loop')
371        y = _interp_evaluate(rk_state.interp_coeff, rk_state.t0, rk_state.t1,
372                             t[i])
373        solution = solution.write(i, y)
374        return solution, history, rk_state, i + 1
375
376    with _assert_increasing(t):
377      num_times = array_ops.size(t)
378
379    solution = tensor_array_ops.TensorArray(
380        y0.dtype, size=num_times).write(0, y0)
381    history = _History(
382        integrate_points=tensor_array_ops.TensorArray(
383            t.dtype, size=0, dynamic_size=True),
384        error_ratio=tensor_array_ops.TensorArray(
385            rtol.dtype, size=0, dynamic_size=True))
386    rk_state = _RungeKuttaState(
387        y0, func(y0, t[0]), t[0], t[0], first_step, interp_coeff=[y0] * 5)
388
389    solution, history, _, _ = control_flow_ops.while_loop(
390        lambda _, __, ___, i: i < num_times,
391        interpolate, (solution, history, rk_state, 1),
392        name='interpolate_loop')
393
394    y = solution.stack(name=scope)
395    y.set_shape(t.get_shape().concatenate(y0.get_shape()))
396    if not full_output:
397      return y
398    else:
399      integrate_points = history.integrate_points.stack()
400      info_dict = {
401          'num_func_evals': 6 * array_ops.size(integrate_points) + 1,
402          'integrate_points': integrate_points,
403          'error_ratio': history.error_ratio.stack()
404      }
405      return (y, info_dict)
406
407
408def odeint(func,
409           y0,
410           t,
411           rtol=1e-6,
412           atol=1e-12,
413           method=None,
414           options=None,
415           full_output=False,
416           name=None):
417  """Integrate a system of ordinary differential equations.
418
419  Solves the initial value problem for a non-stiff system of first order ODEs:
420
421    ```
422    dy/dt = func(y, t), y(t[0]) = y0
423    ```
424
425  where y is a Tensor of any shape.
426
427  For example:
428
429    ```
430    # solve `dy/dt = -y`, corresponding to exponential decay
431    tf.contrib.integrate.odeint(lambda y, _: -y, 1.0, [0, 1, 2])
432    => [1, exp(-1), exp(-2)]
433    ```
434
435  Output dtypes and numerical precision are based on the dtypes of the inputs
436  `y0` and `t`.
437
438  Currently, implements 5th order Runge-Kutta with adaptive step size control
439  and dense output, using the Dormand-Prince method. Similar to the 'dopri5'
440  method of `scipy.integrate.ode` and MATLAB's `ode45`.
441
442  Based on: Shampine, Lawrence F. (1986), "Some Practical Runge-Kutta Formulas",
443  Mathematics of Computation, American Mathematical Society, 46 (173): 135-150,
444  doi:10.2307/2008219
445
446  Args:
447    func: Function that maps a Tensor holding the state `y` and a scalar Tensor
448      `t` into a Tensor of state derivatives with respect to time.
449    y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May
450      have any floating point or complex dtype.
451    t: 1-D Tensor holding a sequence of time points for which to solve for
452      `y`. The initial time point should be the first element of this sequence,
453      and each time must be larger than the previous time. May have any floating
454      point dtype. If not provided as a Tensor, converted to a Tensor with
455      float64 dtype.
456    rtol: optional float64 Tensor specifying an upper bound on relative error,
457      per element of `y`.
458    atol: optional float64 Tensor specifying an upper bound on absolute error,
459      per element of `y`.
460    method: optional string indicating the integration method to use. Currently,
461      the only valid option is `'dopri5'`.
462    options: optional dict of configuring options for the indicated integration
463      method. Can only be provided if a `method` is explicitly set. For
464      `'dopri5'`, valid options include:
465      * first_step: an initial guess for the size of the first integration
466        (current default: 1.0, but may later be changed to use heuristics based
467        on the gradient).
468      * safety: safety factor for adaptive step control, generally a constant
469        in the range 0.8-1 (default: 0.9).
470      * ifactor: maximum factor by which the adaptive step may be increased
471        (default: 10.0).
472      * dfactor: maximum factor by which the adpative step may be decreased
473        (default: 0.2).
474      * max_num_steps: integer maximum number of integrate steps between time
475        points in `t` (default: 1000).
476    full_output: optional boolean. If True, `odeint` returns a tuple
477      `(y, info_dict)` describing the integration process.
478    name: Optional name for this operation.
479
480  Returns:
481    y: (N+1)-D tensor, where the first dimension corresponds to different
482      time points. Contains the solved value of y for each desired time point in
483      `t`, with the initial value `y0` being the first element along the first
484      dimension.
485    info_dict: only if `full_output == True`. A dict with the following values:
486      * num_func_evals: integer Tensor counting the number of function
487        evaluations.
488      * integrate_points: 1D float64 Tensor with the upper bound of each
489        integration time step.
490      * error_ratio: 1D float Tensor with the estimated ratio of the integration
491        error to the error tolerance at each integration step. An ratio greater
492        than 1 corresponds to rejected steps.
493
494  Raises:
495    ValueError: if an invalid `method` is provided.
496    TypeError: if `options` is supplied without `method`, or if `t` or `y0` has
497      an invalid dtype.
498  """
499  if method is not None and method != 'dopri5':
500    raise ValueError('invalid method: %r' % method)
501
502  if options is None:
503    options = {}
504  elif method is None:
505    raise ValueError('cannot supply `options` without specifying `method`')
506
507  with ops.name_scope(name, 'odeint', [y0, t, rtol, atol]) as scope:
508    # TODO(shoyer): use nest.flatten (like tf.while_loop) to allow `y0` to be an
509    # arbitrarily nested tuple. This will help performance and usability by
510    # avoiding the need to pack/unpack in user functions.
511    y0 = ops.convert_to_tensor(y0, name='y0')
512    t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
513    _check_input_types(t, y0)
514
515    error_dtype = abs(y0).dtype
516    rtol = ops.convert_to_tensor(rtol, dtype=error_dtype, name='rtol')
517    atol = ops.convert_to_tensor(atol, dtype=error_dtype, name='atol')
518
519    return _dopri5(
520        func,
521        y0,
522        t,
523        rtol=rtol,
524        atol=atol,
525        full_output=full_output,
526        name=scope,
527        **options)
528
529
530class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)):
531  """Base class for fixed-grid ODE integrators."""
532
533  def integrate(self, evol_func, y0, time_grid):
534    time_delta_grid = time_grid[1:] - time_grid[:-1]
535
536    scan_func = self._make_scan_func(evol_func)
537
538    y_grid = functional_ops.scan(scan_func, (time_grid[:-1], time_delta_grid),
539                                 y0)
540    return array_ops.concat([[y0], y_grid], axis=0)
541
542  def _make_scan_func(self, evol_func):
543
544    def scan_func(y, t_and_dt):
545      t, dt = t_and_dt
546      dy = self._step_func(evol_func, t, dt, y)
547      dy = math_ops.cast(dy, dtype=y.dtype)
548      return y + dy
549
550    return scan_func
551
552  @abc.abstractmethod
553  def _step_func(self, evol_func, t, dt, y):
554    pass
555
556
557class _MidpointFixedGridIntegrator(_FixedGridIntegrator):
558
559  def _step_func(self, evol_func, t, dt, y):
560    dt_cast = math_ops.cast(dt, y.dtype)
561    # yn1 = yn + h * f(tn + h/2, yn + f(tn, yn) * h/2)
562    return dt_cast * evol_func(y + evol_func(y, t) * dt_cast / 2, t + dt / 2)
563
564
565class _RK4FixedGridIntegrator(_FixedGridIntegrator):
566
567  def _step_func(self, evol_func, t, dt, y):
568    k1 = evol_func(y, t)
569    half_step = t + dt / 2
570    dt_cast = math_ops.cast(dt, y.dtype)
571
572    k2 = evol_func(y + dt_cast * k1 / 2, half_step)
573    k3 = evol_func(y + dt_cast * k2 / 2, half_step)
574    k4 = evol_func(y + dt_cast * k3, t + dt)
575    return math_ops.add_n([k1, 2 * k2, 2 * k3, k4]) * (dt_cast / 6)
576
577
578def odeint_fixed(func, y0, t, method='rk4', name=None):
579  """ODE integration on a fixed grid (with no step size control).
580
581  Useful in certain scenarios to avoid the overhead of adaptive step size
582  control, e.g. when differentiation of the integration result is desired and/or
583  the time grid is known a priori to be sufficient.
584
585  Args:
586    func: Function that maps a Tensor holding the state `y` and a scalar Tensor
587      `t` into a Tensor of state derivatives with respect to time.
588    y0: N-D Tensor giving starting value of `y` at time point `t[0]`.
589    t: 1-D Tensor holding a sequence of time points for which to solve for
590      `y`. The initial time point should be the first element of this sequence,
591      and each time must be larger than the previous time. May have any floating
592      point dtype.
593    method: One of 'midpoint' or 'rk4'.
594    name: Optional name for the resulting operation.
595
596  Returns:
597    y: (N+1)-D tensor, where the first dimension corresponds to different
598      time points. Contains the solved value of y for each desired time point in
599      `t`, with the initial value `y0` being the first element along the first
600      dimension.
601
602  Raises:
603    ValueError: Upon caller errors.
604  """
605  with ops.name_scope(name, 'odeint_fixed', [y0, t]):
606    t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
607    y0 = ops.convert_to_tensor(y0, name='y0')
608    _check_input_types(t, y0)
609
610    with _assert_increasing(t):
611      with ops.name_scope(method):
612        if method == 'midpoint':
613          return _MidpointFixedGridIntegrator().integrate(func, y0, t)
614        elif method == 'rk4':
615          return _RK4FixedGridIntegrator().integrate(func, y0, t)
616        else:
617          raise ValueError('method not supported: {!s}'.format(method))
618