1793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler/*M///////////////////////////////////////////////////////////////////////////////////////
2793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//
3793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//
5793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//  By downloading, copying, installing or using the software you agree to this license.
6793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//  If you do not agree to this license, do not download, install,
7793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//  copy or use the software.
8793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//
9793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//
10793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//                        Intel License Agreement
11793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//
12793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// Copyright (C) 2000, Intel Corporation, all rights reserved.
13793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// Third party copyrights are property of their respective owners.
14793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//
15793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// Redistribution and use in source and binary forms, with or without modification,
16793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// are permitted provided that the following conditions are met:
17793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//
18793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//   * Redistribution's of source code must retain the above copyright notice,
19793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//     this list of conditions and the following disclaimer.
20793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//
21793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//   * Redistribution's in binary form must reproduce the above copyright notice,
22793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//     this list of conditions and the following disclaimer in the documentation
23793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//     and/or other materials provided with the distribution.
24793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//
25793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//   * The name of Intel Corporation may not be used to endorse or promote products
26793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//     derived from this software without specific prior written permission.
27793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//
28793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// This software is provided by the copyright holders and contributors "as is" and
29793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// any express or implied warranties, including, but not limited to, the implied
30793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// warranties of merchantability and fitness for a particular purpose are disclaimed.
31793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// In no event shall the Intel Corporation or contributors be liable for any direct,
32793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// indirect, incidental, special, exemplary, or consequential damages
33793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// (including, but not limited to, procurement of substitute goods or services;
34793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// loss of use, data, or profits; or business interruption) however caused
35793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// and on any theory of liability, whether in contract, strict liability,
36793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// or tort (including negligence or otherwise) arising in any way out of
37793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// the use of this software, even if advised of the possibility of such damage.
38793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//
39793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//M*/
40793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
41793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler#include "precomp.hpp"
42793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
43793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslernamespace cv { namespace ml {
44793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
45793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerstruct AnnParams
46793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler{
47793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    AnnParams()
48793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
49793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        termCrit = TermCriteria( TermCriteria::COUNT + TermCriteria::EPS, 1000, 0.01 );
50793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        trainMethod = ANN_MLP::RPROP;
51793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        bpDWScale = bpMomentScale = 0.1;
52793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        rpDW0 = 0.1; rpDWPlus = 1.2; rpDWMinus = 0.5;
53793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        rpDWMin = FLT_EPSILON; rpDWMax = 50.;
54793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
55793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
56793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    TermCriteria termCrit;
57793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    int trainMethod;
58793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
59793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    double bpDWScale;
60793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    double bpMomentScale;
61793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
62793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    double rpDW0;
63793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    double rpDWPlus;
64793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    double rpDWMinus;
65793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    double rpDWMin;
66793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    double rpDWMax;
67793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler};
68793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
69793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslertemplate <typename T>
70793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerinline T inBounds(T val, T min_val, T max_val)
71793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler{
72793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    return std::min(std::max(val, min_val), max_val);
73793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler}
74793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
75793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerclass ANN_MLPImpl : public ANN_MLP
76793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler{
77793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerpublic:
78793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    ANN_MLPImpl()
79793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
80793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        clear();
81793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        setActivationFunction( SIGMOID_SYM, 0, 0 );
82793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        setLayerSizes(Mat());
83793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        setTrainMethod(ANN_MLP::RPROP, 0.1, FLT_EPSILON);
84793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
85793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
86793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    virtual ~ANN_MLPImpl() {}
87793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
88793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CV_IMPL_PROPERTY(TermCriteria, TermCriteria, params.termCrit)
89793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CV_IMPL_PROPERTY(double, BackpropWeightScale, params.bpDWScale)
90793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CV_IMPL_PROPERTY(double, BackpropMomentumScale, params.bpMomentScale)
91793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CV_IMPL_PROPERTY(double, RpropDW0, params.rpDW0)
92793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CV_IMPL_PROPERTY(double, RpropDWPlus, params.rpDWPlus)
93793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CV_IMPL_PROPERTY(double, RpropDWMinus, params.rpDWMinus)
94793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CV_IMPL_PROPERTY(double, RpropDWMin, params.rpDWMin)
95793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CV_IMPL_PROPERTY(double, RpropDWMax, params.rpDWMax)
96793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
97793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void clear()
98793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
99793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        min_val = max_val = min_val1 = max_val1 = 0.;
100793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        rng = RNG((uint64)-1);
101793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        weights.clear();
102793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        trained = false;
103793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        max_buf_sz = 1 << 12;
104793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
105793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
106793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    int layer_count() const { return (int)layer_sizes.size(); }
107793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
108793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void setTrainMethod(int method, double param1, double param2)
109793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
110793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if (method != ANN_MLP::RPROP && method != ANN_MLP::BACKPROP)
111793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            method = ANN_MLP::RPROP;
112793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        params.trainMethod = method;
113793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if(method == ANN_MLP::RPROP )
114793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
115793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            if( param1 < FLT_EPSILON )
116793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                param1 = 1.;
117793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            params.rpDW0 = param1;
118793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            params.rpDWMin = std::max( param2, 0. );
119793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
120793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        else if(method == ANN_MLP::BACKPROP )
121793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
122793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            if( param1 <= 0 )
123793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                param1 = 0.1;
124793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            params.bpDWScale = inBounds<double>(param1, 1e-3, 1.);
125793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            if( param2 < 0 )
126793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                param2 = 0.1;
127793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            params.bpMomentScale = std::min( param2, 1. );
128793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
129793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
130793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
131793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    int getTrainMethod() const
132793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
133793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        return params.trainMethod;
134793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
135793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
136793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void setActivationFunction(int _activ_func, double _f_param1, double _f_param2 )
137793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
138793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( _activ_func < 0 || _activ_func > GAUSSIAN )
139793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            CV_Error( CV_StsOutOfRange, "Unknown activation function" );
140793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
141793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        activ_func = _activ_func;
142793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
143793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        switch( activ_func )
144793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
145793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        case SIGMOID_SYM:
146793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            max_val = 0.95; min_val = -max_val;
147793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            max_val1 = 0.98; min_val1 = -max_val1;
148793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            if( fabs(_f_param1) < FLT_EPSILON )
149793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                _f_param1 = 2./3;
150793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            if( fabs(_f_param2) < FLT_EPSILON )
151793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                _f_param2 = 1.7159;
152793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            break;
153793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        case GAUSSIAN:
154793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            max_val = 1.; min_val = 0.05;
155793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            max_val1 = 1.; min_val1 = 0.02;
156793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            if( fabs(_f_param1) < FLT_EPSILON )
157793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                _f_param1 = 1.;
158793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            if( fabs(_f_param2) < FLT_EPSILON )
159793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                _f_param2 = 1.;
160793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            break;
161793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        default:
162793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            min_val = max_val = min_val1 = max_val1 = 0.;
163793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            _f_param1 = 1.;
164793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            _f_param2 = 0.;
165793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
166793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
167793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        f_param1 = _f_param1;
168793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        f_param2 = _f_param2;
169793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
170793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
171793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
172793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void init_weights()
173793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
174793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int i, j, k, l_count = layer_count();
175793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
176793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        for( i = 1; i < l_count; i++ )
177793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
178793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            int n1 = layer_sizes[i-1];
179793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            int n2 = layer_sizes[i];
180793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            double val = 0, G = n2 > 2 ? 0.7*pow((double)n1,1./(n2-1)) : 1.;
181793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            double* w = weights[i].ptr<double>();
182793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
183793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            // initialize weights using Nguyen-Widrow algorithm
184793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( j = 0; j < n2; j++ )
185793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
186793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double s = 0;
187793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                for( k = 0; k <= n1; k++ )
188793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
189793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    val = rng.uniform(0., 1.)*2-1.;
190793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    w[k*n2 + j] = val;
191793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    s += fabs(val);
192793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
193793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
194793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                if( i < l_count - 1 )
195793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
196793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    s = 1./(s - fabs(val));
197793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    for( k = 0; k <= n1; k++ )
198793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        w[k*n2 + j] *= s;
199793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    w[n1*n2 + j] *= G*(-1+j*2./n2);
200793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
201793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
202793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
203793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
204793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
205793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    Mat getLayerSizes() const
206793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
207793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        return Mat_<int>(layer_sizes, true);
208793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
209793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
210793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void setLayerSizes( InputArray _layer_sizes )
211793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
212793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        clear();
213793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
214793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        _layer_sizes.copyTo(layer_sizes);
215793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int l_count = layer_count();
216793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
217793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        weights.resize(l_count + 2);
218793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        max_lsize = 0;
219793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
220793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( l_count > 0 )
221793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
222793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( int i = 0; i < l_count; i++ )
223793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
224793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                int n = layer_sizes[i];
225793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                if( n < 1 + (0 < i && i < l_count-1))
226793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    CV_Error( CV_StsOutOfRange,
227793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                             "there should be at least one input and one output "
228793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                             "and every hidden layer must have more than 1 neuron" );
229793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                max_lsize = std::max( max_lsize, n );
230793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                if( i > 0 )
231793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    weights[i].create(layer_sizes[i-1]+1, n, CV_64F);
232793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
233793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
234793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            int ninputs = layer_sizes.front();
235793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            int noutputs = layer_sizes.back();
236793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            weights[0].create(1, ninputs*2, CV_64F);
237793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            weights[l_count].create(1, noutputs*2, CV_64F);
238793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            weights[l_count+1].create(1, noutputs*2, CV_64F);
239793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
240793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
241793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
242793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    float predict( InputArray _inputs, OutputArray _outputs, int ) const
243793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
244793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( !trained )
245793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            CV_Error( CV_StsError, "The network has not been trained or loaded" );
246793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
247793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        Mat inputs = _inputs.getMat();
248793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int type = inputs.type(), l_count = layer_count();
249793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int n = inputs.rows, dn0 = n;
250793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
251793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        CV_Assert( (type == CV_32F || type == CV_64F) && inputs.cols == layer_sizes[0] );
252793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int noutputs = layer_sizes[l_count-1];
253793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        Mat outputs;
254793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
255793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int min_buf_sz = 2*max_lsize;
256793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int buf_sz = n*min_buf_sz;
257793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
258793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( buf_sz > max_buf_sz )
259793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
260793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            dn0 = max_buf_sz/min_buf_sz;
261793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            dn0 = std::max( dn0, 1 );
262793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            buf_sz = dn0*min_buf_sz;
263793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
264793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
265793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        cv::AutoBuffer<double> _buf(buf_sz+noutputs);
266793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        double* buf = _buf;
267793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
268793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( !_outputs.needed() )
269793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
270793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            CV_Assert( n == 1 );
271793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            outputs = Mat(n, noutputs, type, buf + buf_sz);
272793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
273793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        else
274793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
275793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            _outputs.create(n, noutputs, type);
276793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            outputs = _outputs.getMat();
277793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
278793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
279793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int dn = 0;
280793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        for( int i = 0; i < n; i += dn )
281793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
282793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            dn = std::min( dn0, n - i );
283793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
284793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            Mat layer_in = inputs.rowRange(i, i + dn);
285793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            Mat layer_out( dn, layer_in.cols, CV_64F, buf);
286793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
287793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            scale_input( layer_in, layer_out );
288793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            layer_in = layer_out;
289793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
290793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( int j = 1; j < l_count; j++ )
291793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
292793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double* data = buf + ((j&1) ? max_lsize*dn0 : 0);
293793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                int cols = layer_sizes[j];
294793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
295793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                layer_out = Mat(dn, cols, CV_64F, data);
296793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                Mat w = weights[j].rowRange(0, layer_in.cols);
297793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                gemm(layer_in, w, 1, noArray(), 0, layer_out);
298793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                calc_activ_func( layer_out, weights[j] );
299793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
300793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                layer_in = layer_out;
301793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
302793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
303793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            layer_out = outputs.rowRange(i, i + dn);
304793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            scale_output( layer_in, layer_out );
305793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
306793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
307793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( n == 1 )
308793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
309793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            int maxIdx[] = {0, 0};
310793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            minMaxIdx(outputs, 0, 0, 0, maxIdx);
311793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            return (float)(maxIdx[0] + maxIdx[1]);
312793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
313793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
314793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        return 0.f;
315793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
316793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
317793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void scale_input( const Mat& _src, Mat& _dst ) const
318793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
319793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int cols = _src.cols;
320793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        const double* w = weights[0].ptr<double>();
321793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
322793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( _src.type() == CV_32F )
323793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
324793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( int i = 0; i < _src.rows; i++ )
325793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
326793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                const float* src = _src.ptr<float>(i);
327793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double* dst = _dst.ptr<double>(i);
328793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                for( int j = 0; j < cols; j++ )
329793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    dst[j] = src[j]*w[j*2] + w[j*2+1];
330793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
331793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
332793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        else
333793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
334793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( int i = 0; i < _src.rows; i++ )
335793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
336793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                const float* src = _src.ptr<float>(i);
337793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double* dst = _dst.ptr<double>(i);
338793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                for( int j = 0; j < cols; j++ )
339793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    dst[j] = src[j]*w[j*2] + w[j*2+1];
340793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
341793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
342793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
343793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
344793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void scale_output( const Mat& _src, Mat& _dst ) const
345793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
346793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int cols = _src.cols;
347793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        const double* w = weights[layer_count()].ptr<double>();
348793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
349793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( _dst.type() == CV_32F )
350793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
351793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( int i = 0; i < _src.rows; i++ )
352793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
353793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                const double* src = _src.ptr<double>(i);
354793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                float* dst = _dst.ptr<float>(i);
355793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                for( int j = 0; j < cols; j++ )
356793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    dst[j] = (float)(src[j]*w[j*2] + w[j*2+1]);
357793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
358793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
359793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        else
360793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
361793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( int i = 0; i < _src.rows; i++ )
362793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
363793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                const double* src = _src.ptr<double>(i);
364793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double* dst = _dst.ptr<double>(i);
365793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                for( int j = 0; j < cols; j++ )
366793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    dst[j] = src[j]*w[j*2] + w[j*2+1];
367793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
368793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
369793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
370793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
371793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void calc_activ_func( Mat& sums, const Mat& w ) const
372793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
373793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        const double* bias = w.ptr<double>(w.rows-1);
374793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int i, j, n = sums.rows, cols = sums.cols;
375793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        double scale = 0, scale2 = f_param2;
376793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
377793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        switch( activ_func )
378793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
379793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            case IDENTITY:
380793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                scale = 1.;
381793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                break;
382793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            case SIGMOID_SYM:
383793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                scale = -f_param1;
384793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                break;
385793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            case GAUSSIAN:
386793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                scale = -f_param1*f_param1;
387793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                break;
388793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            default:
389793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                ;
390793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
391793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
392793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        CV_Assert( sums.isContinuous() );
393793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
394793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( activ_func != GAUSSIAN )
395793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
396793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( i = 0; i < n; i++ )
397793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
398793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double* data = sums.ptr<double>(i);
399793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                for( j = 0; j < cols; j++ )
400793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    data[j] = (data[j] + bias[j])*scale;
401793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
402793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
403793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            if( activ_func == IDENTITY )
404793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                return;
405793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
406793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        else
407793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
408793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( i = 0; i < n; i++ )
409793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
410793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double* data = sums.ptr<double>(i);
411793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                for( j = 0; j < cols; j++ )
412793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
413793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    double t = data[j] + bias[j];
414793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    data[j] = t*t*scale;
415793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
416793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
417793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
418793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
419793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        exp( sums, sums );
420793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
421793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( sums.isContinuous() )
422793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
423793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            cols *= n;
424793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            n = 1;
425793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
426793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
427793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        switch( activ_func )
428793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
429793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            case SIGMOID_SYM:
430793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                for( i = 0; i < n; i++ )
431793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
432793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    double* data = sums.ptr<double>(i);
433793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    for( j = 0; j < cols; j++ )
434793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    {
435793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        double t = scale2*(1. - data[j])/(1. + data[j]);
436793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        data[j] = t;
437793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    }
438793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
439793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                break;
440793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
441793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            case GAUSSIAN:
442793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                for( i = 0; i < n; i++ )
443793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
444793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    double* data = sums.ptr<double>(i);
445793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    for( j = 0; j < cols; j++ )
446793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        data[j] = scale2*data[j];
447793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
448793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                break;
449793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
450793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            default:
451793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                ;
452793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
453793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
454793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
455793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void calc_activ_func_deriv( Mat& _xf, Mat& _df, const Mat& w ) const
456793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
457793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        const double* bias = w.ptr<double>(w.rows-1);
458793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int i, j, n = _xf.rows, cols = _xf.cols;
459793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
460793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( activ_func == IDENTITY )
461793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
462793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( i = 0; i < n; i++ )
463793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
464793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double* xf = _xf.ptr<double>(i);
465793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double* df = _df.ptr<double>(i);
466793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
467793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                for( j = 0; j < cols; j++ )
468793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
469793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    xf[j] += bias[j];
470793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    df[j] = 1;
471793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
472793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
473793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
474793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        else if( activ_func == GAUSSIAN )
475793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
476793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            double scale = -f_param1*f_param1;
477793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            double scale2 = scale*f_param2;
478793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( i = 0; i < n; i++ )
479793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
480793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double* xf = _xf.ptr<double>(i);
481793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double* df = _df.ptr<double>(i);
482793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
483793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                for( j = 0; j < cols; j++ )
484793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
485793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    double t = xf[j] + bias[j];
486793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    df[j] = t*2*scale2;
487793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    xf[j] = t*t*scale;
488793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
489793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
490793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            exp( _xf, _xf );
491793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
492793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( i = 0; i < n; i++ )
493793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
494793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double* xf = _xf.ptr<double>(i);
495793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double* df = _df.ptr<double>(i);
496793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
497793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                for( j = 0; j < cols; j++ )
498793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    df[j] *= xf[j];
499793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
500793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
501793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        else
502793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
503793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            double scale = f_param1;
504793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            double scale2 = f_param2;
505793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
506793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( i = 0; i < n; i++ )
507793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
508793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double* xf = _xf.ptr<double>(i);
509793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double* df = _df.ptr<double>(i);
510793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
511793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                for( j = 0; j < cols; j++ )
512793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
513793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    xf[j] = (xf[j] + bias[j])*scale;
514793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    df[j] = -fabs(xf[j]);
515793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
516793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
517793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
518793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            exp( _df, _df );
519793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
520793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            // ((1+exp(-ax))^-1)'=a*((1+exp(-ax))^-2)*exp(-ax);
521793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            // ((1-exp(-ax))/(1+exp(-ax)))'=(a*exp(-ax)*(1+exp(-ax)) + a*exp(-ax)*(1-exp(-ax)))/(1+exp(-ax))^2=
522793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            // 2*a*exp(-ax)/(1+exp(-ax))^2
523793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            scale *= 2*f_param2;
524793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( i = 0; i < n; i++ )
525793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
526793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double* xf = _xf.ptr<double>(i);
527793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double* df = _df.ptr<double>(i);
528793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
529793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                for( j = 0; j < cols; j++ )
530793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
531793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    int s0 = xf[j] > 0 ? 1 : -1;
532793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    double t0 = 1./(1. + df[j]);
533793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    double t1 = scale*df[j]*t0*t0;
534793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    t0 *= scale2*(1. - df[j])*s0;
535793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    df[j] = t1;
536793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    xf[j] = t0;
537793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
538793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
539793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
540793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
541793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
542793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void calc_input_scale( const Mat& inputs, int flags )
543793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
544793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        bool reset_weights = (flags & UPDATE_WEIGHTS) == 0;
545793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        bool no_scale = (flags & NO_INPUT_SCALE) != 0;
546793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        double* scale = weights[0].ptr<double>();
547793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int count = inputs.rows;
548793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
549793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( reset_weights )
550793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
551793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            int i, j, vcount = layer_sizes[0];
552793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            int type = inputs.type();
553793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            double a = no_scale ? 1. : 0.;
554793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
555793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( j = 0; j < vcount; j++ )
556793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                scale[2*j] = a, scale[j*2+1] = 0.;
557793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
558793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            if( no_scale )
559793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                return;
560793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
561793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( i = 0; i < count; i++ )
562793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
563793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                const uchar* p = inputs.ptr(i);
564793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                const float* f = (const float*)p;
565793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                const double* d = (const double*)p;
566793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                for( j = 0; j < vcount; j++ )
567793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
568793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    double t = type == CV_32F ? (double)f[j] : d[j];
569793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    scale[j*2] += t;
570793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    scale[j*2+1] += t*t;
571793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
572793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
573793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
574793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( j = 0; j < vcount; j++ )
575793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
576793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double s = scale[j*2], s2 = scale[j*2+1];
577793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double m = s/count, sigma2 = s2/count - m*m;
578793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                scale[j*2] = sigma2 < DBL_EPSILON ? 1 : 1./sqrt(sigma2);
579793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                scale[j*2+1] = -m*scale[j*2];
580793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
581793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
582793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
583793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
584793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void calc_output_scale( const Mat& outputs, int flags )
585793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
586793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int i, j, vcount = layer_sizes.back();
587793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int type = outputs.type();
588793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        double m = min_val, M = max_val, m1 = min_val1, M1 = max_val1;
589793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        bool reset_weights = (flags & UPDATE_WEIGHTS) == 0;
590793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        bool no_scale = (flags & NO_OUTPUT_SCALE) != 0;
591793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int l_count = layer_count();
592793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        double* scale = weights[l_count].ptr<double>();
593793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        double* inv_scale = weights[l_count+1].ptr<double>();
594793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int count = outputs.rows;
595793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
596793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( reset_weights )
597793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
598793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            double a0 = no_scale ? 1 : DBL_MAX, b0 = no_scale ? 0 : -DBL_MAX;
599793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
600793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( j = 0; j < vcount; j++ )
601793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
602793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                scale[2*j] = inv_scale[2*j] = a0;
603793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                scale[j*2+1] = inv_scale[2*j+1] = b0;
604793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
605793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
606793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            if( no_scale )
607793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                return;
608793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
609793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
610793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        for( i = 0; i < count; i++ )
611793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
612793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            const uchar* p = outputs.ptr(i);
613793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            const float* f = (const float*)p;
614793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            const double* d = (const double*)p;
615793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
616793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( j = 0; j < vcount; j++ )
617793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
618793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double t = type == CV_32F ? (double)f[j] : d[j];
619793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
620793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                if( reset_weights )
621793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
622793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    double mj = scale[j*2], Mj = scale[j*2+1];
623793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    if( mj > t ) mj = t;
624793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    if( Mj < t ) Mj = t;
625793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
626793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    scale[j*2] = mj;
627793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    scale[j*2+1] = Mj;
628793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
629793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                else if( !no_scale )
630793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
631793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    t = t*inv_scale[j*2] + inv_scale[2*j+1];
632793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    if( t < m1 || t > M1 )
633793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        CV_Error( CV_StsOutOfRange,
634793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                                 "Some of new output training vector components run exceed the original range too much" );
635793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
636793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
637793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
638793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
639793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( reset_weights )
640793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( j = 0; j < vcount; j++ )
641793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
642793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                // map mj..Mj to m..M
643793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double mj = scale[j*2], Mj = scale[j*2+1];
644793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double a, b;
645793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double delta = Mj - mj;
646793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                if( delta < DBL_EPSILON )
647793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    a = 1, b = (M + m - Mj - mj)*0.5;
648793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                else
649793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    a = (M - m)/delta, b = m - mj*a;
650793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                inv_scale[j*2] = a; inv_scale[j*2+1] = b;
651793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                a = 1./a; b = -b*a;
652793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                scale[j*2] = a; scale[j*2+1] = b;
653793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
654793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
655793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
656793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void prepare_to_train( const Mat& inputs, const Mat& outputs,
657793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                           Mat& sample_weights, int flags )
658793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
659793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( layer_sizes.empty() )
660793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            CV_Error( CV_StsError,
661793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                     "The network has not been created. Use method create or the appropriate constructor" );
662793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
663793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( (inputs.type() != CV_32F && inputs.type() != CV_64F) ||
664793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            inputs.cols != layer_sizes[0] )
665793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            CV_Error( CV_StsBadArg,
666793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                     "input training data should be a floating-point matrix with "
667793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                     "the number of rows equal to the number of training samples and "
668793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                     "the number of columns equal to the size of 0-th (input) layer" );
669793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
670793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( (outputs.type() != CV_32F && outputs.type() != CV_64F) ||
671793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            outputs.cols != layer_sizes.back() )
672793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            CV_Error( CV_StsBadArg,
673793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                     "output training data should be a floating-point matrix with "
674793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                     "the number of rows equal to the number of training samples and "
675793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                     "the number of columns equal to the size of last (output) layer" );
676793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
677793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( inputs.rows != outputs.rows )
678793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            CV_Error( CV_StsUnmatchedSizes, "The numbers of input and output samples do not match" );
679793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
680793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        Mat temp;
681793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        double s = sum(sample_weights)[0];
682793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        sample_weights.convertTo(temp, CV_64F, 1./s);
683793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        sample_weights = temp;
684793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
685793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        calc_input_scale( inputs, flags );
686793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        calc_output_scale( outputs, flags );
687793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
688793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
689793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    bool train( const Ptr<TrainData>& trainData, int flags )
690793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
691793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        const int MAX_ITER = 1000;
692793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        const double DEFAULT_EPSILON = FLT_EPSILON;
693793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
694793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        // initialize training data
695793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        Mat inputs = trainData->getTrainSamples();
696793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        Mat outputs = trainData->getTrainResponses();
697793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        Mat sw = trainData->getTrainSampleWeights();
698793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        prepare_to_train( inputs, outputs, sw, flags );
699793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
700793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        // ... and link weights
701793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( !(flags & UPDATE_WEIGHTS) )
702793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            init_weights();
703793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
704793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        TermCriteria termcrit;
705793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        termcrit.type = TermCriteria::COUNT + TermCriteria::EPS;
706793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        termcrit.maxCount = std::max((params.termCrit.type & CV_TERMCRIT_ITER ? params.termCrit.maxCount : MAX_ITER), 1);
707793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        termcrit.epsilon = std::max((params.termCrit.type & CV_TERMCRIT_EPS ? params.termCrit.epsilon : DEFAULT_EPSILON), DBL_EPSILON);
708793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
709793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int iter = params.trainMethod == ANN_MLP::BACKPROP ?
710793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            train_backprop( inputs, outputs, sw, termcrit ) :
711793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            train_rprop( inputs, outputs, sw, termcrit );
712793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
713793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        trained = iter > 0;
714793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        return trained;
715793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
716793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
717793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    int train_backprop( const Mat& inputs, const Mat& outputs, const Mat& _sw, TermCriteria termCrit )
718793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
719793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int i, j, k;
720793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        double prev_E = DBL_MAX*0.5, E = 0;
721793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int itype = inputs.type(), otype = outputs.type();
722793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
723793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int count = inputs.rows;
724793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
725793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int iter = -1, max_iter = termCrit.maxCount*count;
726793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        double epsilon = termCrit.epsilon*count;
727793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
728793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int l_count = layer_count();
729793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int ivcount = layer_sizes[0];
730793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int ovcount = layer_sizes.back();
731793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
732793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        // allocate buffers
733793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        vector<vector<double> > x(l_count);
734793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        vector<vector<double> > df(l_count);
735793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        vector<Mat> dw(l_count);
736793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
737793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        for( i = 0; i < l_count; i++ )
738793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
739793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            int n = layer_sizes[i];
740793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            x[i].resize(n+1);
741793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            df[i].resize(n);
742793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            dw[i] = Mat::zeros(weights[i].size(), CV_64F);
743793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
744793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
745793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        Mat _idx_m(1, count, CV_32S);
746793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int* _idx = _idx_m.ptr<int>();
747793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        for( i = 0; i < count; i++ )
748793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            _idx[i] = i;
749793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
750793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        AutoBuffer<double> _buf(max_lsize*2);
751793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        double* buf[] = { _buf, (double*)_buf + max_lsize };
752793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
753793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        const double* sw = _sw.empty() ? 0 : _sw.ptr<double>();
754793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
755793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        // run back-propagation loop
756793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        /*
757793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler         y_i = w_i*x_{i-1}
758793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler         x_i = f(y_i)
759793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler         E = 1/2*||u - x_N||^2
760793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler         grad_N = (x_N - u)*f'(y_i)
761793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler         dw_i(t) = momentum*dw_i(t-1) + dw_scale*x_{i-1}*grad_i
762793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler         w_i(t+1) = w_i(t) + dw_i(t)
763793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler         grad_{i-1} = w_i^t*grad_i
764793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        */
765793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        for( iter = 0; iter < max_iter; iter++ )
766793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
767793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            int idx = iter % count;
768793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            double sweight = sw ? count*sw[idx] : 1.;
769793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
770793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            if( idx == 0 )
771793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
772793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                //printf("%d. E = %g\n", iter/count, E);
773793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                if( fabs(prev_E - E) < epsilon )
774793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    break;
775793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                prev_E = E;
776793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                E = 0;
777793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
778793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                // shuffle indices
779793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                for( i = 0; i < count; i++ )
780793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
781793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    j = rng.uniform(0, count);
782793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    k = rng.uniform(0, count);
783793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    std::swap(_idx[j], _idx[k]);
784793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
785793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
786793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
787793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            idx = _idx[idx];
788793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
789793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            const uchar* x0data_p = inputs.ptr(idx);
790793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            const float* x0data_f = (const float*)x0data_p;
791793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            const double* x0data_d = (const double*)x0data_p;
792793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
793793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            double* w = weights[0].ptr<double>();
794793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( j = 0; j < ivcount; j++ )
795793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                x[0][j] = (itype == CV_32F ? (double)x0data_f[j] : x0data_d[j])*w[j*2] + w[j*2 + 1];
796793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
797793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            Mat x1( 1, ivcount, CV_64F, &x[0][0] );
798793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
799793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            // forward pass, compute y[i]=w*x[i-1], x[i]=f(y[i]), df[i]=f'(y[i])
800793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( i = 1; i < l_count; i++ )
801793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
802793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                int n = layer_sizes[i];
803793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                Mat x2(1, n, CV_64F, &x[i][0] );
804793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                Mat _w = weights[i].rowRange(0, x1.cols);
805793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                gemm(x1, _w, 1, noArray(), 0, x2);
806793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                Mat _df(1, n, CV_64F, &df[i][0] );
807793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                calc_activ_func_deriv( x2, _df, weights[i] );
808793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                x1 = x2;
809793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
810793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
811793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            Mat grad1( 1, ovcount, CV_64F, buf[l_count&1] );
812793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            w = weights[l_count+1].ptr<double>();
813793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
814793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            // calculate error
815793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            const uchar* udata_p = outputs.ptr(idx);
816793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            const float* udata_f = (const float*)udata_p;
817793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            const double* udata_d = (const double*)udata_p;
818793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
819793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            double* gdata = grad1.ptr<double>();
820793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( k = 0; k < ovcount; k++ )
821793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
822793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double t = (otype == CV_32F ? (double)udata_f[k] : udata_d[k])*w[k*2] + w[k*2+1] - x[l_count-1][k];
823793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                gdata[k] = t*sweight;
824793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                E += t*t;
825793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
826793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            E *= sweight;
827793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
828793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            // backward pass, update weights
829793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( i = l_count-1; i > 0; i-- )
830793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
831793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                int n1 = layer_sizes[i-1], n2 = layer_sizes[i];
832793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                Mat _df(1, n2, CV_64F, &df[i][0]);
833793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                multiply( grad1, _df, grad1 );
834793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                Mat _x(n1+1, 1, CV_64F, &x[i-1][0]);
835793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                x[i-1][n1] = 1.;
836793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                gemm( _x, grad1, params.bpDWScale, dw[i], params.bpMomentScale, dw[i] );
837793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                add( weights[i], dw[i], weights[i] );
838793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                if( i > 1 )
839793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
840793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    Mat grad2(1, n1, CV_64F, buf[i&1]);
841793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    Mat _w = weights[i].rowRange(0, n1);
842793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    gemm( grad1, _w, 1, noArray(), 0, grad2, GEMM_2_T );
843793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    grad1 = grad2;
844793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
845793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
846793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
847793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
848793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        iter /= count;
849793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        return iter;
850793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
851793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
852793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    struct RPropLoop : public ParallelLoopBody
853793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
854793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        RPropLoop(ANN_MLPImpl* _ann,
855793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                  const Mat& _inputs, const Mat& _outputs, const Mat& _sw,
856793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                  int _dcount0, vector<Mat>& _dEdw, double* _E)
857793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
858793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            ann = _ann;
859793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            inputs = _inputs;
860793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            outputs = _outputs;
861793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            sw = _sw.ptr<double>();
862793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            dcount0 = _dcount0;
863793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            dEdw = &_dEdw;
864793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            pE = _E;
865793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
866793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
867793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        ANN_MLPImpl* ann;
868793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        vector<Mat>* dEdw;
869793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        Mat inputs, outputs;
870793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        const double* sw;
871793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int dcount0;
872793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        double* pE;
873793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
874793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        void operator()( const Range& range ) const
875793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
876793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            double inv_count = 1./inputs.rows;
877793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            int ivcount = ann->layer_sizes.front();
878793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            int ovcount = ann->layer_sizes.back();
879793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            int itype = inputs.type(), otype = outputs.type();
880793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            int count = inputs.rows;
881793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            int i, j, k, l_count = ann->layer_count();
882793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            vector<vector<double> > x(l_count);
883793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            vector<vector<double> > df(l_count);
884793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            vector<double> _buf(ann->max_lsize*dcount0*2);
885793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            double* buf[] = { &_buf[0], &_buf[ann->max_lsize*dcount0] };
886793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            double E = 0;
887793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
888793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( i = 0; i < l_count; i++ )
889793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
890793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                x[i].resize(ann->layer_sizes[i]*dcount0);
891793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                df[i].resize(ann->layer_sizes[i]*dcount0);
892793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
893793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
894793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( int si = range.start; si < range.end; si++ )
895793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
896793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                int i0 = si*dcount0, i1 = std::min((si + 1)*dcount0, count);
897793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                int dcount = i1 - i0;
898793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                const double* w = ann->weights[0].ptr<double>();
899793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
900793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                // grab and preprocess input data
901793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                for( i = 0; i < dcount; i++ )
902793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
903793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    const uchar* x0data_p = inputs.ptr(i0 + i);
904793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    const float* x0data_f = (const float*)x0data_p;
905793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    const double* x0data_d = (const double*)x0data_p;
906793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
907793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    double* xdata = &x[0][i*ivcount];
908793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    for( j = 0; j < ivcount; j++ )
909793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        xdata[j] = (itype == CV_32F ? (double)x0data_f[j] : x0data_d[j])*w[j*2] + w[j*2+1];
910793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
911793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                Mat x1(dcount, ivcount, CV_64F, &x[0][0]);
912793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
913793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                // forward pass, compute y[i]=w*x[i-1], x[i]=f(y[i]), df[i]=f'(y[i])
914793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                for( i = 1; i < l_count; i++ )
915793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
916793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    Mat x2( dcount, ann->layer_sizes[i], CV_64F, &x[i][0] );
917793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    Mat _w = ann->weights[i].rowRange(0, x1.cols);
918793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    gemm( x1, _w, 1, noArray(), 0, x2 );
919793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    Mat _df( x2.size(), CV_64F, &df[i][0] );
920793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    ann->calc_activ_func_deriv( x2, _df, ann->weights[i] );
921793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    x1 = x2;
922793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
923793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
924793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                Mat grad1(dcount, ovcount, CV_64F, buf[l_count & 1]);
925793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
926793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                w = ann->weights[l_count+1].ptr<double>();
927793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
928793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                // calculate error
929793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                for( i = 0; i < dcount; i++ )
930793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
931793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    const uchar* udata_p = outputs.ptr(i0+i);
932793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    const float* udata_f = (const float*)udata_p;
933793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    const double* udata_d = (const double*)udata_p;
934793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
935793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    const double* xdata = &x[l_count-1][i*ovcount];
936793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    double* gdata = grad1.ptr<double>(i);
937793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    double sweight = sw ? sw[si+i] : inv_count, E1 = 0;
938793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
939793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    for( j = 0; j < ovcount; j++ )
940793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    {
941793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        double t = (otype == CV_32F ? (double)udata_f[j] : udata_d[j])*w[j*2] + w[j*2+1] - xdata[j];
942793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        gdata[j] = t*sweight;
943793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        E1 += t*t;
944793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    }
945793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    E += sweight*E1;
946793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
947793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
948793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                for( i = l_count-1; i > 0; i-- )
949793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
950793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    int n1 = ann->layer_sizes[i-1], n2 = ann->layer_sizes[i];
951793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    Mat _df(dcount, n2, CV_64F, &df[i][0]);
952793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    multiply(grad1, _df, grad1);
953793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
954793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    {
955793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        AutoLock lock(ann->mtx);
956793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        Mat _dEdw = dEdw->at(i).rowRange(0, n1);
957793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        x1 = Mat(dcount, n1, CV_64F, &x[i-1][0]);
958793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        gemm(x1, grad1, 1, _dEdw, 1, _dEdw, GEMM_1_T);
959793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
960793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        // update bias part of dEdw
961793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        double* dst = dEdw->at(i).ptr<double>(n1);
962793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        for( k = 0; k < dcount; k++ )
963793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        {
964793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            const double* src = grad1.ptr<double>(k);
965793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            for( j = 0; j < n2; j++ )
966793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                                dst[j] += src[j];
967793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        }
968793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    }
969793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
970793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    Mat grad2( dcount, n1, CV_64F, buf[i&1] );
971793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    if( i > 1 )
972793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    {
973793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        Mat _w = ann->weights[i].rowRange(0, n1);
974793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        gemm(grad1, _w, 1, noArray(), 0, grad2, GEMM_2_T);
975793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    }
976793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    grad1 = grad2;
977793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
978793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
979793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
980793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                AutoLock lock(ann->mtx);
981793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                *pE += E;
982793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
983793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
984793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    };
985793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
986793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    int train_rprop( const Mat& inputs, const Mat& outputs, const Mat& _sw, TermCriteria termCrit )
987793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
988793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        const int max_buf_size = 1 << 16;
989793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int i, iter = -1, count = inputs.rows;
990793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
991793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        double prev_E = DBL_MAX*0.5;
992793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
993793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int max_iter = termCrit.maxCount;
994793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        double epsilon = termCrit.epsilon;
995793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        double dw_plus = params.rpDWPlus;
996793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        double dw_minus = params.rpDWMinus;
997793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        double dw_min = params.rpDWMin;
998793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        double dw_max = params.rpDWMax;
999793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1000793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int l_count = layer_count();
1001793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1002793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        // allocate buffers
1003793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        vector<Mat> dw(l_count), dEdw(l_count), prev_dEdw_sign(l_count);
1004793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1005793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int total = 0;
1006793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        for( i = 0; i < l_count; i++ )
1007793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
1008793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            total += layer_sizes[i];
1009793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            dw[i].create(weights[i].size(), CV_64F);
1010793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            dw[i].setTo(Scalar::all(params.rpDW0));
1011793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            prev_dEdw_sign[i] = Mat::zeros(weights[i].size(), CV_8S);
1012793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            dEdw[i] = Mat::zeros(weights[i].size(), CV_64F);
1013793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
1014793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1015793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int dcount0 = max_buf_size/(2*total);
1016793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        dcount0 = std::max( dcount0, 1 );
1017793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        dcount0 = std::min( dcount0, count );
1018793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int chunk_count = (count + dcount0 - 1)/dcount0;
1019793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1020793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        // run rprop loop
1021793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        /*
1022793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler         y_i(t) = w_i(t)*x_{i-1}(t)
1023793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler         x_i(t) = f(y_i(t))
1024793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler         E = sum_over_all_samples(1/2*||u - x_N||^2)
1025793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler         grad_N = (x_N - u)*f'(y_i)
1026793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1027793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler         std::min(dw_i{jk}(t)*dw_plus, dw_max), if dE/dw_i{jk}(t)*dE/dw_i{jk}(t-1) > 0
1028793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler         dw_i{jk}(t) = std::max(dw_i{jk}(t)*dw_minus, dw_min), if dE/dw_i{jk}(t)*dE/dw_i{jk}(t-1) < 0
1029793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler         dw_i{jk}(t-1) else
1030793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1031793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler         if (dE/dw_i{jk}(t)*dE/dw_i{jk}(t-1) < 0)
1032793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler         dE/dw_i{jk}(t)<-0
1033793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler         else
1034793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler         w_i{jk}(t+1) = w_i{jk}(t) + dw_i{jk}(t)
1035793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler         grad_{i-1}(t) = w_i^t(t)*grad_i(t)
1036793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler         */
1037793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        for( iter = 0; iter < max_iter; iter++ )
1038793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
1039793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            double E = 0;
1040793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1041793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( i = 0; i < l_count; i++ )
1042793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                dEdw[i].setTo(Scalar::all(0));
1043793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1044793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            // first, iterate through all the samples and compute dEdw
1045793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            RPropLoop invoker(this, inputs, outputs, _sw, dcount0, dEdw, &E);
1046793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            parallel_for_(Range(0, chunk_count), invoker);
1047793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            //invoker(Range(0, chunk_count));
1048793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1049793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            // now update weights
1050793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( i = 1; i < l_count; i++ )
1051793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
1052793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                int n1 = layer_sizes[i-1], n2 = layer_sizes[i];
1053793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                for( int k = 0; k <= n1; k++ )
1054793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
1055793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    CV_Assert(weights[i].size() == Size(n2, n1+1));
1056793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    double* wk = weights[i].ptr<double>(k);
1057793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    double* dwk = dw[i].ptr<double>(k);
1058793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    double* dEdwk = dEdw[i].ptr<double>(k);
1059793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    schar* prevEk = prev_dEdw_sign[i].ptr<schar>(k);
1060793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1061793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    for( int j = 0; j < n2; j++ )
1062793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    {
1063793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        double Eval = dEdwk[j];
1064793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        double dval = dwk[j];
1065793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        double wval = wk[j];
1066793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        int s = CV_SIGN(Eval);
1067793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        int ss = prevEk[j]*s;
1068793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        if( ss > 0 )
1069793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        {
1070793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            dval *= dw_plus;
1071793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            dval = std::min( dval, dw_max );
1072793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            dwk[j] = dval;
1073793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            wk[j] = wval + dval*s;
1074793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        }
1075793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        else if( ss < 0 )
1076793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        {
1077793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            dval *= dw_minus;
1078793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            dval = std::max( dval, dw_min );
1079793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            prevEk[j] = 0;
1080793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            dwk[j] = dval;
1081793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            wk[j] = wval + dval*s;
1082793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        }
1083793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        else
1084793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        {
1085793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            prevEk[j] = (schar)s;
1086793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            wk[j] = wval + dval*s;
1087793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        }
1088793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        dEdwk[j] = 0.;
1089793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    }
1090793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
1091793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
1092793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1093793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            //printf("%d. E = %g\n", iter, E);
1094793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            if( fabs(prev_E - E) < epsilon )
1095793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                break;
1096793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            prev_E = E;
1097793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
1098793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1099793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        return iter;
1100793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
1101793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1102793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void write_params( FileStorage& fs ) const
1103793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
1104793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        const char* activ_func_name = activ_func == IDENTITY ? "IDENTITY" :
1105793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                                      activ_func == SIGMOID_SYM ? "SIGMOID_SYM" :
1106793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                                      activ_func == GAUSSIAN ? "GAUSSIAN" : 0;
1107793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1108793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( activ_func_name )
1109793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            fs << "activation_function" << activ_func_name;
1110793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        else
1111793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            fs << "activation_function_id" << activ_func;
1112793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1113793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( activ_func != IDENTITY )
1114793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
1115793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            fs << "f_param1" << f_param1;
1116793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            fs << "f_param2" << f_param2;
1117793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
1118793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1119793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        fs << "min_val" << min_val << "max_val" << max_val << "min_val1" << min_val1 << "max_val1" << max_val1;
1120793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1121793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        fs << "training_params" << "{";
1122793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( params.trainMethod == ANN_MLP::BACKPROP )
1123793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
1124793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            fs << "train_method" << "BACKPROP";
1125793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            fs << "dw_scale" << params.bpDWScale;
1126793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            fs << "moment_scale" << params.bpMomentScale;
1127793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
1128793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        else if( params.trainMethod == ANN_MLP::RPROP )
1129793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
1130793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            fs << "train_method" << "RPROP";
1131793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            fs << "dw0" << params.rpDW0;
1132793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            fs << "dw_plus" << params.rpDWPlus;
1133793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            fs << "dw_minus" << params.rpDWMinus;
1134793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            fs << "dw_min" << params.rpDWMin;
1135793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            fs << "dw_max" << params.rpDWMax;
1136793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
1137793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        else
1138793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            CV_Error(CV_StsError, "Unknown training method");
1139793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1140793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        fs << "term_criteria" << "{";
1141793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( params.termCrit.type & TermCriteria::EPS )
1142793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            fs << "epsilon" << params.termCrit.epsilon;
1143793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( params.termCrit.type & TermCriteria::COUNT )
1144793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            fs << "iterations" << params.termCrit.maxCount;
1145793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        fs << "}" << "}";
1146793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
1147793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1148793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void write( FileStorage& fs ) const
1149793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
1150793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( layer_sizes.empty() )
1151793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            return;
1152793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int i, l_count = layer_count();
1153793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1154793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        fs << "layer_sizes" << layer_sizes;
1155793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1156793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        write_params( fs );
1157793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1158793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        size_t esz = weights[0].elemSize();
1159793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1160793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        fs << "input_scale" << "[";
1161793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        fs.writeRaw("d", weights[0].ptr(), weights[0].total()*esz);
1162793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1163793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        fs << "]" << "output_scale" << "[";
1164793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        fs.writeRaw("d", weights[l_count].ptr(), weights[l_count].total()*esz);
1165793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1166793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        fs << "]" << "inv_output_scale" << "[";
1167793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        fs.writeRaw("d", weights[l_count+1].ptr(), weights[l_count+1].total()*esz);
1168793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1169793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        fs << "]" << "weights" << "[";
1170793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        for( i = 1; i < l_count; i++ )
1171793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
1172793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            fs << "[";
1173793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            fs.writeRaw("d", weights[i].ptr(), weights[i].total()*esz);
1174793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            fs << "]";
1175793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
1176793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        fs << "]";
1177793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
1178793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1179793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void read_params( const FileNode& fn )
1180793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
1181793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        String activ_func_name = (String)fn["activation_function"];
1182793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( !activ_func_name.empty() )
1183793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
1184793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            activ_func = activ_func_name == "SIGMOID_SYM" ? SIGMOID_SYM :
1185793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                         activ_func_name == "IDENTITY" ? IDENTITY :
1186793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                         activ_func_name == "GAUSSIAN" ? GAUSSIAN : -1;
1187793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            CV_Assert( activ_func >= 0 );
1188793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
1189793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        else
1190793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            activ_func = (int)fn["activation_function_id"];
1191793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1192793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        f_param1 = (double)fn["f_param1"];
1193793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        f_param2 = (double)fn["f_param2"];
1194793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1195793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        setActivationFunction( activ_func, f_param1, f_param2 );
1196793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1197793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        min_val = (double)fn["min_val"];
1198793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        max_val = (double)fn["max_val"];
1199793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        min_val1 = (double)fn["min_val1"];
1200793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        max_val1 = (double)fn["max_val1"];
1201793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1202793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        FileNode tpn = fn["training_params"];
1203793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        params = AnnParams();
1204793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1205793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( !tpn.empty() )
1206793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
1207793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            String tmethod_name = (String)tpn["train_method"];
1208793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1209793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            if( tmethod_name == "BACKPROP" )
1210793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
1211793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                params.trainMethod = ANN_MLP::BACKPROP;
1212793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                params.bpDWScale = (double)tpn["dw_scale"];
1213793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                params.bpMomentScale = (double)tpn["moment_scale"];
1214793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
1215793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            else if( tmethod_name == "RPROP" )
1216793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
1217793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                params.trainMethod = ANN_MLP::RPROP;
1218793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                params.rpDW0 = (double)tpn["dw0"];
1219793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                params.rpDWPlus = (double)tpn["dw_plus"];
1220793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                params.rpDWMinus = (double)tpn["dw_minus"];
1221793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                params.rpDWMin = (double)tpn["dw_min"];
1222793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                params.rpDWMax = (double)tpn["dw_max"];
1223793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
1224793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            else
1225793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                CV_Error(CV_StsParseError, "Unknown training method (should be BACKPROP or RPROP)");
1226793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1227793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            FileNode tcn = tpn["term_criteria"];
1228793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            if( !tcn.empty() )
1229793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
1230793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                FileNode tcn_e = tcn["epsilon"];
1231793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                FileNode tcn_i = tcn["iterations"];
1232793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                params.termCrit.type = 0;
1233793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                if( !tcn_e.empty() )
1234793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
1235793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    params.termCrit.type |= TermCriteria::EPS;
1236793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    params.termCrit.epsilon = (double)tcn_e;
1237793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
1238793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                if( !tcn_i.empty() )
1239793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
1240793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    params.termCrit.type |= TermCriteria::COUNT;
1241793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    params.termCrit.maxCount = (int)tcn_i;
1242793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
1243793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
1244793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
1245793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
1246793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1247793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void read( const FileNode& fn )
1248793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
1249793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        clear();
1250793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1251793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        vector<int> _layer_sizes;
1252793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        readVectorOrMat(fn["layer_sizes"], _layer_sizes);
1253793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        setLayerSizes( _layer_sizes );
1254793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1255793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int i, l_count = layer_count();
1256793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        read_params(fn);
1257793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1258793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        size_t esz = weights[0].elemSize();
1259793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1260793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        FileNode w = fn["input_scale"];
1261793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        w.readRaw("d", weights[0].ptr(), weights[0].total()*esz);
1262793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1263793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        w = fn["output_scale"];
1264793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        w.readRaw("d", weights[l_count].ptr(), weights[l_count].total()*esz);
1265793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1266793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        w = fn["inv_output_scale"];
1267793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        w.readRaw("d", weights[l_count+1].ptr(), weights[l_count+1].total()*esz);
1268793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1269793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        FileNodeIterator w_it = fn["weights"].begin();
1270793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1271793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        for( i = 1; i < l_count; i++, ++w_it )
1272793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            (*w_it).readRaw("d", weights[i].ptr(), weights[i].total()*esz);
1273793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        trained = true;
1274793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
1275793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1276793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    Mat getWeights(int layerIdx) const
1277793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
1278793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        CV_Assert( 0 <= layerIdx && layerIdx < (int)weights.size() );
1279793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        return weights[layerIdx];
1280793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
1281793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1282793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    bool isTrained() const
1283793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
1284793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        return trained;
1285793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
1286793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1287793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    bool isClassifier() const
1288793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
1289793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        return false;
1290793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
1291793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1292793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    int getVarCount() const
1293793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
1294793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        return layer_sizes.empty() ? 0 : layer_sizes[0];
1295793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
1296793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1297793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    String getDefaultName() const
1298793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
1299793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        return "opencv_ml_ann_mlp";
1300793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
1301793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1302793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    vector<int> layer_sizes;
1303793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    vector<Mat> weights;
1304793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    double f_param1, f_param2;
1305793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    double min_val, max_val, min_val1, max_val1;
1306793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    int activ_func;
1307793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    int max_lsize, max_buf_sz;
1308793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    AnnParams params;
1309793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    RNG rng;
1310793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    Mutex mtx;
1311793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    bool trained;
1312793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler};
1313793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1314793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1315793ee12c6df9cad3806238d32528c49a3ff9331dNoah PreslerPtr<ANN_MLP> ANN_MLP::create()
1316793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler{
1317793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    return makePtr<ANN_MLPImpl>();
1318793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler}
1319793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1320793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler}}
1321793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
1322793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler/* End of file. */
1323