16acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/*M///////////////////////////////////////////////////////////////////////////////////////
26acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
36acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
46acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
56acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  By downloading, copying, installing or using the software you agree to this license.
66acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  If you do not agree to this license, do not download, install,
76acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  copy or use the software.
86acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
96acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//                        Intel License Agreement
116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Copyright (C) 2000, Intel Corporation, all rights reserved.
136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Third party copyrights are property of their respective owners.
146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Redistribution and use in source and binary forms, with or without modification,
166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// are permitted provided that the following conditions are met:
176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//   * Redistribution's of source code must retain the above copyright notice,
196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     this list of conditions and the following disclaimer.
206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//   * Redistribution's in binary form must reproduce the above copyright notice,
226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     this list of conditions and the following disclaimer in the documentation
236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     and/or other materials provided with the distribution.
246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//   * The name of Intel Corporation may not be used to endorse or promote products
266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     derived from this software without specific prior written permission.
276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// This software is provided by the copyright holders and contributors "as is" and
296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// any express or implied warranties, including, but not limited to, the implied
306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// warranties of merchantability and fitness for a particular purpose are disclaimed.
316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// In no event shall the Intel Corporation or contributors be liable for any direct,
326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// indirect, incidental, special, exemplary, or consequential damages
336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// (including, but not limited to, procurement of substitute goods or services;
346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// loss of use, data, or profits; or business interruption) however caused
356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// and on any theory of liability, whether in contract, strict liability,
366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// or tort (including negligence or otherwise) arising in any way out of
376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// the use of this software, even if advised of the possibility of such damage.
386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//M*/
406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#include "_ml.h"
426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvANN_MLP_TrainParams::CvANN_MLP_TrainParams()
446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    term_crit = cvTermCriteria( CV_TERMCRIT_ITER + CV_TERMCRIT_EPS, 1000, 0.01 );
466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    train_method = RPROP;
476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bp_dw_scale = bp_moment_scale = 0.1;
486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    rp_dw0 = 0.1; rp_dw_plus = 1.2; rp_dw_minus = 0.5;
496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    rp_dw_min = FLT_EPSILON; rp_dw_max = 50.;
506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvANN_MLP_TrainParams::CvANN_MLP_TrainParams( CvTermCriteria _term_crit,
546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                              int _train_method,
556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                              double _param1, double _param2 )
566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    term_crit = _term_crit;
586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    train_method = _train_method;
596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bp_dw_scale = bp_moment_scale = 0.1;
606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    rp_dw0 = 1.; rp_dw_plus = 1.2; rp_dw_minus = 0.5;
616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    rp_dw_min = FLT_EPSILON; rp_dw_max = 50.;
626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( train_method == RPROP )
646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        rp_dw0 = _param1;
666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( rp_dw0 < FLT_EPSILON )
676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            rp_dw0 = 1.;
686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        rp_dw_min = _param2;
696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        rp_dw_min = MAX( rp_dw_min, 0 );
706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else if( train_method == BACKPROP )
726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        bp_dw_scale = _param1;
746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( bp_dw_scale <= 0 )
756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            bp_dw_scale = 0.1;
766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        bp_dw_scale = MAX( bp_dw_scale, 1e-3 );
776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        bp_dw_scale = MIN( bp_dw_scale, 1 );
786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        bp_moment_scale = _param2;
796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( bp_moment_scale < 0 )
806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            bp_moment_scale = 0.1;
816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        bp_moment_scale = MIN( bp_moment_scale, 1 );
826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        train_method = RPROP;
856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvANN_MLP_TrainParams::~CvANN_MLP_TrainParams()
896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvANN_MLP::CvANN_MLP()
946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer_sizes = wbuf = 0;
966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    min_val = max_val = min_val1 = max_val1 = 0.;
976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    weights = 0;
986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    rng = cvRNG(-1);
996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    default_model_name = "my_nn";
1006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
1016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
1026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvANN_MLP::CvANN_MLP( const CvMat* _layer_sizes,
1056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                      int _activ_func,
1066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                      double _f_param1, double _f_param2 )
1076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
1086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer_sizes = wbuf = 0;
1096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    min_val = max_val = min_val1 = max_val1 = 0.;
1106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    weights = 0;
1116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    rng = cvRNG(-1);
1126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    default_model_name = "my_nn";
1136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    create( _layer_sizes, _activ_func, _f_param1, _f_param2 );
1146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
1156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvANN_MLP::~CvANN_MLP()
1186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
1196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
1206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
1216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::clear()
1246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
1256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &layer_sizes );
1266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &wbuf );
1276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( &weights );
1286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    activ_func = SIGMOID_SYM;
1296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    f_param1 = f_param2 = 1;
1306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    max_buf_sz = 1 << 12;
1316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
1326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::set_activ_func( int _activ_func, double _f_param1, double _f_param2 )
1356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
1366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvANN_MLP::set_activ_func" );
1376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
1396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( _activ_func < 0 || _activ_func > GAUSSIAN )
1416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsOutOfRange, "Unknown activation function" );
1426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    activ_func = _activ_func;
1446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    switch( activ_func )
1466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
1476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    case SIGMOID_SYM:
1486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        max_val = 0.95; min_val = -max_val;
1496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        max_val1 = 0.98; min_val1 = -max_val1;
1506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( fabs(_f_param1) < FLT_EPSILON )
1516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            _f_param1 = 2./3;
1526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( fabs(_f_param2) < FLT_EPSILON )
1536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            _f_param2 = 1.7159;
1546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        break;
1556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    case GAUSSIAN:
1566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        max_val = 1.; min_val = 0.05;
1576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        max_val1 = 1.; min_val1 = 0.02;
1586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( fabs(_f_param1) < FLT_EPSILON )
1596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            _f_param1 = 1.;
1606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( fabs(_f_param2) < FLT_EPSILON )
1616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            _f_param2 = 1.;
1626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        break;
1636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    default:
1646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        min_val = max_val = min_val1 = max_val1 = 0.;
1656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        _f_param1 = 1.;
1666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        _f_param2 = 0.;
1676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
1686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    f_param1 = _f_param1;
1706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    f_param2 = _f_param2;
1716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
1736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
1746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::init_weights()
1776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
1786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, j, k;
1796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 1; i < layer_sizes->cols; i++ )
1816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
1826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int n1 = layer_sizes->data.i[i-1];
1836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int n2 = layer_sizes->data.i[i];
1846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double val = 0, G = n2 > 2 ? 0.7*pow((double)n1,1./(n2-1)) : 1.;
1856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double* w = weights[i];
1866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // initialize weights using Nguyen-Widrow algorithm
1886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( j = 0; j < n2; j++ )
1896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
1906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double s = 0;
1916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( k = 0; k <= n1; k++ )
1926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
1936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                val = cvRandReal(&rng)*2-1.;
1946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                w[k*n2 + j] = val;
1956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                s += val;
1966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
1976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( i < layer_sizes->cols - 1 )
1996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
2006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                s = 1./(s - val);
2016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( k = 0; k <= n1; k++ )
2026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    w[k*n2 + j] *= s;
2036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                w[n1*n2 + j] *= G*(-1+j*2./n2);
2046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
2056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
2066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
2076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
2086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::create( const CvMat* _layer_sizes, int _activ_func,
2116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        double _f_param1, double _f_param2 )
2126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
2136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvANN_MLP::create" );
2146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
2166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, l_step, l_count, buf_sz = 0;
2186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int *l_src, *l_dst;
2196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
2216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !CV_IS_MAT(_layer_sizes) ||
2236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        _layer_sizes->cols != 1 && _layer_sizes->rows != 1 ||
2246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_MAT_TYPE(_layer_sizes->type) != CV_32SC1 )
2256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg,
2266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        "The array of layer neuron counters must be an integer vector" );
2276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( set_activ_func( _activ_func, _f_param1, _f_param2 ));
2296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    l_count = _layer_sizes->rows + _layer_sizes->cols - 1;
2316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    l_src = _layer_sizes->data.i;
2326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    l_step = CV_IS_MAT_CONT(_layer_sizes->type) ? 1 :
2336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                _layer_sizes->step / sizeof(l_src[0]);
2346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( layer_sizes = cvCreateMat( 1, l_count, CV_32SC1 ));
2356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    l_dst = layer_sizes->data.i;
2366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    max_count = 0;
2386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < l_count; i++ )
2396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
2406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int n = l_src[i*l_step];
2416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( n < 1 + (0 < i && i < l_count-1))
2426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsOutOfRange,
2436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            "there should be at least one input and one output "
2446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            "and every hidden layer must have more than 1 neuron" );
2456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        l_dst[i] = n;
2466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        max_count = MAX( max_count, n );
2476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( i > 0 )
2486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            buf_sz += (l_dst[i-1]+1)*n;
2496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
2506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    buf_sz += (l_dst[0] + l_dst[l_count-1]*2)*2;
2526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( wbuf = cvCreateMat( 1, buf_sz, CV_64F ));
2546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( weights = (double**)cvAlloc( (l_count+1)*sizeof(weights[0]) ));
2556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    weights[0] = wbuf->data.db;
2576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    weights[1] = weights[0] + l_dst[0]*2;
2586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 1; i < l_count; i++ )
2596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        weights[i+1] = weights[i] + (l_dst[i-1] + 1)*l_dst[i];
2606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    weights[l_count+1] = weights[l_count] + l_dst[l_count-1]*2;
2616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
2636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
2646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennfloat CvANN_MLP::predict( const CvMat* _inputs, CvMat* _outputs ) const
2676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
2686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvANN_MLP::predict" );
2696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
2716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* buf;
2736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, j, n, dn = 0, l_count, dn0, buf_sz, min_buf_sz;
2746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !layer_sizes )
2766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsError, "The network has not been initialized" );
2776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !CV_IS_MAT(_inputs) || !CV_IS_MAT(_outputs) ||
2796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        !CV_ARE_TYPES_EQ(_inputs,_outputs) ||
2806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_MAT_TYPE(_inputs->type) != CV_32FC1 &&
2816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_MAT_TYPE(_inputs->type) != CV_64FC1 ||
2826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        _inputs->rows != _outputs->rows )
2836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Both input and output must be floating-point matrices "
2846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                "of the same type and have the same number of rows" );
2856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( _inputs->cols != layer_sizes->data.i[0] )
2876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadSize, "input matrix must have the same number of columns as "
2886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                 "the number of neurons in the input layer" );
2896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( _outputs->cols != layer_sizes->data.i[layer_sizes->cols - 1] )
2916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadSize, "output matrix must have the same number of columns as "
2926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                 "the number of neurons in the output layer" );
2936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    n = dn0 = _inputs->rows;
2946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    min_buf_sz = 2*max_count;
2956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    buf_sz = n*min_buf_sz;
2966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( buf_sz > max_buf_sz )
2986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
2996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        dn0 = max_buf_sz/min_buf_sz;
3006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        dn0 = MAX( dn0, 1 );
3016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        buf_sz = dn0*min_buf_sz;
3026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
3036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    buf = (double*)cvStackAlloc( buf_sz*sizeof(buf[0]) );
3056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    l_count = layer_sizes->cols;
3066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < n; i += dn )
3086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvMat hdr[2], _w, *layer_in = &hdr[0], *layer_out = &hdr[1], *temp;
3106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        dn = MIN( dn0, n - i );
3116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvGetRows( _inputs, layer_in, i, i + dn );
3136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvInitMatHeader( layer_out, dn, layer_in->cols, CV_64F, buf );
3146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        scale_input( layer_in, layer_out );
3166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_SWAP( layer_in, layer_out, temp );
3176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( j = 1; j < l_count; j++ )
3196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
3206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double* data = buf + (j&1 ? max_count*dn0 : 0);
3216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int cols = layer_sizes->data.i[j];
3226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvInitMatHeader( layer_out, dn, cols, CV_64F, data );
3246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvInitMatHeader( &_w, layer_in->cols, layer_out->cols, CV_64F, weights[j] );
3256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvGEMM( layer_in, &_w, 1, 0, 0, layer_out );
3266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            calc_activ_func( layer_out, _w.data.db + _w.rows*_w.cols );
3276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_SWAP( layer_in, layer_out, temp );
3296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
3306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvGetRows( _outputs, layer_out, i, i + dn );
3326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        scale_output( layer_in, layer_out );
3336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
3346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
3366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return 0.f;
3386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
3396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::scale_input( const CvMat* _src, CvMat* _dst ) const
3426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
3436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, j, cols = _src->cols;
3446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* dst = _dst->data.db;
3456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const double* w = weights[0];
3466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int step = _src->step;
3476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( CV_MAT_TYPE( _src->type ) == CV_32F )
3496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const float* src = _src->data.fl;
3516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        step /= sizeof(src[0]);
3526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < _src->rows; i++, src += step, dst += cols )
3546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j < cols; j++ )
3556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dst[j] = src[j]*w[j*2] + w[j*2+1];
3566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
3576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
3586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const double* src = _src->data.db;
3606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        step /= sizeof(src[0]);
3616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < _src->rows; i++, src += step, dst += cols )
3636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j < cols; j++ )
3646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dst[j] = src[j]*w[j*2] + w[j*2+1];
3656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
3666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
3676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::scale_output( const CvMat* _src, CvMat* _dst ) const
3706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
3716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, j, cols = _src->cols;
3726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const double* src = _src->data.db;
3736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const double* w = weights[layer_sizes->cols];
3746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int step = _dst->step;
3756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( CV_MAT_TYPE( _dst->type ) == CV_32F )
3776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        float* dst = _dst->data.fl;
3796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        step /= sizeof(dst[0]);
3806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < _src->rows; i++, src += cols, dst += step )
3826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j < cols; j++ )
3836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dst[j] = (float)(src[j]*w[j*2] + w[j*2+1]);
3846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
3856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
3866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double* dst = _dst->data.db;
3886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        step /= sizeof(dst[0]);
3896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < _src->rows; i++, src += cols, dst += step )
3916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j < cols; j++ )
3926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dst[j] = src[j]*w[j*2] + w[j*2+1];
3936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
3946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
3956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::calc_activ_func( CvMat* sums, const double* bias ) const
3986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
3996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, j, n = sums->rows, cols = sums->cols;
4006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* data = sums->data.db;
4016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double scale = 0, scale2 = f_param2;
4026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    switch( activ_func )
4046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
4056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    case IDENTITY:
4066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        scale = 1.;
4076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        break;
4086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    case SIGMOID_SYM:
4096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        scale = -f_param1;
4106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        break;
4116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    case GAUSSIAN:
4126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        scale = -f_param1*f_param1;
4136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        break;
4146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    default:
4156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        ;
4166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
4176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    assert( CV_IS_MAT_CONT(sums->type) );
4196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( activ_func != GAUSSIAN )
4216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
4226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n; i++, data += cols )
4236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j < cols; j++ )
4246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                data[j] = (data[j] + bias[j])*scale;
4256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( activ_func == IDENTITY )
4276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            return;
4286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
4296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
4306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
4316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n; i++, data += cols )
4326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j < cols; j++ )
4336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
4346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double t = data[j] + bias[j];
4356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                data[j] = t*t*scale;
4366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
4376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
4386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvExp( sums, sums );
4406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    n *= cols;
4426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data -= n;
4436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    switch( activ_func )
4456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
4466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    case SIGMOID_SYM:
4476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i <= n - 4; i += 4 )
4486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
4496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double x0 = 1.+data[i], x1 = 1.+data[i+1], x2 = 1.+data[i+2], x3 = 1.+data[i+3];
4506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double a = x0*x1, b = x2*x3, d = scale2/(a*b), t0, t1;
4516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            a *= d; b *= d;
4526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            t0 = (2 - x0)*b*x1; t1 = (2 - x1)*b*x0;
4536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            data[i] = t0; data[i+1] = t1;
4546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            t0 = (2 - x2)*a*x3; t1 = (2 - x3)*a*x2;
4556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            data[i+2] = t0; data[i+3] = t1;
4566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
4576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( ; i < n; i++ )
4596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
4606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double t = scale2*(1. - data[i])/(1. + data[i]);
4616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            data[i] = t;
4626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
4636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        break;
4646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    case GAUSSIAN:
4666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n; i++ )
4676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            data[i] = scale2*data[i];
4686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        break;
4696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    default:
4716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        ;
4726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
4736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
4746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::calc_activ_func_deriv( CvMat* _xf, CvMat* _df,
4776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                       const double* bias ) const
4786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
4796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, j, n = _xf->rows, cols = _xf->cols;
4806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* xf = _xf->data.db;
4816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* df = _df->data.db;
4826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double scale, scale2 = f_param2;
4836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    assert( CV_IS_MAT_CONT( _xf->type & _df->type ) );
4846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( activ_func == IDENTITY )
4866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
4876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n; i++, xf += cols, df += cols )
4886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j < cols; j++ )
4896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
4906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                xf[j] += bias[j];
4916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                df[j] = 1;
4926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
4936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        return;
4946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
4956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else if( activ_func == GAUSSIAN )
4966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
4976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        scale = -f_param1*f_param1;
4986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        scale2 *= scale;
4996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n; i++, xf += cols, df += cols )
5006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j < cols; j++ )
5016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
5026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double t = xf[j] + bias[j];
5036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                df[j] = t*2*scale2;
5046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                xf[j] = t*t*scale;
5056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
5066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
5076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
5086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
5096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        scale = -f_param1;
5106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n; i++, xf += cols, df += cols )
5116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j < cols; j++ )
5126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                xf[j] = (xf[j] + bias[j])*scale;
5136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
5146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvExp( _xf, _xf );
5166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    n *= cols;
5186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    xf -= n; df -= n;
5196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // ((1+exp(-ax))^-1)'=a*((1+exp(-ax))^-2)*exp(-ax);
5216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // ((1-exp(-ax))/(1+exp(-ax)))'=(a*exp(-ax)*(1+exp(-ax)) + a*exp(-ax)*(1-exp(-ax)))/(1+exp(-ax))^2=
5226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // 2*a*exp(-ax)/(1+exp(-ax))^2
5236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    switch( activ_func )
5246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
5256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    case SIGMOID_SYM:
5266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        scale *= -2*f_param2;
5276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i <= n - 4; i += 4 )
5286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
5296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double x0 = 1.+xf[i], x1 = 1.+xf[i+1], x2 = 1.+xf[i+2], x3 = 1.+xf[i+3];
5306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double a = x0*x1, b = x2*x3, d = 1./(a*b), t0, t1;
5316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            a *= d; b *= d;
5326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            t0 = b*x1; t1 = b*x0;
5346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            df[i] = scale*xf[i]*t0*t0;
5356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            df[i+1] = scale*xf[i+1]*t1*t1;
5366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            t0 *= scale2*(2 - x0); t1 *= scale2*(2 - x1);
5376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            xf[i] = t0; xf[i+1] = t1;
5386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            t0 = a*x3; t1 = a*x2;
5406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            df[i+2] = scale*xf[i+2]*t0*t0;
5416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            df[i+3] = scale*xf[i+3]*t1*t1;
5426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            t0 *= scale2*(2 - x2); t1 *= scale2*(2 - x3);
5436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            xf[i+2] = t0; xf[i+3] = t1;
5446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
5456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( ; i < n; i++ )
5476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
5486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double t0 = 1./(1. + xf[i]);
5496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double t1 = scale*xf[i]*t0*t0;
5506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            t0 *= scale2*(1. - xf[i]);
5516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            df[i] = t1;
5526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            xf[i] = t0;
5536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
5546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        break;
5556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    case GAUSSIAN:
5576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n; i++ )
5586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            df[i] *= xf[i];
5596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        break;
5606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    default:
5616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        ;
5626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
5636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
5646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::calc_input_scale( const CvVectors* vecs, int flags )
5676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
5686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bool reset_weights = (flags & UPDATE_WEIGHTS) == 0;
5696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bool no_scale = (flags & NO_INPUT_SCALE) != 0;
5706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* scale = weights[0];
5716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int count = vecs->count;
5726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( reset_weights )
5746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
5756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int i, j, vcount = layer_sizes->data.i[0];
5766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int type = vecs->type;
5776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double a = no_scale ? 1. : 0.;
5786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( j = 0; j < vcount; j++ )
5806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            scale[2*j] = a, scale[j*2+1] = 0.;
5816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( no_scale )
5836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            return;
5846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < count; i++ )
5866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
5876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const float* f = vecs->data.fl[i];
5886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const double* d = vecs->data.db[i];
5896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j < vcount; j++ )
5906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
5916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double t = type == CV_32F ? (double)f[j] : d[j];
5926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                scale[j*2] += t;
5936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                scale[j*2+1] += t*t;
5946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
5956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
5966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( j = 0; j < vcount; j++ )
5986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
5996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double s = scale[j*2], s2 = scale[j*2+1];
6006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double m = s/count, sigma2 = s2/count - m*m;
6016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            scale[j*2] = sigma2 < DBL_EPSILON ? 1 : 1./sqrt(sigma2);
6026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            scale[j*2+1] = -m*scale[j*2];
6036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
6046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
6056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
6066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::calc_output_scale( const CvVectors* vecs, int flags )
6096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
6106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, j, vcount = layer_sizes->data.i[layer_sizes->cols-1];
6116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int type = vecs->type;
6126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double m = min_val, M = max_val, m1 = min_val1, M1 = max_val1;
6136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bool reset_weights = (flags & UPDATE_WEIGHTS) == 0;
6146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bool no_scale = (flags & NO_OUTPUT_SCALE) != 0;
6156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int l_count = layer_sizes->cols;
6166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* scale = weights[l_count];
6176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* inv_scale = weights[l_count+1];
6186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int count = vecs->count;
6196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvANN_MLP::calc_output_scale" );
6216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
6236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( reset_weights )
6256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
6266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double a0 = no_scale ? 1 : DBL_MAX, b0 = no_scale ? 0 : -DBL_MAX;
6276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( j = 0; j < vcount; j++ )
6296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
6306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            scale[2*j] = inv_scale[2*j] = a0;
6316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            scale[j*2+1] = inv_scale[2*j+1] = b0;
6326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
6336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( no_scale )
6356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            EXIT;
6366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
6376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < count; i++ )
6396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
6406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const float* f = vecs->data.fl[i];
6416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const double* d = vecs->data.db[i];
6426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( j = 0; j < vcount; j++ )
6446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
6456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double t = type == CV_32F ? (double)f[j] : d[j];
6466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( reset_weights )
6486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
6496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double mj = scale[j*2], Mj = scale[j*2+1];
6506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( mj > t ) mj = t;
6516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( Mj < t ) Mj = t;
6526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                scale[j*2] = mj;
6546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                scale[j*2+1] = Mj;
6556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
6566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else
6576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
6586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                t = t*scale[j*2] + scale[2*j+1];
6596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( t < m1 || t > M1 )
6606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    CV_ERROR( CV_StsOutOfRange,
6616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    "Some of new output training vector components run exceed the original range too much" );
6626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
6636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
6646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
6656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( reset_weights )
6676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( j = 0; j < vcount; j++ )
6686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
6696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // map mj..Mj to m..M
6706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double mj = scale[j*2], Mj = scale[j*2+1];
6716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double a, b;
6726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double delta = Mj - mj;
6736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( delta < DBL_EPSILON )
6746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                a = 1, b = (M + m - Mj - mj)*0.5;
6756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else
6766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                a = (M - m)/delta, b = m - mj*a;
6776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            inv_scale[j*2] = a; inv_scale[j*2+1] = b;
6786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            a = 1./a; b = -b*a;
6796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            scale[j*2] = a; scale[j*2+1] = b;
6806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
6816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
6836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
6846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool CvANN_MLP::prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
6876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const CvMat* _sample_weights, const CvMat* _sample_idx,
6886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags )
6896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
6906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bool ok = false;
6916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* sample_idx = 0;
6926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvVectors ivecs, ovecs;
6936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* sw = 0;
6946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int count = 0;
6956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvANN_MLP::prepare_to_train" );
6976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ivecs.data.ptr = ovecs.data.ptr = 0;
6996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    assert( _ivecs && _ovecs );
7006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
7026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int* sidx = 0;
7046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, sw_type = 0, sw_count = 0;
7056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int sw_step = 0;
7066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double sw_sum = 0;
7076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !layer_sizes )
7096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsError,
7106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        "The network has not been created. Use method create or the appropriate constructor" );
7116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !CV_IS_MAT(_inputs) || CV_MAT_TYPE(_inputs->type) != CV_32FC1 &&
7136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_MAT_TYPE(_inputs->type) != CV_64FC1 || _inputs->cols != layer_sizes->data.i[0] )
7146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg,
7156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        "input training data should be a floating-point matrix with"
7166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        "the number of rows equal to the number of training samples and "
7176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        "the number of columns equal to the size of 0-th (input) layer" );
7186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !CV_IS_MAT(_outputs) || CV_MAT_TYPE(_outputs->type) != CV_32FC1 &&
7206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_MAT_TYPE(_outputs->type) != CV_64FC1 ||
7216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        _outputs->cols != layer_sizes->data.i[layer_sizes->cols - 1] )
7226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg,
7236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        "output training data should be a floating-point matrix with"
7246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        "the number of rows equal to the number of training samples and "
7256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        "the number of columns equal to the size of last (output) layer" );
7266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( _inputs->rows != _outputs->rows )
7286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsUnmatchedSizes, "The numbers of input and output samples do not match" );
7296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( _sample_idx )
7316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
7326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( sample_idx = cvPreprocessIndexArray( _sample_idx, _inputs->rows ));
7336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sidx = sample_idx->data.i;
7346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        count = sample_idx->cols + sample_idx->rows - 1;
7356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
7366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
7376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        count = _inputs->rows;
7386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( _sample_weights )
7406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
7416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !CV_IS_MAT(_sample_weights) )
7426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadArg, "sample_weights (if passed) must be a valid matrix" );
7436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sw_type = CV_MAT_TYPE(_sample_weights->type);
7456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sw_count = _sample_weights->cols + _sample_weights->rows - 1;
7466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( sw_type != CV_32FC1 && sw_type != CV_64FC1 ||
7486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            _sample_weights->cols != 1 && _sample_weights->rows != 1 ||
7496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            sw_count != count && sw_count != _inputs->rows )
7506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadArg,
7516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            "sample_weights must be 1d floating-point vector containing weights "
7526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            "of all or selected training samples" );
7536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sw_step = CV_IS_MAT_CONT(_sample_weights->type) ? 1 :
7556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            _sample_weights->step/CV_ELEM_SIZE(sw_type);
7566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( sw = (double*)cvAlloc( count*sizeof(sw[0]) ));
7586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
7596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( ivecs.data.ptr = (uchar**)cvAlloc( count*sizeof(ivecs.data.ptr[0]) ));
7616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( ovecs.data.ptr = (uchar**)cvAlloc( count*sizeof(ovecs.data.ptr[0]) ));
7626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ivecs.type = CV_MAT_TYPE(_inputs->type);
7646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ovecs.type = CV_MAT_TYPE(_outputs->type);
7656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ivecs.count = ovecs.count = count;
7666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < count; i++ )
7686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
7696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int idx = sidx ? sidx[i] : i;
7706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        ivecs.data.ptr[i] = _inputs->data.ptr + idx*_inputs->step;
7716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        ovecs.data.ptr[i] = _outputs->data.ptr + idx*_outputs->step;
7726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( sw )
7736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
7746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int si = sw_count == count ? i : idx;
7756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double w = sw_type == CV_32FC1 ?
7766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                (double)_sample_weights->data.fl[si*sw_step] :
7776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                _sample_weights->data.db[si*sw_step];
7786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            sw[i] = w;
7796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( w < 0 )
7806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CV_ERROR( CV_StsOutOfRange, "some of sample weights are negative" );
7816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            sw_sum += w;
7826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
7836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
7846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // normalize weights
7866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( sw )
7876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
7886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sw_sum = sw_sum > DBL_EPSILON ? 1./sw_sum : 0;
7896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < count; i++ )
7906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            sw[i] *= sw_sum;
7916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
7926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    calc_input_scale( &ivecs, _flags );
7946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( calc_output_scale( &ovecs, _flags ));
7956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ok = true;
7976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
7996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !ok )
8016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
8026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvFree( &ivecs.data.ptr );
8036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvFree( &ovecs.data.ptr );
8046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvFree( &sw );
8056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
8066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &sample_idx );
8086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    *_ivecs = ivecs;
8096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    *_ovecs = ovecs;
8106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    *_sw = sw;
8116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return ok;
8136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
8146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvANN_MLP::train( const CvMat* _inputs, const CvMat* _outputs,
8176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                      const CvMat* _sample_weights, const CvMat* _sample_idx,
8186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                      CvANN_MLP_TrainParams _params, int flags )
8196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
8206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int MAX_ITER = 1000;
8216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const double DEFAULT_EPSILON = FLT_EPSILON;
8226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* sw = 0;
8246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvVectors x0, u;
8256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int iter = -1;
8266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    x0.data.ptr = u.data.ptr = 0;
8286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvANN_MLP::train" );
8306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
8326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int max_iter;
8346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double epsilon;
8356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params = _params;
8376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // initialize training data
8396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( prepare_to_train( _inputs, _outputs, _sample_weights,
8406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                               _sample_idx, &x0, &u, &sw, flags ));
8416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // ... and link weights
8436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !(flags & UPDATE_WEIGHTS) )
8446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        init_weights();
8456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    max_iter = params.term_crit.type & CV_TERMCRIT_ITER ? params.term_crit.max_iter : MAX_ITER;
8476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    max_iter = MIN( max_iter, MAX_ITER );
8486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    max_iter = MAX( max_iter, 1 );
8496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    epsilon = params.term_crit.type & CV_TERMCRIT_EPS ? params.term_crit.epsilon : DEFAULT_EPSILON;
8516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    epsilon = MAX(epsilon, DBL_EPSILON);
8526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params.term_crit.type = CV_TERMCRIT_ITER + CV_TERMCRIT_EPS;
8546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params.term_crit.max_iter = max_iter;
8556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params.term_crit.epsilon = epsilon;
8566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.train_method == CvANN_MLP_TrainParams::BACKPROP )
8586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
8596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( iter = train_backprop( x0, u, sw ));
8606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
8616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
8626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
8636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( iter = train_rprop( x0, u, sw ));
8646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
8656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
8676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( &x0.data.ptr );
8696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( &u.data.ptr );
8706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( &sw );
8716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return iter;
8736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
8746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvANN_MLP::train_backprop( CvVectors x0, CvVectors u, const double* sw )
8776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
8786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* dw = 0;
8796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* buf = 0;
8806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double **x = 0, **df = 0;
8816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* _idx = 0;
8826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int iter = -1, count = x0.count;
8836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvANN_MLP::train_backprop" );
8856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
8876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, j, k, ivcount, ovcount, l_count, total = 0, max_iter;
8896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double *buf_ptr;
8906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double prev_E = DBL_MAX*0.5, E = 0, epsilon;
8916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    max_iter = params.term_crit.max_iter*count;
8936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    epsilon = params.term_crit.epsilon*count;
8946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    l_count = layer_sizes->cols;
8966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ivcount = layer_sizes->data.i[0];
8976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ovcount = layer_sizes->data.i[l_count-1];
8986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // allocate buffers
9006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < l_count; i++ )
9016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        total += layer_sizes->data.i[i] + 1;
9026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( dw = cvCreateMat( wbuf->rows, wbuf->cols, wbuf->type ));
9046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvZero( dw );
9056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( buf = cvCreateMat( 1, (total + max_count)*2, CV_64F ));
9066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( _idx = cvCreateMat( 1, count, CV_32SC1 ));
9076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < count; i++ )
9086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        _idx->data.i[i] = i;
9096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( x = (double**)cvAlloc( total*2*sizeof(x[0]) ));
9116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    df = x + total;
9126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    buf_ptr = buf->data.db;
9136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( j = 0; j < l_count; j++ )
9156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
9166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        x[j] = buf_ptr;
9176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        df[j] = x[j] + layer_sizes->data.i[j];
9186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        buf_ptr += (df[j] - x[j])*2;
9196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
9206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // run back-propagation loop
9226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    /*
9236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        y_i = w_i*x_{i-1}
9246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        x_i = f(y_i)
9256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        E = 1/2*||u - x_N||^2
9266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        grad_N = (x_N - u)*f'(y_i)
9276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        dw_i(t) = momentum*dw_i(t-1) + dw_scale*x_{i-1}*grad_i
9286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        w_i(t+1) = w_i(t) + dw_i(t)
9296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        grad_{i-1} = w_i^t*grad_i
9306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    */
9316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( iter = 0; iter < max_iter; iter++ )
9326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
9336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int idx = iter % count;
9346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double* w = weights[0];
9356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double sweight = sw ? count*sw[idx] : 1.;
9366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvMat _w, _dw, hdr1, hdr2, ghdr1, ghdr2, _df;
9376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvMat *x1 = &hdr1, *x2 = &hdr2, *grad1 = &ghdr1, *grad2 = &ghdr2, *temp;
9386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( idx == 0 )
9406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
9416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( fabs(prev_E - E) < epsilon )
9426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                break;
9436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            prev_E = E;
9446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            E = 0;
9456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // shuffle indices
9476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < count; i++ )
9486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
9496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int tt;
9506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                j = (unsigned)cvRandInt(&rng) % count;
9516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                k = (unsigned)cvRandInt(&rng) % count;
9526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CV_SWAP( _idx->data.i[j], _idx->data.i[k], tt );
9536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
9546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
9556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        idx = _idx->data.i[idx];
9576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( x0.type == CV_32F )
9596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
9606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const float* x0data = x0.data.fl[idx];
9616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j < ivcount; j++ )
9626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                x[0][j] = x0data[j]*w[j*2] + w[j*2 + 1];
9636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
9646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
9656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
9666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const double* x0data = x0.data.db[idx];
9676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j < ivcount; j++ )
9686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                x[0][j] = x0data[j]*w[j*2] + w[j*2 + 1];
9696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
9706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvInitMatHeader( x1, 1, ivcount, CV_64F, x[0] );
9726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // forward pass, compute y[i]=w*x[i-1], x[i]=f(y[i]), df[i]=f'(y[i])
9746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 1; i < l_count; i++ )
9756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
9766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvInitMatHeader( x2, 1, layer_sizes->data.i[i], CV_64F, x[i] );
9776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvInitMatHeader( &_w, x1->cols, x2->cols, CV_64F, weights[i] );
9786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvGEMM( x1, &_w, 1, 0, 0, x2 );
9796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            _df = *x2;
9806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            _df.data.db = df[i];
9816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            calc_activ_func_deriv( x2, &_df, _w.data.db + _w.rows*_w.cols );
9826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_SWAP( x1, x2, temp );
9836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
9846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvInitMatHeader( grad1, 1, ovcount, CV_64F, buf_ptr );
9866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        *grad2 = *grad1;
9876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        grad2->data.db = buf_ptr + max_count;
9886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        w = weights[l_count+1];
9906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // calculate error
9926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( u.type == CV_32F )
9936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
9946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const float* udata = u.data.fl[idx];
9956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( k = 0; k < ovcount; k++ )
9966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
9976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double t = udata[k]*w[k*2] + w[k*2+1] - x[l_count-1][k];
9986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                grad1->data.db[k] = t*sweight;
9996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                E += t*t;
10006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
10016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
10026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
10036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
10046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const double* udata = u.data.db[idx];
10056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( k = 0; k < ovcount; k++ )
10066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
10076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double t = udata[k]*w[k*2] + w[k*2+1] - x[l_count-1][k];
10086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                grad1->data.db[k] = t*sweight;
10096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                E += t*t;
10106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
10116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
10126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        E *= sweight;
10136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // backward pass, update weights
10156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = l_count-1; i > 0; i-- )
10166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
10176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int n1 = layer_sizes->data.i[i-1], n2 = layer_sizes->data.i[i];
10186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvInitMatHeader( &_df, 1, n2, CV_64F, df[i] );
10196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvMul( grad1, &_df, grad1 );
10206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvInitMatHeader( &_w, n1+1, n2, CV_64F, weights[i] );
10216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvInitMatHeader( &_dw, n1+1, n2, CV_64F, dw->data.db + (weights[i] - weights[0]) );
10226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvInitMatHeader( x1, n1+1, 1, CV_64F, x[i-1] );
10236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            x[i-1][n1] = 1.;
10246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvGEMM( x1, grad1, params.bp_dw_scale, &_dw, params.bp_moment_scale, &_dw );
10256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvAdd( &_w, &_dw, &_w );
10266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( i > 1 )
10276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
10286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                grad2->cols = n1;
10296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                _w.rows = n1;
10306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvGEMM( grad1, &_w, 1, 0, 0, grad2, CV_GEMM_B_T );
10316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
10326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_SWAP( grad1, grad2, temp );
10336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
10346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
10356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    iter /= count;
10376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
10396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &dw );
10416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &buf );
10426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &_idx );
10436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( &x );
10446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return iter;
10466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
10476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvANN_MLP::train_rprop( CvVectors x0, CvVectors u, const double* sw )
10506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
10516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int max_buf_sz = 1 << 16;
10526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* dw = 0;
10536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* dEdw = 0;
10546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* prev_dEdw_sign = 0;
10556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* buf = 0;
10566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double **x = 0, **df = 0;
10576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int iter = -1, count = x0.count;
10586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvANN_MLP::train" );
10606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
10626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, ivcount, ovcount, l_count, total = 0, max_iter, buf_sz, dcount0, dcount=0;
10646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double *buf_ptr;
10656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double prev_E = DBL_MAX*0.5, epsilon;
10666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double dw_plus, dw_minus, dw_min, dw_max;
10676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double inv_count;
10686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    max_iter = params.term_crit.max_iter;
10706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    epsilon = params.term_crit.epsilon;
10716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    dw_plus = params.rp_dw_plus;
10726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    dw_minus = params.rp_dw_minus;
10736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    dw_min = params.rp_dw_min;
10746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    dw_max = params.rp_dw_max;
10756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    l_count = layer_sizes->cols;
10776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ivcount = layer_sizes->data.i[0];
10786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ovcount = layer_sizes->data.i[l_count-1];
10796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // allocate buffers
10816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < l_count; i++ )
10826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        total += layer_sizes->data.i[i];
10836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( dw = cvCreateMat( wbuf->rows, wbuf->cols, wbuf->type ));
10856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvSet( dw, cvScalarAll(params.rp_dw0) );
10866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( dEdw = cvCreateMat( wbuf->rows, wbuf->cols, wbuf->type ));
10876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvZero( dEdw );
10886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( prev_dEdw_sign = cvCreateMat( wbuf->rows, wbuf->cols, CV_8SC1 ));
10896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvZero( prev_dEdw_sign );
10906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    inv_count = 1./count;
10926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    dcount0 = max_buf_sz/(2*total);
10936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    dcount0 = MAX( dcount0, 1 );
10946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    dcount0 = MIN( dcount0, count );
10956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    buf_sz = dcount0*(total + max_count)*2;
10966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( buf = cvCreateMat( 1, buf_sz, CV_64F ));
10986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( x = (double**)cvAlloc( total*2*sizeof(x[0]) ));
11006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    df = x + total;
11016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    buf_ptr = buf->data.db;
11026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < l_count; i++ )
11046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
11056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        x[i] = buf_ptr;
11066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        df[i] = x[i] + layer_sizes->data.i[i]*dcount0;
11076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        buf_ptr += (df[i] - x[i])*2;
11086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
11096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // run rprop loop
11116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    /*
11126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        y_i(t) = w_i(t)*x_{i-1}(t)
11136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        x_i(t) = f(y_i(t))
11146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        E = sum_over_all_samples(1/2*||u - x_N||^2)
11156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        grad_N = (x_N - u)*f'(y_i)
11166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                      MIN(dw_i{jk}(t)*dw_plus, dw_max), if dE/dw_i{jk}(t)*dE/dw_i{jk}(t-1) > 0
11186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        dw_i{jk}(t) = MAX(dw_i{jk}(t)*dw_minus, dw_min), if dE/dw_i{jk}(t)*dE/dw_i{jk}(t-1) < 0
11196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                      dw_i{jk}(t-1) else
11206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if (dE/dw_i{jk}(t)*dE/dw_i{jk}(t-1) < 0)
11226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn           dE/dw_i{jk}(t)<-0
11236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
11246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn           w_i{jk}(t+1) = w_i{jk}(t) + dw_i{jk}(t)
11256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        grad_{i-1}(t) = w_i^t(t)*grad_i(t)
11266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    */
11276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( iter = 0; iter < max_iter; iter++ )
11286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
11296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int n1, n2, si, j, k;
11306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double* w;
11316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvMat _w, _dEdw, hdr1, hdr2, ghdr1, ghdr2, _df;
11326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvMat *x1, *x2, *grad1, *grad2, *temp;
11336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double E = 0;
11346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // first, iterate through all the samples and compute dEdw
11366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( si = 0; si < count; si += dcount )
11376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
11386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            dcount = MIN( count - si, dcount0 );
11396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            w = weights[0];
11406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            grad1 = &ghdr1; grad2 = &ghdr2;
11416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            x1 = &hdr1; x2 = &hdr2;
11426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // grab and preprocess input data
11446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( x0.type == CV_32F )
11456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( i = 0; i < dcount; i++ )
11466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
11476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    const float* x0data = x0.data.fl[si+i];
11486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    double* xdata = x[0]+i*ivcount;
11496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    for( j = 0; j < ivcount; j++ )
11506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        xdata[j] = x0data[j]*w[j*2] + w[j*2+1];
11516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
11526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else
11536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( i = 0; i < dcount; i++ )
11546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
11556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    const double* x0data = x0.data.db[si+i];
11566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    double* xdata = x[0]+i*ivcount;
11576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    for( j = 0; j < ivcount; j++ )
11586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        xdata[j] = x0data[j]*w[j*2] + w[j*2+1];
11596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
11606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvInitMatHeader( x1, dcount, ivcount, CV_64F, x[0] );
11626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // forward pass, compute y[i]=w*x[i-1], x[i]=f(y[i]), df[i]=f'(y[i])
11646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 1; i < l_count; i++ )
11656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
11666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvInitMatHeader( x2, dcount, layer_sizes->data.i[i], CV_64F, x[i] );
11676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvInitMatHeader( &_w, x1->cols, x2->cols, CV_64F, weights[i] );
11686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvGEMM( x1, &_w, 1, 0, 0, x2 );
11696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                _df = *x2;
11706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                _df.data.db = df[i];
11716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                calc_activ_func_deriv( x2, &_df, _w.data.db + _w.rows*_w.cols );
11726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CV_SWAP( x1, x2, temp );
11736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
11746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvInitMatHeader( grad1, dcount, ovcount, CV_64F, buf_ptr );
11766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            w = weights[l_count+1];
11776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            grad2->data.db = buf_ptr + max_count*dcount;
11786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // calculate error
11806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( u.type == CV_32F )
11816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( i = 0; i < dcount; i++ )
11826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
11836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    const float* udata = u.data.fl[si+i];
11846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    const double* xdata = x[l_count-1] + i*ovcount;
11856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    double* gdata = grad1->data.db + i*ovcount;
11866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    double sweight = sw ? sw[si+i] : inv_count, E1 = 0;
11876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    for( j = 0; j < ovcount; j++ )
11896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    {
11906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        double t = udata[j]*w[j*2] + w[j*2+1] - xdata[j];
11916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        gdata[j] = t*sweight;
11926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        E1 += t*t;
11936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    }
11946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    E += sweight*E1;
11956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
11966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else
11976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( i = 0; i < dcount; i++ )
11986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
11996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    const double* udata = u.data.db[si+i];
12006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    const double* xdata = x[l_count-1] + i*ovcount;
12016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    double* gdata = grad1->data.db + i*ovcount;
12026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    double sweight = sw ? sw[si+i] : inv_count, E1 = 0;
12036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    for( j = 0; j < ovcount; j++ )
12056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    {
12066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        double t = udata[j]*w[j*2] + w[j*2+1] - xdata[j];
12076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        gdata[j] = t*sweight;
12086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        E1 += t*t;
12096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    }
12106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    E += sweight*E1;
12116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
12126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // backward pass, update dEdw
12146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = l_count-1; i > 0; i-- )
12156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
12166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                n1 = layer_sizes->data.i[i-1]; n2 = layer_sizes->data.i[i];
12176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvInitMatHeader( &_df, dcount, n2, CV_64F, df[i] );
12186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvMul( grad1, &_df, grad1 );
12196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvInitMatHeader( &_dEdw, n1, n2, CV_64F, dEdw->data.db+(weights[i]-weights[0]) );
12206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvInitMatHeader( x1, dcount, n1, CV_64F, x[i-1] );
12216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvGEMM( x1, grad1, 1, &_dEdw, 1, &_dEdw, CV_GEMM_A_T );
12226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                // update bias part of dEdw
12236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( k = 0; k < dcount; k++ )
12246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
12256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    double* dst = _dEdw.data.db + n1*n2;
12266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    const double* src = grad1->data.db + k*n2;
12276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    for( j = 0; j < n2; j++ )
12286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        dst[j] += src[j];
12296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
12306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvInitMatHeader( &_w, n1, n2, CV_64F, weights[i] );
12316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvInitMatHeader( grad2, dcount, n1, CV_64F, grad2->data.db );
12326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( i > 1 )
12346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    cvGEMM( grad1, &_w, 1, 0, 0, grad2, CV_GEMM_B_T );
12356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CV_SWAP( grad1, grad2, temp );
12366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
12376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
12386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // now update weights
12406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 1; i < l_count; i++ )
12416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
12426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            n1 = layer_sizes->data.i[i-1]; n2 = layer_sizes->data.i[i];
12436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( k = 0; k <= n1; k++ )
12446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
12456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double* wk = weights[i]+k*n2;
12466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                size_t delta = wk - weights[0];
12476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double* dwk = dw->data.db + delta;
12486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double* dEdwk = dEdw->data.db + delta;
12496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                char* prevEk = (char*)(prev_dEdw_sign->data.ptr + delta);
12506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( j = 0; j < n2; j++ )
12526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
12536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    double Eval = dEdwk[j];
12546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    double dval = dwk[j];
12556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    double wval = wk[j];
12566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    int s = CV_SIGN(Eval);
12576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    int ss = prevEk[j]*s;
12586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( ss > 0 )
12596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    {
12606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        dval *= dw_plus;
12616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        dval = MIN( dval, dw_max );
12626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        dwk[j] = dval;
12636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        wk[j] = wval + dval*s;
12646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    }
12656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    else if( ss < 0 )
12666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    {
12676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        dval *= dw_minus;
12686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        dval = MAX( dval, dw_min );
12696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        prevEk[j] = 0;
12706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        dwk[j] = dval;
12716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        wk[j] = wval + dval*s;
12726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    }
12736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    else
12746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    {
12756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        prevEk[j] = (char)s;
12766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        wk[j] = wval + dval*s;
12776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    }
12786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    dEdwk[j] = 0.;
12796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
12806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
12816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
12826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( fabs(prev_E - E) < epsilon )
12846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            break;
12856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        prev_E = E;
12866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        E = 0;
12876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
12886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
12906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &dw );
12926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &dEdw );
12936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &prev_dEdw_sign );
12946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &buf );
12956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( &x );
12966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return iter;
12986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
12996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::write_params( CvFileStorage* fs )
13026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
13036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    //CV_FUNCNAME( "CvANN_MLP::write_params" );
13046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
13066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const char* activ_func_name = activ_func == IDENTITY ? "IDENTITY" :
13086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            activ_func == SIGMOID_SYM ? "SIGMOID_SYM" :
13096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            activ_func == GAUSSIAN ? "GAUSSIAN" : 0;
13106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( activ_func_name )
13126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteString( fs, "activation_function", activ_func_name );
13136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
13146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteInt( fs, "activation_function", activ_func );
13156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( activ_func != IDENTITY )
13176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
13186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteReal( fs, "f_param1", f_param1 );
13196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteReal( fs, "f_param2", f_param2 );
13206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
13216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteReal( fs, "min_val", min_val );
13236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteReal( fs, "max_val", max_val );
13246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteReal( fs, "min_val1", min_val1 );
13256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteReal( fs, "max_val1", max_val1 );
13266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
13286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.train_method == CvANN_MLP_TrainParams::BACKPROP )
13296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
13306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteString( fs, "train_method", "BACKPROP" );
13316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteReal( fs, "dw_scale", params.bp_dw_scale );
13326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteReal( fs, "moment_scale", params.bp_moment_scale );
13336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
13346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else if( params.train_method == CvANN_MLP_TrainParams::RPROP )
13356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
13366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteString( fs, "train_method", "RPROP" );
13376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteReal( fs, "dw0", params.rp_dw0 );
13386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteReal( fs, "dw_plus", params.rp_dw_plus );
13396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteReal( fs, "dw_minus", params.rp_dw_minus );
13406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteReal( fs, "dw_min", params.rp_dw_min );
13416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteReal( fs, "dw_max", params.rp_dw_max );
13426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
13436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvStartWriteStruct( fs, "term_criteria", CV_NODE_MAP + CV_NODE_FLOW );
13456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.term_crit.type & CV_TERMCRIT_EPS )
13466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteReal( fs, "epsilon", params.term_crit.epsilon );
13476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.term_crit.type & CV_TERMCRIT_ITER )
13486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteInt( fs, "iterations", params.term_crit.max_iter );
13496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvEndWriteStruct( fs );
13506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvEndWriteStruct( fs );
13526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
13546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
13556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::write( CvFileStorage* fs, const char* name )
13586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
13596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvANN_MLP::write" );
13606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
13626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, l_count = layer_sizes->cols;
13646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !layer_sizes )
13666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsError, "The network has not been initialized" );
13676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_ANN_MLP );
13696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWrite( fs, "layer_sizes", layer_sizes );
13716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    write_params( fs );
13736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvStartWriteStruct( fs, "input_scale", CV_NODE_SEQ + CV_NODE_FLOW );
13756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteRawData( fs, weights[0], layer_sizes->data.i[0]*2, "d" );
13766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvEndWriteStruct( fs );
13776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvStartWriteStruct( fs, "output_scale", CV_NODE_SEQ + CV_NODE_FLOW );
13796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteRawData( fs, weights[l_count], layer_sizes->data.i[l_count-1]*2, "d" );
13806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvEndWriteStruct( fs );
13816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvStartWriteStruct( fs, "inv_output_scale", CV_NODE_SEQ + CV_NODE_FLOW );
13836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteRawData( fs, weights[l_count+1], layer_sizes->data.i[l_count-1]*2, "d" );
13846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvEndWriteStruct( fs );
13856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvStartWriteStruct( fs, "weights", CV_NODE_SEQ );
13876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 1; i < l_count; i++ )
13886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
13896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvStartWriteStruct( fs, 0, CV_NODE_SEQ + CV_NODE_FLOW );
13906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteRawData( fs, weights[i], (layer_sizes->data.i[i-1]+1)*layer_sizes->data.i[i], "d" );
13916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvEndWriteStruct( fs );
13926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
13936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvEndWriteStruct( fs );
13956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
13976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
13986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::read_params( CvFileStorage* fs, CvFileNode* node )
14016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
14026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    //CV_FUNCNAME( "CvANN_MLP::read_params" );
14036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
14056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const char* activ_func_name = cvReadStringByName( fs, node, "activation_function", 0 );
14076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvFileNode* tparams_node;
14086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( activ_func_name )
14106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        activ_func = strcmp( activ_func_name, "SIGMOID_SYM" ) == 0 ? SIGMOID_SYM :
14116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                     strcmp( activ_func_name, "IDENTITY" ) == 0 ? IDENTITY :
14126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                     strcmp( activ_func_name, "GAUSSIAN" ) == 0 ? GAUSSIAN : 0;
14136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
14146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        activ_func = cvReadIntByName( fs, node, "activation_function" );
14156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    f_param1 = cvReadRealByName( fs, node, "f_param1", 0 );
14176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    f_param2 = cvReadRealByName( fs, node, "f_param2", 0 );
14186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    set_activ_func( activ_func, f_param1, f_param2 );
14206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    min_val = cvReadRealByName( fs, node, "min_val", 0. );
14226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    max_val = cvReadRealByName( fs, node, "max_val", 1. );
14236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    min_val1 = cvReadRealByName( fs, node, "min_val1", 0. );
14246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    max_val1 = cvReadRealByName( fs, node, "max_val1", 1. );
14256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
14276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params = CvANN_MLP_TrainParams();
14286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( tparams_node )
14306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
14316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const char* tmethod_name = cvReadStringByName( fs, tparams_node, "train_method", "" );
14326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvFileNode* tcrit_node;
14336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( strcmp( tmethod_name, "BACKPROP" ) == 0 )
14356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
14366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            params.train_method = CvANN_MLP_TrainParams::BACKPROP;
14376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            params.bp_dw_scale = cvReadRealByName( fs, tparams_node, "dw_scale", 0 );
14386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            params.bp_moment_scale = cvReadRealByName( fs, tparams_node, "moment_scale", 0 );
14396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
14406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else if( strcmp( tmethod_name, "RPROP" ) == 0 )
14416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
14426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            params.train_method = CvANN_MLP_TrainParams::RPROP;
14436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            params.rp_dw0 = cvReadRealByName( fs, tparams_node, "dw0", 0 );
14446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            params.rp_dw_plus = cvReadRealByName( fs, tparams_node, "dw_plus", 0 );
14456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            params.rp_dw_minus = cvReadRealByName( fs, tparams_node, "dw_minus", 0 );
14466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            params.rp_dw_min = cvReadRealByName( fs, tparams_node, "dw_min", 0 );
14476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            params.rp_dw_max = cvReadRealByName( fs, tparams_node, "dw_max", 0 );
14486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
14496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        tcrit_node = cvGetFileNodeByName( fs, tparams_node, "term_criteria" );
14516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( tcrit_node )
14526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
14536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            params.term_crit.epsilon = cvReadRealByName( fs, tcrit_node, "epsilon", -1 );
14546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            params.term_crit.max_iter = cvReadIntByName( fs, tcrit_node, "iterations", -1 );
14556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            params.term_crit.type = (params.term_crit.epsilon >= 0 ? CV_TERMCRIT_EPS : 0) +
14566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                   (params.term_crit.max_iter >= 0 ? CV_TERMCRIT_ITER : 0);
14576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
14586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
14596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
14616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
14626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::read( CvFileStorage* fs, CvFileNode* node )
14656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
14666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* _layer_sizes = 0;
14676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvANN_MLP::read" );
14696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
14716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvFileNode* w;
14736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvSeqReader reader;
14746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, l_count;
14756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    _layer_sizes = (CvMat*)cvReadByName( fs, node, "layer_sizes" );
14776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( create( _layer_sizes, SIGMOID_SYM, 0, 0 ));
14786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    l_count = layer_sizes->cols;
14796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( read_params( fs, node ));
14816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    w = cvGetFileNodeByName( fs, node, "input_scale" );
14836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !w || CV_NODE_TYPE(w->tag) != CV_NODE_SEQ ||
14846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        w->data.seq->total != layer_sizes->data.i[0]*2 )
14856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsParseError, "input_scale tag is not found or is invalid" );
14866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( cvReadRawData( fs, w, weights[0], "d" ));
14886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    w = cvGetFileNodeByName( fs, node, "output_scale" );
14906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !w || CV_NODE_TYPE(w->tag) != CV_NODE_SEQ ||
14916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        w->data.seq->total != layer_sizes->data.i[l_count-1]*2 )
14926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsParseError, "output_scale tag is not found or is invalid" );
14936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( cvReadRawData( fs, w, weights[l_count], "d" ));
14956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    w = cvGetFileNodeByName( fs, node, "inv_output_scale" );
14976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !w || CV_NODE_TYPE(w->tag) != CV_NODE_SEQ ||
14986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        w->data.seq->total != layer_sizes->data.i[l_count-1]*2 )
14996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsParseError, "inv_output_scale tag is not found or is invalid" );
15006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( cvReadRawData( fs, w, weights[l_count+1], "d" ));
15026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    w = cvGetFileNodeByName( fs, node, "weights" );
15046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !w || CV_NODE_TYPE(w->tag) != CV_NODE_SEQ ||
15056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        w->data.seq->total != l_count - 1 )
15066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsParseError, "weights tag is not found or is invalid" );
15076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvStartReadSeq( w->data.seq, &reader );
15096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 1; i < l_count; i++ )
15116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
15126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        w = (CvFileNode*)reader.ptr;
15136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( cvReadRawData( fs, w, weights[i], "d" ));
15146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
15156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
15166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
15186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
15196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/* End of file. */
1521