16acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/*M/////////////////////////////////////////////////////////////////////////////////////// 26acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 36acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 46acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 56acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// By downloading, copying, installing or using the software you agree to this license. 66acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// If you do not agree to this license, do not download, install, 76acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// copy or use the software. 86acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 96acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Intel License Agreement 116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Copyright (C) 2000, Intel Corporation, all rights reserved. 136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Third party copyrights are property of their respective owners. 146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Redistribution and use in source and binary forms, with or without modification, 166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// are permitted provided that the following conditions are met: 176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// * Redistribution's of source code must retain the above copyright notice, 196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// this list of conditions and the following disclaimer. 206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// * Redistribution's in binary form must reproduce the above copyright notice, 226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// this list of conditions and the following disclaimer in the documentation 236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// and/or other materials provided with the distribution. 246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// * The name of Intel Corporation may not be used to endorse or promote products 266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// derived from this software without specific prior written permission. 276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// This software is provided by the copyright holders and contributors "as is" and 296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// any express or implied warranties, including, but not limited to, the implied 306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// warranties of merchantability and fitness for a particular purpose are disclaimed. 316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// In no event shall the Intel Corporation or contributors be liable for any direct, 326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// indirect, incidental, special, exemplary, or consequential damages 336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// (including, but not limited to, procurement of substitute goods or services; 346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// loss of use, data, or profits; or business interruption) however caused 356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// and on any theory of liability, whether in contract, strict liability, 366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// or tort (including negligence or otherwise) arising in any way out of 376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// the use of this software, even if advised of the possibility of such damage. 386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//M*/ 406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#include "_ml.h" 426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvANN_MLP_TrainParams::CvANN_MLP_TrainParams() 446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn term_crit = cvTermCriteria( CV_TERMCRIT_ITER + CV_TERMCRIT_EPS, 1000, 0.01 ); 466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn train_method = RPROP; 476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bp_dw_scale = bp_moment_scale = 0.1; 486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rp_dw0 = 0.1; rp_dw_plus = 1.2; rp_dw_minus = 0.5; 496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rp_dw_min = FLT_EPSILON; rp_dw_max = 50.; 506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvANN_MLP_TrainParams::CvANN_MLP_TrainParams( CvTermCriteria _term_crit, 546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int _train_method, 556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double _param1, double _param2 ) 566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn term_crit = _term_crit; 586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn train_method = _train_method; 596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bp_dw_scale = bp_moment_scale = 0.1; 606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rp_dw0 = 1.; rp_dw_plus = 1.2; rp_dw_minus = 0.5; 616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rp_dw_min = FLT_EPSILON; rp_dw_max = 50.; 626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( train_method == RPROP ) 646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rp_dw0 = _param1; 666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( rp_dw0 < FLT_EPSILON ) 676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rp_dw0 = 1.; 686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rp_dw_min = _param2; 696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rp_dw_min = MAX( rp_dw_min, 0 ); 706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( train_method == BACKPROP ) 726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bp_dw_scale = _param1; 746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( bp_dw_scale <= 0 ) 756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bp_dw_scale = 0.1; 766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bp_dw_scale = MAX( bp_dw_scale, 1e-3 ); 776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bp_dw_scale = MIN( bp_dw_scale, 1 ); 786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bp_moment_scale = _param2; 796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( bp_moment_scale < 0 ) 806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bp_moment_scale = 0.1; 816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bp_moment_scale = MIN( bp_moment_scale, 1 ); 826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn train_method = RPROP; 856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvANN_MLP_TrainParams::~CvANN_MLP_TrainParams() 896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvANN_MLP::CvANN_MLP() 946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer_sizes = wbuf = 0; 966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn min_val = max_val = min_val1 = max_val1 = 0.; 976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weights = 0; 986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rng = cvRNG(-1); 996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn default_model_name = "my_nn"; 1006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn clear(); 1016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 1026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvANN_MLP::CvANN_MLP( const CvMat* _layer_sizes, 1056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int _activ_func, 1066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double _f_param1, double _f_param2 ) 1076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 1086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer_sizes = wbuf = 0; 1096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn min_val = max_val = min_val1 = max_val1 = 0.; 1106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weights = 0; 1116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rng = cvRNG(-1); 1126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn default_model_name = "my_nn"; 1136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn create( _layer_sizes, _activ_func, _f_param1, _f_param2 ); 1146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 1156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvANN_MLP::~CvANN_MLP() 1186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 1196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn clear(); 1206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 1216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::clear() 1246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 1256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &layer_sizes ); 1266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &wbuf ); 1276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &weights ); 1286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn activ_func = SIGMOID_SYM; 1296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn f_param1 = f_param2 = 1; 1306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_buf_sz = 1 << 12; 1316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 1326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::set_activ_func( int _activ_func, double _f_param1, double _f_param2 ) 1356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 1366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvANN_MLP::set_activ_func" ); 1376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 1396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( _activ_func < 0 || _activ_func > GAUSSIAN ) 1416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsOutOfRange, "Unknown activation function" ); 1426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn activ_func = _activ_func; 1446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn switch( activ_func ) 1466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 1476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn case SIGMOID_SYM: 1486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_val = 0.95; min_val = -max_val; 1496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_val1 = 0.98; min_val1 = -max_val1; 1506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( fabs(_f_param1) < FLT_EPSILON ) 1516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _f_param1 = 2./3; 1526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( fabs(_f_param2) < FLT_EPSILON ) 1536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _f_param2 = 1.7159; 1546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 1556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn case GAUSSIAN: 1566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_val = 1.; min_val = 0.05; 1576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_val1 = 1.; min_val1 = 0.02; 1586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( fabs(_f_param1) < FLT_EPSILON ) 1596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _f_param1 = 1.; 1606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( fabs(_f_param2) < FLT_EPSILON ) 1616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _f_param2 = 1.; 1626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 1636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn default: 1646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn min_val = max_val = min_val1 = max_val1 = 0.; 1656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _f_param1 = 1.; 1666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _f_param2 = 0.; 1676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 1686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn f_param1 = _f_param1; 1706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn f_param2 = _f_param2; 1716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 1736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 1746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::init_weights() 1776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 1786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, j, k; 1796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 1; i < layer_sizes->cols; i++ ) 1816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 1826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n1 = layer_sizes->data.i[i-1]; 1836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n2 = layer_sizes->data.i[i]; 1846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double val = 0, G = n2 > 2 ? 0.7*pow((double)n1,1./(n2-1)) : 1.; 1856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* w = weights[i]; 1866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // initialize weights using Nguyen-Widrow algorithm 1886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < n2; j++ ) 1896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 1906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double s = 0; 1916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k <= n1; k++ ) 1926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 1936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn val = cvRandReal(&rng)*2-1.; 1946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w[k*n2 + j] = val; 1956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn s += val; 1966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 1976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( i < layer_sizes->cols - 1 ) 1996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn s = 1./(s - val); 2016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k <= n1; k++ ) 2026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w[k*n2 + j] *= s; 2036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w[n1*n2 + j] *= G*(-1+j*2./n2); 2046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 2086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::create( const CvMat* _layer_sizes, int _activ_func, 2116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double _f_param1, double _f_param2 ) 2126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 2136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvANN_MLP::create" ); 2146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 2166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, l_step, l_count, buf_sz = 0; 2186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int *l_src, *l_dst; 2196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn clear(); 2216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_IS_MAT(_layer_sizes) || 2236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _layer_sizes->cols != 1 && _layer_sizes->rows != 1 || 2246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_MAT_TYPE(_layer_sizes->type) != CV_32SC1 ) 2256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, 2266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "The array of layer neuron counters must be an integer vector" ); 2276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( set_activ_func( _activ_func, _f_param1, _f_param2 )); 2296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn l_count = _layer_sizes->rows + _layer_sizes->cols - 1; 2316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn l_src = _layer_sizes->data.i; 2326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn l_step = CV_IS_MAT_CONT(_layer_sizes->type) ? 1 : 2336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _layer_sizes->step / sizeof(l_src[0]); 2346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( layer_sizes = cvCreateMat( 1, l_count, CV_32SC1 )); 2356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn l_dst = layer_sizes->data.i; 2366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_count = 0; 2386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < l_count; i++ ) 2396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n = l_src[i*l_step]; 2416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( n < 1 + (0 < i && i < l_count-1)) 2426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsOutOfRange, 2436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "there should be at least one input and one output " 2446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "and every hidden layer must have more than 1 neuron" ); 2456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn l_dst[i] = n; 2466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_count = MAX( max_count, n ); 2476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( i > 0 ) 2486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn buf_sz += (l_dst[i-1]+1)*n; 2496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn buf_sz += (l_dst[0] + l_dst[l_count-1]*2)*2; 2526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( wbuf = cvCreateMat( 1, buf_sz, CV_64F )); 2546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( weights = (double**)cvAlloc( (l_count+1)*sizeof(weights[0]) )); 2556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weights[0] = wbuf->data.db; 2576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weights[1] = weights[0] + l_dst[0]*2; 2586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 1; i < l_count; i++ ) 2596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weights[i+1] = weights[i] + (l_dst[i-1] + 1)*l_dst[i]; 2606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weights[l_count+1] = weights[l_count] + l_dst[l_count-1]*2; 2616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 2636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 2646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennfloat CvANN_MLP::predict( const CvMat* _inputs, CvMat* _outputs ) const 2676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 2686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvANN_MLP::predict" ); 2696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 2716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* buf; 2736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, j, n, dn = 0, l_count, dn0, buf_sz, min_buf_sz; 2746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !layer_sizes ) 2766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsError, "The network has not been initialized" ); 2776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_IS_MAT(_inputs) || !CV_IS_MAT(_outputs) || 2796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn !CV_ARE_TYPES_EQ(_inputs,_outputs) || 2806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_MAT_TYPE(_inputs->type) != CV_32FC1 && 2816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_MAT_TYPE(_inputs->type) != CV_64FC1 || 2826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _inputs->rows != _outputs->rows ) 2836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Both input and output must be floating-point matrices " 2846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "of the same type and have the same number of rows" ); 2856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( _inputs->cols != layer_sizes->data.i[0] ) 2876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadSize, "input matrix must have the same number of columns as " 2886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "the number of neurons in the input layer" ); 2896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( _outputs->cols != layer_sizes->data.i[layer_sizes->cols - 1] ) 2916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadSize, "output matrix must have the same number of columns as " 2926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "the number of neurons in the output layer" ); 2936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn n = dn0 = _inputs->rows; 2946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn min_buf_sz = 2*max_count; 2956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn buf_sz = n*min_buf_sz; 2966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( buf_sz > max_buf_sz ) 2986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dn0 = max_buf_sz/min_buf_sz; 3006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dn0 = MAX( dn0, 1 ); 3016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn buf_sz = dn0*min_buf_sz; 3026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn buf = (double*)cvStackAlloc( buf_sz*sizeof(buf[0]) ); 3056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn l_count = layer_sizes->cols; 3066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i += dn ) 3086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat hdr[2], _w, *layer_in = &hdr[0], *layer_out = &hdr[1], *temp; 3106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dn = MIN( dn0, n - i ); 3116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGetRows( _inputs, layer_in, i, i + dn ); 3136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvInitMatHeader( layer_out, dn, layer_in->cols, CV_64F, buf ); 3146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn scale_input( layer_in, layer_out ); 3166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_SWAP( layer_in, layer_out, temp ); 3176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 1; j < l_count; j++ ) 3196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* data = buf + (j&1 ? max_count*dn0 : 0); 3216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int cols = layer_sizes->data.i[j]; 3226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvInitMatHeader( layer_out, dn, cols, CV_64F, data ); 3246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvInitMatHeader( &_w, layer_in->cols, layer_out->cols, CV_64F, weights[j] ); 3256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGEMM( layer_in, &_w, 1, 0, 0, layer_out ); 3266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn calc_activ_func( layer_out, _w.data.db + _w.rows*_w.cols ); 3276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_SWAP( layer_in, layer_out, temp ); 3296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGetRows( _outputs, layer_out, i, i + dn ); 3326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn scale_output( layer_in, layer_out ); 3336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 3366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return 0.f; 3386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 3396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::scale_input( const CvMat* _src, CvMat* _dst ) const 3426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 3436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, j, cols = _src->cols; 3446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* dst = _dst->data.db; 3456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* w = weights[0]; 3466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int step = _src->step; 3476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( CV_MAT_TYPE( _src->type ) == CV_32F ) 3496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float* src = _src->data.fl; 3516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn step /= sizeof(src[0]); 3526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < _src->rows; i++, src += step, dst += cols ) 3546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < cols; j++ ) 3556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dst[j] = src[j]*w[j*2] + w[j*2+1]; 3566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 3586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* src = _src->data.db; 3606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn step /= sizeof(src[0]); 3616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < _src->rows; i++, src += step, dst += cols ) 3636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < cols; j++ ) 3646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dst[j] = src[j]*w[j*2] + w[j*2+1]; 3656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 3676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::scale_output( const CvMat* _src, CvMat* _dst ) const 3706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 3716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, j, cols = _src->cols; 3726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* src = _src->data.db; 3736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* w = weights[layer_sizes->cols]; 3746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int step = _dst->step; 3756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( CV_MAT_TYPE( _dst->type ) == CV_32F ) 3776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float* dst = _dst->data.fl; 3796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn step /= sizeof(dst[0]); 3806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < _src->rows; i++, src += cols, dst += step ) 3826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < cols; j++ ) 3836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dst[j] = (float)(src[j]*w[j*2] + w[j*2+1]); 3846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 3866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* dst = _dst->data.db; 3886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn step /= sizeof(dst[0]); 3896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < _src->rows; i++, src += cols, dst += step ) 3916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < cols; j++ ) 3926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dst[j] = src[j]*w[j*2] + w[j*2+1]; 3936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 3956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::calc_activ_func( CvMat* sums, const double* bias ) const 3986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 3996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, j, n = sums->rows, cols = sums->cols; 4006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* data = sums->data.db; 4016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double scale = 0, scale2 = f_param2; 4026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn switch( activ_func ) 4046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn case IDENTITY: 4066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn scale = 1.; 4076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 4086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn case SIGMOID_SYM: 4096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn scale = -f_param1; 4106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 4116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn case GAUSSIAN: 4126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn scale = -f_param1*f_param1; 4136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 4146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn default: 4156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ; 4166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn assert( CV_IS_MAT_CONT(sums->type) ); 4196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( activ_func != GAUSSIAN ) 4216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++, data += cols ) 4236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < cols; j++ ) 4246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data[j] = (data[j] + bias[j])*scale; 4256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( activ_func == IDENTITY ) 4276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return; 4286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 4306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++, data += cols ) 4326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < cols; j++ ) 4336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double t = data[j] + bias[j]; 4356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data[j] = t*t*scale; 4366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvExp( sums, sums ); 4406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn n *= cols; 4426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data -= n; 4436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn switch( activ_func ) 4456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn case SIGMOID_SYM: 4476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i <= n - 4; i += 4 ) 4486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double x0 = 1.+data[i], x1 = 1.+data[i+1], x2 = 1.+data[i+2], x3 = 1.+data[i+3]; 4506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double a = x0*x1, b = x2*x3, d = scale2/(a*b), t0, t1; 4516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn a *= d; b *= d; 4526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn t0 = (2 - x0)*b*x1; t1 = (2 - x1)*b*x0; 4536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data[i] = t0; data[i+1] = t1; 4546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn t0 = (2 - x2)*a*x3; t1 = (2 - x3)*a*x2; 4556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data[i+2] = t0; data[i+3] = t1; 4566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ; i < n; i++ ) 4596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double t = scale2*(1. - data[i])/(1. + data[i]); 4616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data[i] = t; 4626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 4646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn case GAUSSIAN: 4666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 4676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data[i] = scale2*data[i]; 4686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 4696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn default: 4716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ; 4726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 4746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::calc_activ_func_deriv( CvMat* _xf, CvMat* _df, 4776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* bias ) const 4786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 4796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, j, n = _xf->rows, cols = _xf->cols; 4806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* xf = _xf->data.db; 4816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* df = _df->data.db; 4826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double scale, scale2 = f_param2; 4836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn assert( CV_IS_MAT_CONT( _xf->type & _df->type ) ); 4846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( activ_func == IDENTITY ) 4866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++, xf += cols, df += cols ) 4886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < cols; j++ ) 4896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn xf[j] += bias[j]; 4916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn df[j] = 1; 4926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return; 4946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( activ_func == GAUSSIAN ) 4966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn scale = -f_param1*f_param1; 4986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn scale2 *= scale; 4996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++, xf += cols, df += cols ) 5006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < cols; j++ ) 5016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double t = xf[j] + bias[j]; 5036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn df[j] = t*2*scale2; 5046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn xf[j] = t*t*scale; 5056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 5086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn scale = -f_param1; 5106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++, xf += cols, df += cols ) 5116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < cols; j++ ) 5126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn xf[j] = (xf[j] + bias[j])*scale; 5136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvExp( _xf, _xf ); 5166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn n *= cols; 5186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn xf -= n; df -= n; 5196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // ((1+exp(-ax))^-1)'=a*((1+exp(-ax))^-2)*exp(-ax); 5216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // ((1-exp(-ax))/(1+exp(-ax)))'=(a*exp(-ax)*(1+exp(-ax)) + a*exp(-ax)*(1-exp(-ax)))/(1+exp(-ax))^2= 5226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // 2*a*exp(-ax)/(1+exp(-ax))^2 5236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn switch( activ_func ) 5246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn case SIGMOID_SYM: 5266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn scale *= -2*f_param2; 5276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i <= n - 4; i += 4 ) 5286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double x0 = 1.+xf[i], x1 = 1.+xf[i+1], x2 = 1.+xf[i+2], x3 = 1.+xf[i+3]; 5306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double a = x0*x1, b = x2*x3, d = 1./(a*b), t0, t1; 5316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn a *= d; b *= d; 5326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn t0 = b*x1; t1 = b*x0; 5346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn df[i] = scale*xf[i]*t0*t0; 5356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn df[i+1] = scale*xf[i+1]*t1*t1; 5366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn t0 *= scale2*(2 - x0); t1 *= scale2*(2 - x1); 5376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn xf[i] = t0; xf[i+1] = t1; 5386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn t0 = a*x3; t1 = a*x2; 5406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn df[i+2] = scale*xf[i+2]*t0*t0; 5416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn df[i+3] = scale*xf[i+3]*t1*t1; 5426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn t0 *= scale2*(2 - x2); t1 *= scale2*(2 - x3); 5436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn xf[i+2] = t0; xf[i+3] = t1; 5446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ; i < n; i++ ) 5476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double t0 = 1./(1. + xf[i]); 5496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double t1 = scale*xf[i]*t0*t0; 5506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn t0 *= scale2*(1. - xf[i]); 5516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn df[i] = t1; 5526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn xf[i] = t0; 5536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 5556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn case GAUSSIAN: 5576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 5586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn df[i] *= xf[i]; 5596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 5606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn default: 5616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ; 5626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 5646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::calc_input_scale( const CvVectors* vecs, int flags ) 5676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 5686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bool reset_weights = (flags & UPDATE_WEIGHTS) == 0; 5696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bool no_scale = (flags & NO_INPUT_SCALE) != 0; 5706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* scale = weights[0]; 5716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int count = vecs->count; 5726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( reset_weights ) 5746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, j, vcount = layer_sizes->data.i[0]; 5766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int type = vecs->type; 5776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double a = no_scale ? 1. : 0.; 5786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < vcount; j++ ) 5806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn scale[2*j] = a, scale[j*2+1] = 0.; 5816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( no_scale ) 5836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return; 5846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 5866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float* f = vecs->data.fl[i]; 5886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* d = vecs->data.db[i]; 5896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < vcount; j++ ) 5906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double t = type == CV_32F ? (double)f[j] : d[j]; 5926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn scale[j*2] += t; 5936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn scale[j*2+1] += t*t; 5946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < vcount; j++ ) 5986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double s = scale[j*2], s2 = scale[j*2+1]; 6006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double m = s/count, sigma2 = s2/count - m*m; 6016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn scale[j*2] = sigma2 < DBL_EPSILON ? 1 : 1./sqrt(sigma2); 6026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn scale[j*2+1] = -m*scale[j*2]; 6036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 6066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::calc_output_scale( const CvVectors* vecs, int flags ) 6096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 6106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, j, vcount = layer_sizes->data.i[layer_sizes->cols-1]; 6116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int type = vecs->type; 6126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double m = min_val, M = max_val, m1 = min_val1, M1 = max_val1; 6136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bool reset_weights = (flags & UPDATE_WEIGHTS) == 0; 6146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bool no_scale = (flags & NO_OUTPUT_SCALE) != 0; 6156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int l_count = layer_sizes->cols; 6166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* scale = weights[l_count]; 6176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* inv_scale = weights[l_count+1]; 6186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int count = vecs->count; 6196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvANN_MLP::calc_output_scale" ); 6216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 6236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( reset_weights ) 6256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double a0 = no_scale ? 1 : DBL_MAX, b0 = no_scale ? 0 : -DBL_MAX; 6276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < vcount; j++ ) 6296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn scale[2*j] = inv_scale[2*j] = a0; 6316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn scale[j*2+1] = inv_scale[2*j+1] = b0; 6326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( no_scale ) 6356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn EXIT; 6366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 6396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float* f = vecs->data.fl[i]; 6416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* d = vecs->data.db[i]; 6426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < vcount; j++ ) 6446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double t = type == CV_32F ? (double)f[j] : d[j]; 6466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( reset_weights ) 6486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double mj = scale[j*2], Mj = scale[j*2+1]; 6506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( mj > t ) mj = t; 6516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( Mj < t ) Mj = t; 6526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn scale[j*2] = mj; 6546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn scale[j*2+1] = Mj; 6556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 6576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn t = t*scale[j*2] + scale[2*j+1]; 6596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( t < m1 || t > M1 ) 6606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsOutOfRange, 6616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "Some of new output training vector components run exceed the original range too much" ); 6626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( reset_weights ) 6676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < vcount; j++ ) 6686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // map mj..Mj to m..M 6706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double mj = scale[j*2], Mj = scale[j*2+1]; 6716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double a, b; 6726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double delta = Mj - mj; 6736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( delta < DBL_EPSILON ) 6746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn a = 1, b = (M + m - Mj - mj)*0.5; 6756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 6766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn a = (M - m)/delta, b = m - mj*a; 6776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn inv_scale[j*2] = a; inv_scale[j*2+1] = b; 6786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn a = 1./a; b = -b*a; 6796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn scale[j*2] = a; scale[j*2+1] = b; 6806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 6836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 6846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool CvANN_MLP::prepare_to_train( const CvMat* _inputs, const CvMat* _outputs, 6876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* _sample_weights, const CvMat* _sample_idx, 6886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags ) 6896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 6906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bool ok = false; 6916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* sample_idx = 0; 6926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvVectors ivecs, ovecs; 6936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* sw = 0; 6946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int count = 0; 6956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvANN_MLP::prepare_to_train" ); 6976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ivecs.data.ptr = ovecs.data.ptr = 0; 6996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn assert( _ivecs && _ovecs ); 7006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 7026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* sidx = 0; 7046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, sw_type = 0, sw_count = 0; 7056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int sw_step = 0; 7066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sw_sum = 0; 7076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !layer_sizes ) 7096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsError, 7106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "The network has not been created. Use method create or the appropriate constructor" ); 7116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_IS_MAT(_inputs) || CV_MAT_TYPE(_inputs->type) != CV_32FC1 && 7136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_MAT_TYPE(_inputs->type) != CV_64FC1 || _inputs->cols != layer_sizes->data.i[0] ) 7146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, 7156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "input training data should be a floating-point matrix with" 7166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "the number of rows equal to the number of training samples and " 7176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "the number of columns equal to the size of 0-th (input) layer" ); 7186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_IS_MAT(_outputs) || CV_MAT_TYPE(_outputs->type) != CV_32FC1 && 7206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_MAT_TYPE(_outputs->type) != CV_64FC1 || 7216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _outputs->cols != layer_sizes->data.i[layer_sizes->cols - 1] ) 7226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, 7236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "output training data should be a floating-point matrix with" 7246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "the number of rows equal to the number of training samples and " 7256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "the number of columns equal to the size of last (output) layer" ); 7266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( _inputs->rows != _outputs->rows ) 7286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsUnmatchedSizes, "The numbers of input and output samples do not match" ); 7296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( _sample_idx ) 7316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( sample_idx = cvPreprocessIndexArray( _sample_idx, _inputs->rows )); 7336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sidx = sample_idx->data.i; 7346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn count = sample_idx->cols + sample_idx->rows - 1; 7356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 7376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn count = _inputs->rows; 7386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( _sample_weights ) 7406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_IS_MAT(_sample_weights) ) 7426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "sample_weights (if passed) must be a valid matrix" ); 7436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sw_type = CV_MAT_TYPE(_sample_weights->type); 7456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sw_count = _sample_weights->cols + _sample_weights->rows - 1; 7466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( sw_type != CV_32FC1 && sw_type != CV_64FC1 || 7486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _sample_weights->cols != 1 && _sample_weights->rows != 1 || 7496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sw_count != count && sw_count != _inputs->rows ) 7506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, 7516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "sample_weights must be 1d floating-point vector containing weights " 7526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "of all or selected training samples" ); 7536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sw_step = CV_IS_MAT_CONT(_sample_weights->type) ? 1 : 7556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _sample_weights->step/CV_ELEM_SIZE(sw_type); 7566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( sw = (double*)cvAlloc( count*sizeof(sw[0]) )); 7586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( ivecs.data.ptr = (uchar**)cvAlloc( count*sizeof(ivecs.data.ptr[0]) )); 7616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( ovecs.data.ptr = (uchar**)cvAlloc( count*sizeof(ovecs.data.ptr[0]) )); 7626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ivecs.type = CV_MAT_TYPE(_inputs->type); 7646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ovecs.type = CV_MAT_TYPE(_outputs->type); 7656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ivecs.count = ovecs.count = count; 7666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 7686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = sidx ? sidx[i] : i; 7706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ivecs.data.ptr[i] = _inputs->data.ptr + idx*_inputs->step; 7716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ovecs.data.ptr[i] = _outputs->data.ptr + idx*_outputs->step; 7726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( sw ) 7736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int si = sw_count == count ? i : idx; 7756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = sw_type == CV_32FC1 ? 7766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn (double)_sample_weights->data.fl[si*sw_step] : 7776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _sample_weights->data.db[si*sw_step]; 7786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sw[i] = w; 7796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( w < 0 ) 7806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsOutOfRange, "some of sample weights are negative" ); 7816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sw_sum += w; 7826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // normalize weights 7866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( sw ) 7876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sw_sum = sw_sum > DBL_EPSILON ? 1./sw_sum : 0; 7896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 7906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sw[i] *= sw_sum; 7916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn calc_input_scale( &ivecs, _flags ); 7946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( calc_output_scale( &ovecs, _flags )); 7956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ok = true; 7976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 7996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !ok ) 8016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &ivecs.data.ptr ); 8036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &ovecs.data.ptr ); 8046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &sw ); 8056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &sample_idx ); 8086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn *_ivecs = ivecs; 8096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn *_ovecs = ovecs; 8106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn *_sw = sw; 8116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return ok; 8136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 8146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvANN_MLP::train( const CvMat* _inputs, const CvMat* _outputs, 8176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* _sample_weights, const CvMat* _sample_idx, 8186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvANN_MLP_TrainParams _params, int flags ) 8196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 8206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int MAX_ITER = 1000; 8216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double DEFAULT_EPSILON = FLT_EPSILON; 8226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* sw = 0; 8246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvVectors x0, u; 8256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int iter = -1; 8266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn x0.data.ptr = u.data.ptr = 0; 8286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvANN_MLP::train" ); 8306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 8326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int max_iter; 8346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double epsilon; 8356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params = _params; 8376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // initialize training data 8396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( prepare_to_train( _inputs, _outputs, _sample_weights, 8406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _sample_idx, &x0, &u, &sw, flags )); 8416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // ... and link weights 8436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !(flags & UPDATE_WEIGHTS) ) 8446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn init_weights(); 8456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_iter = params.term_crit.type & CV_TERMCRIT_ITER ? params.term_crit.max_iter : MAX_ITER; 8476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_iter = MIN( max_iter, MAX_ITER ); 8486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_iter = MAX( max_iter, 1 ); 8496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn epsilon = params.term_crit.type & CV_TERMCRIT_EPS ? params.term_crit.epsilon : DEFAULT_EPSILON; 8516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn epsilon = MAX(epsilon, DBL_EPSILON); 8526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.term_crit.type = CV_TERMCRIT_ITER + CV_TERMCRIT_EPS; 8546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.term_crit.max_iter = max_iter; 8556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.term_crit.epsilon = epsilon; 8566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.train_method == CvANN_MLP_TrainParams::BACKPROP ) 8586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( iter = train_backprop( x0, u, sw )); 8606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 8626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( iter = train_rprop( x0, u, sw )); 8646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 8676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &x0.data.ptr ); 8696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &u.data.ptr ); 8706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &sw ); 8716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return iter; 8736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 8746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvANN_MLP::train_backprop( CvVectors x0, CvVectors u, const double* sw ) 8776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 8786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* dw = 0; 8796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* buf = 0; 8806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double **x = 0, **df = 0; 8816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* _idx = 0; 8826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int iter = -1, count = x0.count; 8836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvANN_MLP::train_backprop" ); 8856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 8876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, j, k, ivcount, ovcount, l_count, total = 0, max_iter; 8896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double *buf_ptr; 8906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double prev_E = DBL_MAX*0.5, E = 0, epsilon; 8916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_iter = params.term_crit.max_iter*count; 8936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn epsilon = params.term_crit.epsilon*count; 8946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn l_count = layer_sizes->cols; 8966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ivcount = layer_sizes->data.i[0]; 8976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ovcount = layer_sizes->data.i[l_count-1]; 8986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // allocate buffers 9006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < l_count; i++ ) 9016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn total += layer_sizes->data.i[i] + 1; 9026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( dw = cvCreateMat( wbuf->rows, wbuf->cols, wbuf->type )); 9046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvZero( dw ); 9056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( buf = cvCreateMat( 1, (total + max_count)*2, CV_64F )); 9066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( _idx = cvCreateMat( 1, count, CV_32SC1 )); 9076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 9086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _idx->data.i[i] = i; 9096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( x = (double**)cvAlloc( total*2*sizeof(x[0]) )); 9116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn df = x + total; 9126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn buf_ptr = buf->data.db; 9136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < l_count; j++ ) 9156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn x[j] = buf_ptr; 9176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn df[j] = x[j] + layer_sizes->data.i[j]; 9186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn buf_ptr += (df[j] - x[j])*2; 9196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // run back-propagation loop 9226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn /* 9236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn y_i = w_i*x_{i-1} 9246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn x_i = f(y_i) 9256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn E = 1/2*||u - x_N||^2 9266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn grad_N = (x_N - u)*f'(y_i) 9276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dw_i(t) = momentum*dw_i(t-1) + dw_scale*x_{i-1}*grad_i 9286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w_i(t+1) = w_i(t) + dw_i(t) 9296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn grad_{i-1} = w_i^t*grad_i 9306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn */ 9316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( iter = 0; iter < max_iter; iter++ ) 9326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = iter % count; 9346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* w = weights[0]; 9356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sweight = sw ? count*sw[idx] : 1.; 9366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat _w, _dw, hdr1, hdr2, ghdr1, ghdr2, _df; 9376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat *x1 = &hdr1, *x2 = &hdr2, *grad1 = &ghdr1, *grad2 = &ghdr2, *temp; 9386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( idx == 0 ) 9406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( fabs(prev_E - E) < epsilon ) 9426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 9436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn prev_E = E; 9446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn E = 0; 9456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // shuffle indices 9476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 9486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int tt; 9506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn j = (unsigned)cvRandInt(&rng) % count; 9516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn k = (unsigned)cvRandInt(&rng) % count; 9526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_SWAP( _idx->data.i[j], _idx->data.i[k], tt ); 9536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn idx = _idx->data.i[idx]; 9576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( x0.type == CV_32F ) 9596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float* x0data = x0.data.fl[idx]; 9616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < ivcount; j++ ) 9626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn x[0][j] = x0data[j]*w[j*2] + w[j*2 + 1]; 9636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 9656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* x0data = x0.data.db[idx]; 9676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < ivcount; j++ ) 9686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn x[0][j] = x0data[j]*w[j*2] + w[j*2 + 1]; 9696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvInitMatHeader( x1, 1, ivcount, CV_64F, x[0] ); 9726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // forward pass, compute y[i]=w*x[i-1], x[i]=f(y[i]), df[i]=f'(y[i]) 9746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 1; i < l_count; i++ ) 9756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvInitMatHeader( x2, 1, layer_sizes->data.i[i], CV_64F, x[i] ); 9776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvInitMatHeader( &_w, x1->cols, x2->cols, CV_64F, weights[i] ); 9786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGEMM( x1, &_w, 1, 0, 0, x2 ); 9796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _df = *x2; 9806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _df.data.db = df[i]; 9816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn calc_activ_func_deriv( x2, &_df, _w.data.db + _w.rows*_w.cols ); 9826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_SWAP( x1, x2, temp ); 9836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvInitMatHeader( grad1, 1, ovcount, CV_64F, buf_ptr ); 9866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn *grad2 = *grad1; 9876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn grad2->data.db = buf_ptr + max_count; 9886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w = weights[l_count+1]; 9906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // calculate error 9926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( u.type == CV_32F ) 9936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float* udata = u.data.fl[idx]; 9956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < ovcount; k++ ) 9966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double t = udata[k]*w[k*2] + w[k*2+1] - x[l_count-1][k]; 9986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn grad1->data.db[k] = t*sweight; 9996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn E += t*t; 10006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 10036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* udata = u.data.db[idx]; 10056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < ovcount; k++ ) 10066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double t = udata[k]*w[k*2] + w[k*2+1] - x[l_count-1][k]; 10086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn grad1->data.db[k] = t*sweight; 10096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn E += t*t; 10106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn E *= sweight; 10136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // backward pass, update weights 10156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = l_count-1; i > 0; i-- ) 10166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n1 = layer_sizes->data.i[i-1], n2 = layer_sizes->data.i[i]; 10186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvInitMatHeader( &_df, 1, n2, CV_64F, df[i] ); 10196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvMul( grad1, &_df, grad1 ); 10206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvInitMatHeader( &_w, n1+1, n2, CV_64F, weights[i] ); 10216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvInitMatHeader( &_dw, n1+1, n2, CV_64F, dw->data.db + (weights[i] - weights[0]) ); 10226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvInitMatHeader( x1, n1+1, 1, CV_64F, x[i-1] ); 10236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn x[i-1][n1] = 1.; 10246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGEMM( x1, grad1, params.bp_dw_scale, &_dw, params.bp_moment_scale, &_dw ); 10256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvAdd( &_w, &_dw, &_w ); 10266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( i > 1 ) 10276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn grad2->cols = n1; 10296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _w.rows = n1; 10306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGEMM( grad1, &_w, 1, 0, 0, grad2, CV_GEMM_B_T ); 10316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_SWAP( grad1, grad2, temp ); 10336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn iter /= count; 10376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 10396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &dw ); 10416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &buf ); 10426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &_idx ); 10436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &x ); 10446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return iter; 10466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 10476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvANN_MLP::train_rprop( CvVectors x0, CvVectors u, const double* sw ) 10506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 10516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int max_buf_sz = 1 << 16; 10526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* dw = 0; 10536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* dEdw = 0; 10546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* prev_dEdw_sign = 0; 10556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* buf = 0; 10566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double **x = 0, **df = 0; 10576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int iter = -1, count = x0.count; 10586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvANN_MLP::train" ); 10606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 10626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, ivcount, ovcount, l_count, total = 0, max_iter, buf_sz, dcount0, dcount=0; 10646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double *buf_ptr; 10656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double prev_E = DBL_MAX*0.5, epsilon; 10666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double dw_plus, dw_minus, dw_min, dw_max; 10676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double inv_count; 10686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_iter = params.term_crit.max_iter; 10706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn epsilon = params.term_crit.epsilon; 10716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dw_plus = params.rp_dw_plus; 10726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dw_minus = params.rp_dw_minus; 10736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dw_min = params.rp_dw_min; 10746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dw_max = params.rp_dw_max; 10756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn l_count = layer_sizes->cols; 10776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ivcount = layer_sizes->data.i[0]; 10786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ovcount = layer_sizes->data.i[l_count-1]; 10796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // allocate buffers 10816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < l_count; i++ ) 10826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn total += layer_sizes->data.i[i]; 10836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( dw = cvCreateMat( wbuf->rows, wbuf->cols, wbuf->type )); 10856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSet( dw, cvScalarAll(params.rp_dw0) ); 10866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( dEdw = cvCreateMat( wbuf->rows, wbuf->cols, wbuf->type )); 10876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvZero( dEdw ); 10886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( prev_dEdw_sign = cvCreateMat( wbuf->rows, wbuf->cols, CV_8SC1 )); 10896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvZero( prev_dEdw_sign ); 10906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn inv_count = 1./count; 10926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dcount0 = max_buf_sz/(2*total); 10936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dcount0 = MAX( dcount0, 1 ); 10946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dcount0 = MIN( dcount0, count ); 10956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn buf_sz = dcount0*(total + max_count)*2; 10966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( buf = cvCreateMat( 1, buf_sz, CV_64F )); 10986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( x = (double**)cvAlloc( total*2*sizeof(x[0]) )); 11006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn df = x + total; 11016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn buf_ptr = buf->data.db; 11026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < l_count; i++ ) 11046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn x[i] = buf_ptr; 11066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn df[i] = x[i] + layer_sizes->data.i[i]*dcount0; 11076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn buf_ptr += (df[i] - x[i])*2; 11086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // run rprop loop 11116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn /* 11126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn y_i(t) = w_i(t)*x_{i-1}(t) 11136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn x_i(t) = f(y_i(t)) 11146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn E = sum_over_all_samples(1/2*||u - x_N||^2) 11156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn grad_N = (x_N - u)*f'(y_i) 11166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn MIN(dw_i{jk}(t)*dw_plus, dw_max), if dE/dw_i{jk}(t)*dE/dw_i{jk}(t-1) > 0 11186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dw_i{jk}(t) = MAX(dw_i{jk}(t)*dw_minus, dw_min), if dE/dw_i{jk}(t)*dE/dw_i{jk}(t-1) < 0 11196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dw_i{jk}(t-1) else 11206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if (dE/dw_i{jk}(t)*dE/dw_i{jk}(t-1) < 0) 11226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dE/dw_i{jk}(t)<-0 11236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 11246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w_i{jk}(t+1) = w_i{jk}(t) + dw_i{jk}(t) 11256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn grad_{i-1}(t) = w_i^t(t)*grad_i(t) 11266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn */ 11276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( iter = 0; iter < max_iter; iter++ ) 11286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n1, n2, si, j, k; 11306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* w; 11316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat _w, _dEdw, hdr1, hdr2, ghdr1, ghdr2, _df; 11326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat *x1, *x2, *grad1, *grad2, *temp; 11336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double E = 0; 11346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // first, iterate through all the samples and compute dEdw 11366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( si = 0; si < count; si += dcount ) 11376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dcount = MIN( count - si, dcount0 ); 11396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w = weights[0]; 11406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn grad1 = &ghdr1; grad2 = &ghdr2; 11416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn x1 = &hdr1; x2 = &hdr2; 11426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // grab and preprocess input data 11446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( x0.type == CV_32F ) 11456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < dcount; i++ ) 11466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float* x0data = x0.data.fl[si+i]; 11486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* xdata = x[0]+i*ivcount; 11496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < ivcount; j++ ) 11506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn xdata[j] = x0data[j]*w[j*2] + w[j*2+1]; 11516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 11536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < dcount; i++ ) 11546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* x0data = x0.data.db[si+i]; 11566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* xdata = x[0]+i*ivcount; 11576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < ivcount; j++ ) 11586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn xdata[j] = x0data[j]*w[j*2] + w[j*2+1]; 11596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvInitMatHeader( x1, dcount, ivcount, CV_64F, x[0] ); 11626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // forward pass, compute y[i]=w*x[i-1], x[i]=f(y[i]), df[i]=f'(y[i]) 11646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 1; i < l_count; i++ ) 11656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvInitMatHeader( x2, dcount, layer_sizes->data.i[i], CV_64F, x[i] ); 11676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvInitMatHeader( &_w, x1->cols, x2->cols, CV_64F, weights[i] ); 11686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGEMM( x1, &_w, 1, 0, 0, x2 ); 11696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _df = *x2; 11706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _df.data.db = df[i]; 11716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn calc_activ_func_deriv( x2, &_df, _w.data.db + _w.rows*_w.cols ); 11726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_SWAP( x1, x2, temp ); 11736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvInitMatHeader( grad1, dcount, ovcount, CV_64F, buf_ptr ); 11766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w = weights[l_count+1]; 11776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn grad2->data.db = buf_ptr + max_count*dcount; 11786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // calculate error 11806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( u.type == CV_32F ) 11816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < dcount; i++ ) 11826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float* udata = u.data.fl[si+i]; 11846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* xdata = x[l_count-1] + i*ovcount; 11856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* gdata = grad1->data.db + i*ovcount; 11866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sweight = sw ? sw[si+i] : inv_count, E1 = 0; 11876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < ovcount; j++ ) 11896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double t = udata[j]*w[j*2] + w[j*2+1] - xdata[j]; 11916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn gdata[j] = t*sweight; 11926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn E1 += t*t; 11936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn E += sweight*E1; 11956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 11976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < dcount; i++ ) 11986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* udata = u.data.db[si+i]; 12006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* xdata = x[l_count-1] + i*ovcount; 12016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* gdata = grad1->data.db + i*ovcount; 12026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sweight = sw ? sw[si+i] : inv_count, E1 = 0; 12036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < ovcount; j++ ) 12056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 12066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double t = udata[j]*w[j*2] + w[j*2+1] - xdata[j]; 12076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn gdata[j] = t*sweight; 12086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn E1 += t*t; 12096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 12106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn E += sweight*E1; 12116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 12126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // backward pass, update dEdw 12146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = l_count-1; i > 0; i-- ) 12156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 12166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn n1 = layer_sizes->data.i[i-1]; n2 = layer_sizes->data.i[i]; 12176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvInitMatHeader( &_df, dcount, n2, CV_64F, df[i] ); 12186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvMul( grad1, &_df, grad1 ); 12196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvInitMatHeader( &_dEdw, n1, n2, CV_64F, dEdw->data.db+(weights[i]-weights[0]) ); 12206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvInitMatHeader( x1, dcount, n1, CV_64F, x[i-1] ); 12216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGEMM( x1, grad1, 1, &_dEdw, 1, &_dEdw, CV_GEMM_A_T ); 12226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // update bias part of dEdw 12236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < dcount; k++ ) 12246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 12256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* dst = _dEdw.data.db + n1*n2; 12266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* src = grad1->data.db + k*n2; 12276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < n2; j++ ) 12286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dst[j] += src[j]; 12296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 12306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvInitMatHeader( &_w, n1, n2, CV_64F, weights[i] ); 12316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvInitMatHeader( grad2, dcount, n1, CV_64F, grad2->data.db ); 12326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( i > 1 ) 12346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGEMM( grad1, &_w, 1, 0, 0, grad2, CV_GEMM_B_T ); 12356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_SWAP( grad1, grad2, temp ); 12366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 12376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 12386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // now update weights 12406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 1; i < l_count; i++ ) 12416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 12426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn n1 = layer_sizes->data.i[i-1]; n2 = layer_sizes->data.i[i]; 12436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k <= n1; k++ ) 12446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 12456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* wk = weights[i]+k*n2; 12466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn size_t delta = wk - weights[0]; 12476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* dwk = dw->data.db + delta; 12486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* dEdwk = dEdw->data.db + delta; 12496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn char* prevEk = (char*)(prev_dEdw_sign->data.ptr + delta); 12506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < n2; j++ ) 12526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 12536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double Eval = dEdwk[j]; 12546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double dval = dwk[j]; 12556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double wval = wk[j]; 12566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int s = CV_SIGN(Eval); 12576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int ss = prevEk[j]*s; 12586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ss > 0 ) 12596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 12606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dval *= dw_plus; 12616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dval = MIN( dval, dw_max ); 12626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dwk[j] = dval; 12636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn wk[j] = wval + dval*s; 12646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 12656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( ss < 0 ) 12666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 12676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dval *= dw_minus; 12686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dval = MAX( dval, dw_min ); 12696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn prevEk[j] = 0; 12706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dwk[j] = dval; 12716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn wk[j] = wval + dval*s; 12726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 12736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 12746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 12756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn prevEk[j] = (char)s; 12766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn wk[j] = wval + dval*s; 12776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 12786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dEdwk[j] = 0.; 12796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 12806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 12816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 12826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( fabs(prev_E - E) < epsilon ) 12846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 12856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn prev_E = E; 12866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn E = 0; 12876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 12886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 12906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &dw ); 12926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &dEdw ); 12936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &prev_dEdw_sign ); 12946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &buf ); 12956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &x ); 12966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return iter; 12986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 12996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::write_params( CvFileStorage* fs ) 13026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 13036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn //CV_FUNCNAME( "CvANN_MLP::write_params" ); 13046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 13066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const char* activ_func_name = activ_func == IDENTITY ? "IDENTITY" : 13086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn activ_func == SIGMOID_SYM ? "SIGMOID_SYM" : 13096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn activ_func == GAUSSIAN ? "GAUSSIAN" : 0; 13106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( activ_func_name ) 13126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteString( fs, "activation_function", activ_func_name ); 13136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 13146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "activation_function", activ_func ); 13156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( activ_func != IDENTITY ) 13176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 13186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteReal( fs, "f_param1", f_param1 ); 13196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteReal( fs, "f_param2", f_param2 ); 13206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 13216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteReal( fs, "min_val", min_val ); 13236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteReal( fs, "max_val", max_val ); 13246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteReal( fs, "min_val1", min_val1 ); 13256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteReal( fs, "max_val1", max_val1 ); 13266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartWriteStruct( fs, "training_params", CV_NODE_MAP ); 13286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.train_method == CvANN_MLP_TrainParams::BACKPROP ) 13296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 13306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteString( fs, "train_method", "BACKPROP" ); 13316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteReal( fs, "dw_scale", params.bp_dw_scale ); 13326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteReal( fs, "moment_scale", params.bp_moment_scale ); 13336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 13346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( params.train_method == CvANN_MLP_TrainParams::RPROP ) 13356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 13366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteString( fs, "train_method", "RPROP" ); 13376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteReal( fs, "dw0", params.rp_dw0 ); 13386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteReal( fs, "dw_plus", params.rp_dw_plus ); 13396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteReal( fs, "dw_minus", params.rp_dw_minus ); 13406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteReal( fs, "dw_min", params.rp_dw_min ); 13416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteReal( fs, "dw_max", params.rp_dw_max ); 13426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 13436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartWriteStruct( fs, "term_criteria", CV_NODE_MAP + CV_NODE_FLOW ); 13456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.term_crit.type & CV_TERMCRIT_EPS ) 13466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteReal( fs, "epsilon", params.term_crit.epsilon ); 13476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.term_crit.type & CV_TERMCRIT_ITER ) 13486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "iterations", params.term_crit.max_iter ); 13496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvEndWriteStruct( fs ); 13506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvEndWriteStruct( fs ); 13526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 13546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 13556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::write( CvFileStorage* fs, const char* name ) 13586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 13596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvANN_MLP::write" ); 13606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 13626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, l_count = layer_sizes->cols; 13646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !layer_sizes ) 13666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsError, "The network has not been initialized" ); 13676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_ANN_MLP ); 13696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWrite( fs, "layer_sizes", layer_sizes ); 13716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn write_params( fs ); 13736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartWriteStruct( fs, "input_scale", CV_NODE_SEQ + CV_NODE_FLOW ); 13756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteRawData( fs, weights[0], layer_sizes->data.i[0]*2, "d" ); 13766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvEndWriteStruct( fs ); 13776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartWriteStruct( fs, "output_scale", CV_NODE_SEQ + CV_NODE_FLOW ); 13796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteRawData( fs, weights[l_count], layer_sizes->data.i[l_count-1]*2, "d" ); 13806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvEndWriteStruct( fs ); 13816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartWriteStruct( fs, "inv_output_scale", CV_NODE_SEQ + CV_NODE_FLOW ); 13836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteRawData( fs, weights[l_count+1], layer_sizes->data.i[l_count-1]*2, "d" ); 13846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvEndWriteStruct( fs ); 13856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartWriteStruct( fs, "weights", CV_NODE_SEQ ); 13876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 1; i < l_count; i++ ) 13886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 13896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartWriteStruct( fs, 0, CV_NODE_SEQ + CV_NODE_FLOW ); 13906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteRawData( fs, weights[i], (layer_sizes->data.i[i-1]+1)*layer_sizes->data.i[i], "d" ); 13916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvEndWriteStruct( fs ); 13926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 13936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvEndWriteStruct( fs ); 13956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 13976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 13986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::read_params( CvFileStorage* fs, CvFileNode* node ) 14016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 14026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn //CV_FUNCNAME( "CvANN_MLP::read_params" ); 14036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 14056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const char* activ_func_name = cvReadStringByName( fs, node, "activation_function", 0 ); 14076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvFileNode* tparams_node; 14086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( activ_func_name ) 14106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn activ_func = strcmp( activ_func_name, "SIGMOID_SYM" ) == 0 ? SIGMOID_SYM : 14116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn strcmp( activ_func_name, "IDENTITY" ) == 0 ? IDENTITY : 14126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn strcmp( activ_func_name, "GAUSSIAN" ) == 0 ? GAUSSIAN : 0; 14136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 14146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn activ_func = cvReadIntByName( fs, node, "activation_function" ); 14156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn f_param1 = cvReadRealByName( fs, node, "f_param1", 0 ); 14176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn f_param2 = cvReadRealByName( fs, node, "f_param2", 0 ); 14186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn set_activ_func( activ_func, f_param1, f_param2 ); 14206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn min_val = cvReadRealByName( fs, node, "min_val", 0. ); 14226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_val = cvReadRealByName( fs, node, "max_val", 1. ); 14236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn min_val1 = cvReadRealByName( fs, node, "min_val1", 0. ); 14246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_val1 = cvReadRealByName( fs, node, "max_val1", 1. ); 14256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn tparams_node = cvGetFileNodeByName( fs, node, "training_params" ); 14276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params = CvANN_MLP_TrainParams(); 14286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( tparams_node ) 14306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 14316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const char* tmethod_name = cvReadStringByName( fs, tparams_node, "train_method", "" ); 14326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvFileNode* tcrit_node; 14336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( strcmp( tmethod_name, "BACKPROP" ) == 0 ) 14356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 14366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.train_method = CvANN_MLP_TrainParams::BACKPROP; 14376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.bp_dw_scale = cvReadRealByName( fs, tparams_node, "dw_scale", 0 ); 14386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.bp_moment_scale = cvReadRealByName( fs, tparams_node, "moment_scale", 0 ); 14396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 14406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( strcmp( tmethod_name, "RPROP" ) == 0 ) 14416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 14426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.train_method = CvANN_MLP_TrainParams::RPROP; 14436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.rp_dw0 = cvReadRealByName( fs, tparams_node, "dw0", 0 ); 14446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.rp_dw_plus = cvReadRealByName( fs, tparams_node, "dw_plus", 0 ); 14456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.rp_dw_minus = cvReadRealByName( fs, tparams_node, "dw_minus", 0 ); 14466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.rp_dw_min = cvReadRealByName( fs, tparams_node, "dw_min", 0 ); 14476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.rp_dw_max = cvReadRealByName( fs, tparams_node, "dw_max", 0 ); 14486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 14496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn tcrit_node = cvGetFileNodeByName( fs, tparams_node, "term_criteria" ); 14516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( tcrit_node ) 14526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 14536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.term_crit.epsilon = cvReadRealByName( fs, tcrit_node, "epsilon", -1 ); 14546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.term_crit.max_iter = cvReadIntByName( fs, tcrit_node, "iterations", -1 ); 14556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.term_crit.type = (params.term_crit.epsilon >= 0 ? CV_TERMCRIT_EPS : 0) + 14566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn (params.term_crit.max_iter >= 0 ? CV_TERMCRIT_ITER : 0); 14576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 14586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 14596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 14616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 14626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvANN_MLP::read( CvFileStorage* fs, CvFileNode* node ) 14656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 14666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* _layer_sizes = 0; 14676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvANN_MLP::read" ); 14696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 14716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvFileNode* w; 14736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvSeqReader reader; 14746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, l_count; 14756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _layer_sizes = (CvMat*)cvReadByName( fs, node, "layer_sizes" ); 14776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( create( _layer_sizes, SIGMOID_SYM, 0, 0 )); 14786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn l_count = layer_sizes->cols; 14796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( read_params( fs, node )); 14816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w = cvGetFileNodeByName( fs, node, "input_scale" ); 14836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !w || CV_NODE_TYPE(w->tag) != CV_NODE_SEQ || 14846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w->data.seq->total != layer_sizes->data.i[0]*2 ) 14856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsParseError, "input_scale tag is not found or is invalid" ); 14866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvReadRawData( fs, w, weights[0], "d" )); 14886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w = cvGetFileNodeByName( fs, node, "output_scale" ); 14906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !w || CV_NODE_TYPE(w->tag) != CV_NODE_SEQ || 14916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w->data.seq->total != layer_sizes->data.i[l_count-1]*2 ) 14926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsParseError, "output_scale tag is not found or is invalid" ); 14936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvReadRawData( fs, w, weights[l_count], "d" )); 14956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w = cvGetFileNodeByName( fs, node, "inv_output_scale" ); 14976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !w || CV_NODE_TYPE(w->tag) != CV_NODE_SEQ || 14986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w->data.seq->total != layer_sizes->data.i[l_count-1]*2 ) 14996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsParseError, "inv_output_scale tag is not found or is invalid" ); 15006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvReadRawData( fs, w, weights[l_count+1], "d" )); 15026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w = cvGetFileNodeByName( fs, node, "weights" ); 15046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !w || CV_NODE_TYPE(w->tag) != CV_NODE_SEQ || 15056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w->data.seq->total != l_count - 1 ) 15066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsParseError, "weights tag is not found or is invalid" ); 15076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartReadSeq( w->data.seq, &reader ); 15096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 1; i < l_count; i++ ) 15116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 15126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w = (CvFileNode*)reader.ptr; 15136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvReadRawData( fs, w, weights[i], "d" )); 15146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader ); 15156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 15166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 15186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 15196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/* End of file. */ 1521