1// -*- coding: utf-8
2// vim: set fileencoding=utf-8
3
4// This file is part of Eigen, a lightweight C++ template library
5// for linear algebra.
6//
7// Copyright (C) 2009 Thomas Capricelli <orzel@freehackers.org>
8//
9// This Source Code Form is subject to the terms of the Mozilla
10// Public License v. 2.0. If a copy of the MPL was not distributed
11// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
12
13#ifndef EIGEN_NUMERICAL_DIFF_H
14#define EIGEN_NUMERICAL_DIFF_H
15
16namespace Eigen {
17
18enum NumericalDiffMode {
19    Forward,
20    Central
21};
22
23
24/**
25  * This class allows you to add a method df() to your functor, which will
26  * use numerical differentiation to compute an approximate of the
27  * derivative for the functor. Of course, if you have an analytical form
28  * for the derivative, you should rather implement df() by yourself.
29  *
30  * More information on
31  * http://en.wikipedia.org/wiki/Numerical_differentiation
32  *
33  * Currently only "Forward" and "Central" scheme are implemented.
34  */
35template<typename _Functor, NumericalDiffMode mode=Forward>
36class NumericalDiff : public _Functor
37{
38public:
39    typedef _Functor Functor;
40    typedef typename Functor::Scalar Scalar;
41    typedef typename Functor::InputType InputType;
42    typedef typename Functor::ValueType ValueType;
43    typedef typename Functor::JacobianType JacobianType;
44
45    NumericalDiff(Scalar _epsfcn=0.) : Functor(), epsfcn(_epsfcn) {}
46    NumericalDiff(const Functor& f, Scalar _epsfcn=0.) : Functor(f), epsfcn(_epsfcn) {}
47
48    // forward constructors
49    template<typename T0>
50        NumericalDiff(const T0& a0) : Functor(a0), epsfcn(0) {}
51    template<typename T0, typename T1>
52        NumericalDiff(const T0& a0, const T1& a1) : Functor(a0, a1), epsfcn(0) {}
53    template<typename T0, typename T1, typename T2>
54        NumericalDiff(const T0& a0, const T1& a1, const T2& a2) : Functor(a0, a1, a2), epsfcn(0) {}
55
56    enum {
57        InputsAtCompileTime = Functor::InputsAtCompileTime,
58        ValuesAtCompileTime = Functor::ValuesAtCompileTime
59    };
60
61    /**
62      * return the number of evaluation of functor
63     */
64    int df(const InputType& _x, JacobianType &jac) const
65    {
66        using std::sqrt;
67        using std::abs;
68        /* Local variables */
69        Scalar h;
70        int nfev=0;
71        const typename InputType::Index n = _x.size();
72        const Scalar eps = sqrt(((std::max)(epsfcn,NumTraits<Scalar>::epsilon() )));
73        ValueType val1, val2;
74        InputType x = _x;
75        // TODO : we should do this only if the size is not already known
76        val1.resize(Functor::values());
77        val2.resize(Functor::values());
78
79        // initialization
80        switch(mode) {
81            case Forward:
82                // compute f(x)
83                Functor::operator()(x, val1); nfev++;
84                break;
85            case Central:
86                // do nothing
87                break;
88            default:
89                eigen_assert(false);
90        };
91
92        // Function Body
93        for (int j = 0; j < n; ++j) {
94            h = eps * abs(x[j]);
95            if (h == 0.) {
96                h = eps;
97            }
98            switch(mode) {
99                case Forward:
100                    x[j] += h;
101                    Functor::operator()(x, val2);
102                    nfev++;
103                    x[j] = _x[j];
104                    jac.col(j) = (val2-val1)/h;
105                    break;
106                case Central:
107                    x[j] += h;
108                    Functor::operator()(x, val2); nfev++;
109                    x[j] -= 2*h;
110                    Functor::operator()(x, val1); nfev++;
111                    x[j] = _x[j];
112                    jac.col(j) = (val2-val1)/(2*h);
113                    break;
114                default:
115                    eigen_assert(false);
116            };
117        }
118        return nfev;
119    }
120private:
121    Scalar epsfcn;
122
123    NumericalDiff& operator=(const NumericalDiff&);
124};
125
126} // end namespace Eigen
127
128//vim: ai ts=4 sts=4 et sw=4
129#endif // EIGEN_NUMERICAL_DIFF_H
130
131