1793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler/*M/////////////////////////////////////////////////////////////////////////////////////// 2793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// 3793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 4793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// 5793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// By downloading, copying, installing or using the software you agree to this license. 6793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// If you do not agree to this license, do not download, install, 7793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// copy or use the software. 8793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// 9793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// 10793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// Intel License Agreement 11793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// 12793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// Copyright (C) 2000, Intel Corporation, all rights reserved. 13793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// Third party copyrights are property of their respective owners. 14793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// 15793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// Redistribution and use in source and binary forms, with or without modification, 16793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// are permitted provided that the following conditions are met: 17793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// 18793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// * Redistribution's of source code must retain the above copyright notice, 19793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// this list of conditions and the following disclaimer. 20793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// 21793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// * Redistribution's in binary form must reproduce the above copyright notice, 22793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// this list of conditions and the following disclaimer in the documentation 23793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// and/or other materials provided with the distribution. 24793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// 25793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// * The name of Intel Corporation may not be used to endorse or promote products 26793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// derived from this software without specific prior written permission. 27793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// 28793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// This software is provided by the copyright holders and contributors "as is" and 29793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// any express or implied warranties, including, but not limited to, the implied 30793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// warranties of merchantability and fitness for a particular purpose are disclaimed. 31793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// In no event shall the Intel Corporation or contributors be liable for any direct, 32793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// indirect, incidental, special, exemplary, or consequential damages 33793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// (including, but not limited to, procurement of substitute goods or services; 34793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// loss of use, data, or profits; or business interruption) however caused 35793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// and on any theory of liability, whether in contract, strict liability, 36793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// or tort (including negligence or otherwise) arising in any way out of 37793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// the use of this software, even if advised of the possibility of such damage. 38793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// 39793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//M*/ 40793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 41793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler#include "precomp.hpp" 42793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 43793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslernamespace cv { namespace ml { 44793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 45793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerstruct AnnParams 46793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler{ 47793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler AnnParams() 48793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 49793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler termCrit = TermCriteria( TermCriteria::COUNT + TermCriteria::EPS, 1000, 0.01 ); 50793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler trainMethod = ANN_MLP::RPROP; 51793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler bpDWScale = bpMomentScale = 0.1; 52793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler rpDW0 = 0.1; rpDWPlus = 1.2; rpDWMinus = 0.5; 53793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler rpDWMin = FLT_EPSILON; rpDWMax = 50.; 54793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 55793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 56793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler TermCriteria termCrit; 57793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int trainMethod; 58793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 59793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double bpDWScale; 60793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double bpMomentScale; 61793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 62793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double rpDW0; 63793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double rpDWPlus; 64793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double rpDWMinus; 65793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double rpDWMin; 66793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double rpDWMax; 67793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler}; 68793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 69793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslertemplate <typename T> 70793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerinline T inBounds(T val, T min_val, T max_val) 71793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler{ 72793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return std::min(std::max(val, min_val), max_val); 73793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler} 74793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 75793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerclass ANN_MLPImpl : public ANN_MLP 76793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler{ 77793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerpublic: 78793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler ANN_MLPImpl() 79793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 80793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler clear(); 81793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler setActivationFunction( SIGMOID_SYM, 0, 0 ); 82793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler setLayerSizes(Mat()); 83793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler setTrainMethod(ANN_MLP::RPROP, 0.1, FLT_EPSILON); 84793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 85793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 86793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler virtual ~ANN_MLPImpl() {} 87793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 88793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_IMPL_PROPERTY(TermCriteria, TermCriteria, params.termCrit) 89793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_IMPL_PROPERTY(double, BackpropWeightScale, params.bpDWScale) 90793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_IMPL_PROPERTY(double, BackpropMomentumScale, params.bpMomentScale) 91793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_IMPL_PROPERTY(double, RpropDW0, params.rpDW0) 92793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_IMPL_PROPERTY(double, RpropDWPlus, params.rpDWPlus) 93793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_IMPL_PROPERTY(double, RpropDWMinus, params.rpDWMinus) 94793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_IMPL_PROPERTY(double, RpropDWMin, params.rpDWMin) 95793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_IMPL_PROPERTY(double, RpropDWMax, params.rpDWMax) 96793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 97793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void clear() 98793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 99793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler min_val = max_val = min_val1 = max_val1 = 0.; 100793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler rng = RNG((uint64)-1); 101793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler weights.clear(); 102793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler trained = false; 103793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler max_buf_sz = 1 << 12; 104793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 105793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 106793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int layer_count() const { return (int)layer_sizes.size(); } 107793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 108793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void setTrainMethod(int method, double param1, double param2) 109793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 110793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if (method != ANN_MLP::RPROP && method != ANN_MLP::BACKPROP) 111793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler method = ANN_MLP::RPROP; 112793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.trainMethod = method; 113793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if(method == ANN_MLP::RPROP ) 114793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 115793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( param1 < FLT_EPSILON ) 116793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler param1 = 1.; 117793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.rpDW0 = param1; 118793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.rpDWMin = std::max( param2, 0. ); 119793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 120793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler else if(method == ANN_MLP::BACKPROP ) 121793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 122793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( param1 <= 0 ) 123793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler param1 = 0.1; 124793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.bpDWScale = inBounds<double>(param1, 1e-3, 1.); 125793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( param2 < 0 ) 126793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler param2 = 0.1; 127793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.bpMomentScale = std::min( param2, 1. ); 128793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 129793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 130793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 131793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int getTrainMethod() const 132793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 133793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return params.trainMethod; 134793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 135793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 136793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void setActivationFunction(int _activ_func, double _f_param1, double _f_param2 ) 137793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 138793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( _activ_func < 0 || _activ_func > GAUSSIAN ) 139793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_Error( CV_StsOutOfRange, "Unknown activation function" ); 140793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 141793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler activ_func = _activ_func; 142793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 143793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler switch( activ_func ) 144793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 145793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler case SIGMOID_SYM: 146793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler max_val = 0.95; min_val = -max_val; 147793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler max_val1 = 0.98; min_val1 = -max_val1; 148793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( fabs(_f_param1) < FLT_EPSILON ) 149793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler _f_param1 = 2./3; 150793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( fabs(_f_param2) < FLT_EPSILON ) 151793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler _f_param2 = 1.7159; 152793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler break; 153793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler case GAUSSIAN: 154793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler max_val = 1.; min_val = 0.05; 155793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler max_val1 = 1.; min_val1 = 0.02; 156793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( fabs(_f_param1) < FLT_EPSILON ) 157793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler _f_param1 = 1.; 158793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( fabs(_f_param2) < FLT_EPSILON ) 159793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler _f_param2 = 1.; 160793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler break; 161793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler default: 162793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler min_val = max_val = min_val1 = max_val1 = 0.; 163793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler _f_param1 = 1.; 164793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler _f_param2 = 0.; 165793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 166793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 167793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler f_param1 = _f_param1; 168793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler f_param2 = _f_param2; 169793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 170793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 171793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 172793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void init_weights() 173793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 174793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int i, j, k, l_count = layer_count(); 175793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 176793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 1; i < l_count; i++ ) 177793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 178793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int n1 = layer_sizes[i-1]; 179793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int n2 = layer_sizes[i]; 180793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double val = 0, G = n2 > 2 ? 0.7*pow((double)n1,1./(n2-1)) : 1.; 181793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* w = weights[i].ptr<double>(); 182793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 183793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler // initialize weights using Nguyen-Widrow algorithm 184793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( j = 0; j < n2; j++ ) 185793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 186793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double s = 0; 187793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( k = 0; k <= n1; k++ ) 188793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 189793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler val = rng.uniform(0., 1.)*2-1.; 190793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler w[k*n2 + j] = val; 191793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler s += fabs(val); 192793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 193793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 194793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( i < l_count - 1 ) 195793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 196793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler s = 1./(s - fabs(val)); 197793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( k = 0; k <= n1; k++ ) 198793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler w[k*n2 + j] *= s; 199793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler w[n1*n2 + j] *= G*(-1+j*2./n2); 200793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 201793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 202793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 203793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 204793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 205793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat getLayerSizes() const 206793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 207793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return Mat_<int>(layer_sizes, true); 208793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 209793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 210793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void setLayerSizes( InputArray _layer_sizes ) 211793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 212793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler clear(); 213793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 214793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler _layer_sizes.copyTo(layer_sizes); 215793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int l_count = layer_count(); 216793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 217793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler weights.resize(l_count + 2); 218793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler max_lsize = 0; 219793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 220793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( l_count > 0 ) 221793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 222793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( int i = 0; i < l_count; i++ ) 223793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 224793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int n = layer_sizes[i]; 225793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( n < 1 + (0 < i && i < l_count-1)) 226793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_Error( CV_StsOutOfRange, 227793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler "there should be at least one input and one output " 228793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler "and every hidden layer must have more than 1 neuron" ); 229793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler max_lsize = std::max( max_lsize, n ); 230793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( i > 0 ) 231793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler weights[i].create(layer_sizes[i-1]+1, n, CV_64F); 232793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 233793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 234793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int ninputs = layer_sizes.front(); 235793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int noutputs = layer_sizes.back(); 236793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler weights[0].create(1, ninputs*2, CV_64F); 237793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler weights[l_count].create(1, noutputs*2, CV_64F); 238793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler weights[l_count+1].create(1, noutputs*2, CV_64F); 239793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 240793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 241793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 242793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler float predict( InputArray _inputs, OutputArray _outputs, int ) const 243793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 244793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( !trained ) 245793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_Error( CV_StsError, "The network has not been trained or loaded" ); 246793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 247793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat inputs = _inputs.getMat(); 248793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int type = inputs.type(), l_count = layer_count(); 249793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int n = inputs.rows, dn0 = n; 250793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 251793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_Assert( (type == CV_32F || type == CV_64F) && inputs.cols == layer_sizes[0] ); 252793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int noutputs = layer_sizes[l_count-1]; 253793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat outputs; 254793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 255793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int min_buf_sz = 2*max_lsize; 256793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int buf_sz = n*min_buf_sz; 257793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 258793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( buf_sz > max_buf_sz ) 259793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 260793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dn0 = max_buf_sz/min_buf_sz; 261793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dn0 = std::max( dn0, 1 ); 262793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler buf_sz = dn0*min_buf_sz; 263793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 264793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 265793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler cv::AutoBuffer<double> _buf(buf_sz+noutputs); 266793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* buf = _buf; 267793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 268793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( !_outputs.needed() ) 269793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 270793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_Assert( n == 1 ); 271793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler outputs = Mat(n, noutputs, type, buf + buf_sz); 272793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 273793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler else 274793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 275793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler _outputs.create(n, noutputs, type); 276793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler outputs = _outputs.getMat(); 277793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 278793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 279793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int dn = 0; 280793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( int i = 0; i < n; i += dn ) 281793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 282793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dn = std::min( dn0, n - i ); 283793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 284793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat layer_in = inputs.rowRange(i, i + dn); 285793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat layer_out( dn, layer_in.cols, CV_64F, buf); 286793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 287793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler scale_input( layer_in, layer_out ); 288793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler layer_in = layer_out; 289793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 290793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( int j = 1; j < l_count; j++ ) 291793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 292793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* data = buf + ((j&1) ? max_lsize*dn0 : 0); 293793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int cols = layer_sizes[j]; 294793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 295793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler layer_out = Mat(dn, cols, CV_64F, data); 296793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat w = weights[j].rowRange(0, layer_in.cols); 297793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler gemm(layer_in, w, 1, noArray(), 0, layer_out); 298793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler calc_activ_func( layer_out, weights[j] ); 299793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 300793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler layer_in = layer_out; 301793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 302793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 303793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler layer_out = outputs.rowRange(i, i + dn); 304793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler scale_output( layer_in, layer_out ); 305793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 306793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 307793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( n == 1 ) 308793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 309793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int maxIdx[] = {0, 0}; 310793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler minMaxIdx(outputs, 0, 0, 0, maxIdx); 311793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return (float)(maxIdx[0] + maxIdx[1]); 312793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 313793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 314793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return 0.f; 315793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 316793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 317793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void scale_input( const Mat& _src, Mat& _dst ) const 318793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 319793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int cols = _src.cols; 320793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const double* w = weights[0].ptr<double>(); 321793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 322793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( _src.type() == CV_32F ) 323793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 324793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( int i = 0; i < _src.rows; i++ ) 325793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 326793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const float* src = _src.ptr<float>(i); 327793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* dst = _dst.ptr<double>(i); 328793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( int j = 0; j < cols; j++ ) 329793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dst[j] = src[j]*w[j*2] + w[j*2+1]; 330793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 331793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 332793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler else 333793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 334793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( int i = 0; i < _src.rows; i++ ) 335793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 336793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const float* src = _src.ptr<float>(i); 337793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* dst = _dst.ptr<double>(i); 338793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( int j = 0; j < cols; j++ ) 339793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dst[j] = src[j]*w[j*2] + w[j*2+1]; 340793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 341793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 342793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 343793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 344793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void scale_output( const Mat& _src, Mat& _dst ) const 345793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 346793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int cols = _src.cols; 347793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const double* w = weights[layer_count()].ptr<double>(); 348793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 349793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( _dst.type() == CV_32F ) 350793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 351793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( int i = 0; i < _src.rows; i++ ) 352793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 353793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const double* src = _src.ptr<double>(i); 354793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler float* dst = _dst.ptr<float>(i); 355793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( int j = 0; j < cols; j++ ) 356793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dst[j] = (float)(src[j]*w[j*2] + w[j*2+1]); 357793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 358793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 359793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler else 360793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 361793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( int i = 0; i < _src.rows; i++ ) 362793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 363793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const double* src = _src.ptr<double>(i); 364793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* dst = _dst.ptr<double>(i); 365793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( int j = 0; j < cols; j++ ) 366793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dst[j] = src[j]*w[j*2] + w[j*2+1]; 367793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 368793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 369793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 370793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 371793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void calc_activ_func( Mat& sums, const Mat& w ) const 372793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 373793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const double* bias = w.ptr<double>(w.rows-1); 374793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int i, j, n = sums.rows, cols = sums.cols; 375793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double scale = 0, scale2 = f_param2; 376793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 377793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler switch( activ_func ) 378793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 379793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler case IDENTITY: 380793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler scale = 1.; 381793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler break; 382793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler case SIGMOID_SYM: 383793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler scale = -f_param1; 384793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler break; 385793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler case GAUSSIAN: 386793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler scale = -f_param1*f_param1; 387793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler break; 388793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler default: 389793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler ; 390793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 391793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 392793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_Assert( sums.isContinuous() ); 393793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 394793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( activ_func != GAUSSIAN ) 395793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 396793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < n; i++ ) 397793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 398793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* data = sums.ptr<double>(i); 399793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( j = 0; j < cols; j++ ) 400793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler data[j] = (data[j] + bias[j])*scale; 401793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 402793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 403793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( activ_func == IDENTITY ) 404793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return; 405793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 406793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler else 407793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 408793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < n; i++ ) 409793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 410793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* data = sums.ptr<double>(i); 411793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( j = 0; j < cols; j++ ) 412793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 413793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double t = data[j] + bias[j]; 414793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler data[j] = t*t*scale; 415793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 416793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 417793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 418793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 419793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler exp( sums, sums ); 420793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 421793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( sums.isContinuous() ) 422793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 423793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler cols *= n; 424793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler n = 1; 425793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 426793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 427793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler switch( activ_func ) 428793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 429793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler case SIGMOID_SYM: 430793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < n; i++ ) 431793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 432793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* data = sums.ptr<double>(i); 433793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( j = 0; j < cols; j++ ) 434793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 435793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double t = scale2*(1. - data[j])/(1. + data[j]); 436793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler data[j] = t; 437793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 438793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 439793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler break; 440793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 441793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler case GAUSSIAN: 442793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < n; i++ ) 443793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 444793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* data = sums.ptr<double>(i); 445793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( j = 0; j < cols; j++ ) 446793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler data[j] = scale2*data[j]; 447793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 448793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler break; 449793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 450793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler default: 451793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler ; 452793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 453793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 454793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 455793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void calc_activ_func_deriv( Mat& _xf, Mat& _df, const Mat& w ) const 456793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 457793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const double* bias = w.ptr<double>(w.rows-1); 458793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int i, j, n = _xf.rows, cols = _xf.cols; 459793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 460793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( activ_func == IDENTITY ) 461793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 462793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < n; i++ ) 463793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 464793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* xf = _xf.ptr<double>(i); 465793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* df = _df.ptr<double>(i); 466793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 467793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( j = 0; j < cols; j++ ) 468793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 469793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler xf[j] += bias[j]; 470793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler df[j] = 1; 471793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 472793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 473793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 474793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler else if( activ_func == GAUSSIAN ) 475793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 476793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double scale = -f_param1*f_param1; 477793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double scale2 = scale*f_param2; 478793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < n; i++ ) 479793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 480793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* xf = _xf.ptr<double>(i); 481793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* df = _df.ptr<double>(i); 482793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 483793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( j = 0; j < cols; j++ ) 484793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 485793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double t = xf[j] + bias[j]; 486793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler df[j] = t*2*scale2; 487793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler xf[j] = t*t*scale; 488793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 489793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 490793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler exp( _xf, _xf ); 491793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 492793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < n; i++ ) 493793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 494793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* xf = _xf.ptr<double>(i); 495793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* df = _df.ptr<double>(i); 496793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 497793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( j = 0; j < cols; j++ ) 498793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler df[j] *= xf[j]; 499793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 500793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 501793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler else 502793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 503793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double scale = f_param1; 504793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double scale2 = f_param2; 505793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 506793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < n; i++ ) 507793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 508793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* xf = _xf.ptr<double>(i); 509793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* df = _df.ptr<double>(i); 510793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 511793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( j = 0; j < cols; j++ ) 512793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 513793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler xf[j] = (xf[j] + bias[j])*scale; 514793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler df[j] = -fabs(xf[j]); 515793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 516793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 517793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 518793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler exp( _df, _df ); 519793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 520793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler // ((1+exp(-ax))^-1)'=a*((1+exp(-ax))^-2)*exp(-ax); 521793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler // ((1-exp(-ax))/(1+exp(-ax)))'=(a*exp(-ax)*(1+exp(-ax)) + a*exp(-ax)*(1-exp(-ax)))/(1+exp(-ax))^2= 522793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler // 2*a*exp(-ax)/(1+exp(-ax))^2 523793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler scale *= 2*f_param2; 524793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < n; i++ ) 525793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 526793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* xf = _xf.ptr<double>(i); 527793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* df = _df.ptr<double>(i); 528793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 529793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( j = 0; j < cols; j++ ) 530793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 531793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int s0 = xf[j] > 0 ? 1 : -1; 532793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double t0 = 1./(1. + df[j]); 533793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double t1 = scale*df[j]*t0*t0; 534793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler t0 *= scale2*(1. - df[j])*s0; 535793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler df[j] = t1; 536793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler xf[j] = t0; 537793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 538793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 539793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 540793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 541793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 542793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void calc_input_scale( const Mat& inputs, int flags ) 543793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 544793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler bool reset_weights = (flags & UPDATE_WEIGHTS) == 0; 545793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler bool no_scale = (flags & NO_INPUT_SCALE) != 0; 546793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* scale = weights[0].ptr<double>(); 547793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int count = inputs.rows; 548793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 549793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( reset_weights ) 550793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 551793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int i, j, vcount = layer_sizes[0]; 552793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int type = inputs.type(); 553793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double a = no_scale ? 1. : 0.; 554793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 555793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( j = 0; j < vcount; j++ ) 556793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler scale[2*j] = a, scale[j*2+1] = 0.; 557793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 558793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( no_scale ) 559793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return; 560793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 561793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < count; i++ ) 562793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 563793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const uchar* p = inputs.ptr(i); 564793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const float* f = (const float*)p; 565793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const double* d = (const double*)p; 566793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( j = 0; j < vcount; j++ ) 567793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 568793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double t = type == CV_32F ? (double)f[j] : d[j]; 569793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler scale[j*2] += t; 570793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler scale[j*2+1] += t*t; 571793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 572793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 573793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 574793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( j = 0; j < vcount; j++ ) 575793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 576793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double s = scale[j*2], s2 = scale[j*2+1]; 577793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double m = s/count, sigma2 = s2/count - m*m; 578793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler scale[j*2] = sigma2 < DBL_EPSILON ? 1 : 1./sqrt(sigma2); 579793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler scale[j*2+1] = -m*scale[j*2]; 580793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 581793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 582793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 583793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 584793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void calc_output_scale( const Mat& outputs, int flags ) 585793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 586793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int i, j, vcount = layer_sizes.back(); 587793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int type = outputs.type(); 588793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double m = min_val, M = max_val, m1 = min_val1, M1 = max_val1; 589793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler bool reset_weights = (flags & UPDATE_WEIGHTS) == 0; 590793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler bool no_scale = (flags & NO_OUTPUT_SCALE) != 0; 591793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int l_count = layer_count(); 592793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* scale = weights[l_count].ptr<double>(); 593793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* inv_scale = weights[l_count+1].ptr<double>(); 594793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int count = outputs.rows; 595793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 596793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( reset_weights ) 597793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 598793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double a0 = no_scale ? 1 : DBL_MAX, b0 = no_scale ? 0 : -DBL_MAX; 599793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 600793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( j = 0; j < vcount; j++ ) 601793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 602793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler scale[2*j] = inv_scale[2*j] = a0; 603793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler scale[j*2+1] = inv_scale[2*j+1] = b0; 604793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 605793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 606793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( no_scale ) 607793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return; 608793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 609793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 610793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < count; i++ ) 611793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 612793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const uchar* p = outputs.ptr(i); 613793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const float* f = (const float*)p; 614793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const double* d = (const double*)p; 615793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 616793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( j = 0; j < vcount; j++ ) 617793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 618793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double t = type == CV_32F ? (double)f[j] : d[j]; 619793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 620793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( reset_weights ) 621793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 622793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double mj = scale[j*2], Mj = scale[j*2+1]; 623793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( mj > t ) mj = t; 624793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( Mj < t ) Mj = t; 625793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 626793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler scale[j*2] = mj; 627793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler scale[j*2+1] = Mj; 628793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 629793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler else if( !no_scale ) 630793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 631793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler t = t*inv_scale[j*2] + inv_scale[2*j+1]; 632793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( t < m1 || t > M1 ) 633793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_Error( CV_StsOutOfRange, 634793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler "Some of new output training vector components run exceed the original range too much" ); 635793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 636793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 637793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 638793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 639793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( reset_weights ) 640793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( j = 0; j < vcount; j++ ) 641793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 642793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler // map mj..Mj to m..M 643793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double mj = scale[j*2], Mj = scale[j*2+1]; 644793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double a, b; 645793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double delta = Mj - mj; 646793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( delta < DBL_EPSILON ) 647793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler a = 1, b = (M + m - Mj - mj)*0.5; 648793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler else 649793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler a = (M - m)/delta, b = m - mj*a; 650793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler inv_scale[j*2] = a; inv_scale[j*2+1] = b; 651793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler a = 1./a; b = -b*a; 652793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler scale[j*2] = a; scale[j*2+1] = b; 653793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 654793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 655793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 656793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void prepare_to_train( const Mat& inputs, const Mat& outputs, 657793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat& sample_weights, int flags ) 658793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 659793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( layer_sizes.empty() ) 660793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_Error( CV_StsError, 661793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler "The network has not been created. Use method create or the appropriate constructor" ); 662793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 663793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( (inputs.type() != CV_32F && inputs.type() != CV_64F) || 664793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler inputs.cols != layer_sizes[0] ) 665793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_Error( CV_StsBadArg, 666793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler "input training data should be a floating-point matrix with " 667793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler "the number of rows equal to the number of training samples and " 668793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler "the number of columns equal to the size of 0-th (input) layer" ); 669793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 670793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( (outputs.type() != CV_32F && outputs.type() != CV_64F) || 671793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler outputs.cols != layer_sizes.back() ) 672793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_Error( CV_StsBadArg, 673793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler "output training data should be a floating-point matrix with " 674793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler "the number of rows equal to the number of training samples and " 675793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler "the number of columns equal to the size of last (output) layer" ); 676793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 677793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( inputs.rows != outputs.rows ) 678793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_Error( CV_StsUnmatchedSizes, "The numbers of input and output samples do not match" ); 679793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 680793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat temp; 681793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double s = sum(sample_weights)[0]; 682793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler sample_weights.convertTo(temp, CV_64F, 1./s); 683793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler sample_weights = temp; 684793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 685793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler calc_input_scale( inputs, flags ); 686793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler calc_output_scale( outputs, flags ); 687793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 688793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 689793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler bool train( const Ptr<TrainData>& trainData, int flags ) 690793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 691793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const int MAX_ITER = 1000; 692793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const double DEFAULT_EPSILON = FLT_EPSILON; 693793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 694793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler // initialize training data 695793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat inputs = trainData->getTrainSamples(); 696793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat outputs = trainData->getTrainResponses(); 697793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat sw = trainData->getTrainSampleWeights(); 698793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler prepare_to_train( inputs, outputs, sw, flags ); 699793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 700793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler // ... and link weights 701793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( !(flags & UPDATE_WEIGHTS) ) 702793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler init_weights(); 703793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 704793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler TermCriteria termcrit; 705793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler termcrit.type = TermCriteria::COUNT + TermCriteria::EPS; 706793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler termcrit.maxCount = std::max((params.termCrit.type & CV_TERMCRIT_ITER ? params.termCrit.maxCount : MAX_ITER), 1); 707793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler termcrit.epsilon = std::max((params.termCrit.type & CV_TERMCRIT_EPS ? params.termCrit.epsilon : DEFAULT_EPSILON), DBL_EPSILON); 708793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 709793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int iter = params.trainMethod == ANN_MLP::BACKPROP ? 710793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler train_backprop( inputs, outputs, sw, termcrit ) : 711793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler train_rprop( inputs, outputs, sw, termcrit ); 712793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 713793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler trained = iter > 0; 714793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return trained; 715793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 716793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 717793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int train_backprop( const Mat& inputs, const Mat& outputs, const Mat& _sw, TermCriteria termCrit ) 718793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 719793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int i, j, k; 720793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double prev_E = DBL_MAX*0.5, E = 0; 721793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int itype = inputs.type(), otype = outputs.type(); 722793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 723793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int count = inputs.rows; 724793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 725793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int iter = -1, max_iter = termCrit.maxCount*count; 726793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double epsilon = termCrit.epsilon*count; 727793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 728793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int l_count = layer_count(); 729793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int ivcount = layer_sizes[0]; 730793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int ovcount = layer_sizes.back(); 731793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 732793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler // allocate buffers 733793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler vector<vector<double> > x(l_count); 734793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler vector<vector<double> > df(l_count); 735793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler vector<Mat> dw(l_count); 736793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 737793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < l_count; i++ ) 738793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 739793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int n = layer_sizes[i]; 740793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler x[i].resize(n+1); 741793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler df[i].resize(n); 742793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dw[i] = Mat::zeros(weights[i].size(), CV_64F); 743793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 744793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 745793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat _idx_m(1, count, CV_32S); 746793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int* _idx = _idx_m.ptr<int>(); 747793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < count; i++ ) 748793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler _idx[i] = i; 749793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 750793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler AutoBuffer<double> _buf(max_lsize*2); 751793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* buf[] = { _buf, (double*)_buf + max_lsize }; 752793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 753793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const double* sw = _sw.empty() ? 0 : _sw.ptr<double>(); 754793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 755793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler // run back-propagation loop 756793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler /* 757793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler y_i = w_i*x_{i-1} 758793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler x_i = f(y_i) 759793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler E = 1/2*||u - x_N||^2 760793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler grad_N = (x_N - u)*f'(y_i) 761793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dw_i(t) = momentum*dw_i(t-1) + dw_scale*x_{i-1}*grad_i 762793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler w_i(t+1) = w_i(t) + dw_i(t) 763793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler grad_{i-1} = w_i^t*grad_i 764793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler */ 765793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( iter = 0; iter < max_iter; iter++ ) 766793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 767793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int idx = iter % count; 768793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double sweight = sw ? count*sw[idx] : 1.; 769793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 770793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( idx == 0 ) 771793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 772793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler //printf("%d. E = %g\n", iter/count, E); 773793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( fabs(prev_E - E) < epsilon ) 774793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler break; 775793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler prev_E = E; 776793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler E = 0; 777793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 778793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler // shuffle indices 779793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < count; i++ ) 780793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 781793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler j = rng.uniform(0, count); 782793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler k = rng.uniform(0, count); 783793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler std::swap(_idx[j], _idx[k]); 784793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 785793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 786793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 787793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler idx = _idx[idx]; 788793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 789793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const uchar* x0data_p = inputs.ptr(idx); 790793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const float* x0data_f = (const float*)x0data_p; 791793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const double* x0data_d = (const double*)x0data_p; 792793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 793793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* w = weights[0].ptr<double>(); 794793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( j = 0; j < ivcount; j++ ) 795793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler x[0][j] = (itype == CV_32F ? (double)x0data_f[j] : x0data_d[j])*w[j*2] + w[j*2 + 1]; 796793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 797793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat x1( 1, ivcount, CV_64F, &x[0][0] ); 798793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 799793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler // forward pass, compute y[i]=w*x[i-1], x[i]=f(y[i]), df[i]=f'(y[i]) 800793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 1; i < l_count; i++ ) 801793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 802793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int n = layer_sizes[i]; 803793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat x2(1, n, CV_64F, &x[i][0] ); 804793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat _w = weights[i].rowRange(0, x1.cols); 805793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler gemm(x1, _w, 1, noArray(), 0, x2); 806793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat _df(1, n, CV_64F, &df[i][0] ); 807793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler calc_activ_func_deriv( x2, _df, weights[i] ); 808793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler x1 = x2; 809793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 810793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 811793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat grad1( 1, ovcount, CV_64F, buf[l_count&1] ); 812793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler w = weights[l_count+1].ptr<double>(); 813793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 814793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler // calculate error 815793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const uchar* udata_p = outputs.ptr(idx); 816793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const float* udata_f = (const float*)udata_p; 817793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const double* udata_d = (const double*)udata_p; 818793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 819793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* gdata = grad1.ptr<double>(); 820793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( k = 0; k < ovcount; k++ ) 821793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 822793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double t = (otype == CV_32F ? (double)udata_f[k] : udata_d[k])*w[k*2] + w[k*2+1] - x[l_count-1][k]; 823793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler gdata[k] = t*sweight; 824793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler E += t*t; 825793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 826793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler E *= sweight; 827793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 828793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler // backward pass, update weights 829793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = l_count-1; i > 0; i-- ) 830793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 831793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int n1 = layer_sizes[i-1], n2 = layer_sizes[i]; 832793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat _df(1, n2, CV_64F, &df[i][0]); 833793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler multiply( grad1, _df, grad1 ); 834793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat _x(n1+1, 1, CV_64F, &x[i-1][0]); 835793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler x[i-1][n1] = 1.; 836793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler gemm( _x, grad1, params.bpDWScale, dw[i], params.bpMomentScale, dw[i] ); 837793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler add( weights[i], dw[i], weights[i] ); 838793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( i > 1 ) 839793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 840793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat grad2(1, n1, CV_64F, buf[i&1]); 841793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat _w = weights[i].rowRange(0, n1); 842793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler gemm( grad1, _w, 1, noArray(), 0, grad2, GEMM_2_T ); 843793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler grad1 = grad2; 844793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 845793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 846793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 847793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 848793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler iter /= count; 849793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return iter; 850793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 851793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 852793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler struct RPropLoop : public ParallelLoopBody 853793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 854793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler RPropLoop(ANN_MLPImpl* _ann, 855793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const Mat& _inputs, const Mat& _outputs, const Mat& _sw, 856793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int _dcount0, vector<Mat>& _dEdw, double* _E) 857793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 858793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler ann = _ann; 859793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler inputs = _inputs; 860793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler outputs = _outputs; 861793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler sw = _sw.ptr<double>(); 862793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dcount0 = _dcount0; 863793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dEdw = &_dEdw; 864793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler pE = _E; 865793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 866793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 867793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler ANN_MLPImpl* ann; 868793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler vector<Mat>* dEdw; 869793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat inputs, outputs; 870793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const double* sw; 871793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int dcount0; 872793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* pE; 873793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 874793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void operator()( const Range& range ) const 875793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 876793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double inv_count = 1./inputs.rows; 877793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int ivcount = ann->layer_sizes.front(); 878793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int ovcount = ann->layer_sizes.back(); 879793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int itype = inputs.type(), otype = outputs.type(); 880793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int count = inputs.rows; 881793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int i, j, k, l_count = ann->layer_count(); 882793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler vector<vector<double> > x(l_count); 883793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler vector<vector<double> > df(l_count); 884793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler vector<double> _buf(ann->max_lsize*dcount0*2); 885793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* buf[] = { &_buf[0], &_buf[ann->max_lsize*dcount0] }; 886793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double E = 0; 887793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 888793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < l_count; i++ ) 889793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 890793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler x[i].resize(ann->layer_sizes[i]*dcount0); 891793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler df[i].resize(ann->layer_sizes[i]*dcount0); 892793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 893793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 894793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( int si = range.start; si < range.end; si++ ) 895793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 896793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int i0 = si*dcount0, i1 = std::min((si + 1)*dcount0, count); 897793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int dcount = i1 - i0; 898793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const double* w = ann->weights[0].ptr<double>(); 899793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 900793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler // grab and preprocess input data 901793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < dcount; i++ ) 902793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 903793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const uchar* x0data_p = inputs.ptr(i0 + i); 904793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const float* x0data_f = (const float*)x0data_p; 905793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const double* x0data_d = (const double*)x0data_p; 906793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 907793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* xdata = &x[0][i*ivcount]; 908793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( j = 0; j < ivcount; j++ ) 909793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler xdata[j] = (itype == CV_32F ? (double)x0data_f[j] : x0data_d[j])*w[j*2] + w[j*2+1]; 910793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 911793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat x1(dcount, ivcount, CV_64F, &x[0][0]); 912793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 913793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler // forward pass, compute y[i]=w*x[i-1], x[i]=f(y[i]), df[i]=f'(y[i]) 914793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 1; i < l_count; i++ ) 915793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 916793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat x2( dcount, ann->layer_sizes[i], CV_64F, &x[i][0] ); 917793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat _w = ann->weights[i].rowRange(0, x1.cols); 918793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler gemm( x1, _w, 1, noArray(), 0, x2 ); 919793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat _df( x2.size(), CV_64F, &df[i][0] ); 920793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler ann->calc_activ_func_deriv( x2, _df, ann->weights[i] ); 921793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler x1 = x2; 922793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 923793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 924793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat grad1(dcount, ovcount, CV_64F, buf[l_count & 1]); 925793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 926793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler w = ann->weights[l_count+1].ptr<double>(); 927793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 928793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler // calculate error 929793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < dcount; i++ ) 930793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 931793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const uchar* udata_p = outputs.ptr(i0+i); 932793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const float* udata_f = (const float*)udata_p; 933793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const double* udata_d = (const double*)udata_p; 934793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 935793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const double* xdata = &x[l_count-1][i*ovcount]; 936793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* gdata = grad1.ptr<double>(i); 937793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double sweight = sw ? sw[si+i] : inv_count, E1 = 0; 938793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 939793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( j = 0; j < ovcount; j++ ) 940793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 941793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double t = (otype == CV_32F ? (double)udata_f[j] : udata_d[j])*w[j*2] + w[j*2+1] - xdata[j]; 942793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler gdata[j] = t*sweight; 943793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler E1 += t*t; 944793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 945793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler E += sweight*E1; 946793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 947793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 948793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = l_count-1; i > 0; i-- ) 949793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 950793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int n1 = ann->layer_sizes[i-1], n2 = ann->layer_sizes[i]; 951793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat _df(dcount, n2, CV_64F, &df[i][0]); 952793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler multiply(grad1, _df, grad1); 953793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 954793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 955793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler AutoLock lock(ann->mtx); 956793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat _dEdw = dEdw->at(i).rowRange(0, n1); 957793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler x1 = Mat(dcount, n1, CV_64F, &x[i-1][0]); 958793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler gemm(x1, grad1, 1, _dEdw, 1, _dEdw, GEMM_1_T); 959793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 960793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler // update bias part of dEdw 961793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* dst = dEdw->at(i).ptr<double>(n1); 962793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( k = 0; k < dcount; k++ ) 963793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 964793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const double* src = grad1.ptr<double>(k); 965793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( j = 0; j < n2; j++ ) 966793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dst[j] += src[j]; 967793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 968793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 969793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 970793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat grad2( dcount, n1, CV_64F, buf[i&1] ); 971793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( i > 1 ) 972793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 973793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat _w = ann->weights[i].rowRange(0, n1); 974793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler gemm(grad1, _w, 1, noArray(), 0, grad2, GEMM_2_T); 975793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 976793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler grad1 = grad2; 977793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 978793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 979793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 980793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler AutoLock lock(ann->mtx); 981793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler *pE += E; 982793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 983793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 984793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler }; 985793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 986793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int train_rprop( const Mat& inputs, const Mat& outputs, const Mat& _sw, TermCriteria termCrit ) 987793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 988793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const int max_buf_size = 1 << 16; 989793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int i, iter = -1, count = inputs.rows; 990793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 991793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double prev_E = DBL_MAX*0.5; 992793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 993793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int max_iter = termCrit.maxCount; 994793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double epsilon = termCrit.epsilon; 995793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double dw_plus = params.rpDWPlus; 996793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double dw_minus = params.rpDWMinus; 997793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double dw_min = params.rpDWMin; 998793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double dw_max = params.rpDWMax; 999793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1000793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int l_count = layer_count(); 1001793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1002793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler // allocate buffers 1003793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler vector<Mat> dw(l_count), dEdw(l_count), prev_dEdw_sign(l_count); 1004793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1005793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int total = 0; 1006793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < l_count; i++ ) 1007793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1008793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler total += layer_sizes[i]; 1009793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dw[i].create(weights[i].size(), CV_64F); 1010793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dw[i].setTo(Scalar::all(params.rpDW0)); 1011793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler prev_dEdw_sign[i] = Mat::zeros(weights[i].size(), CV_8S); 1012793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dEdw[i] = Mat::zeros(weights[i].size(), CV_64F); 1013793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1014793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1015793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int dcount0 = max_buf_size/(2*total); 1016793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dcount0 = std::max( dcount0, 1 ); 1017793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dcount0 = std::min( dcount0, count ); 1018793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int chunk_count = (count + dcount0 - 1)/dcount0; 1019793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1020793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler // run rprop loop 1021793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler /* 1022793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler y_i(t) = w_i(t)*x_{i-1}(t) 1023793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler x_i(t) = f(y_i(t)) 1024793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler E = sum_over_all_samples(1/2*||u - x_N||^2) 1025793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler grad_N = (x_N - u)*f'(y_i) 1026793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1027793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler std::min(dw_i{jk}(t)*dw_plus, dw_max), if dE/dw_i{jk}(t)*dE/dw_i{jk}(t-1) > 0 1028793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dw_i{jk}(t) = std::max(dw_i{jk}(t)*dw_minus, dw_min), if dE/dw_i{jk}(t)*dE/dw_i{jk}(t-1) < 0 1029793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dw_i{jk}(t-1) else 1030793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1031793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if (dE/dw_i{jk}(t)*dE/dw_i{jk}(t-1) < 0) 1032793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dE/dw_i{jk}(t)<-0 1033793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler else 1034793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler w_i{jk}(t+1) = w_i{jk}(t) + dw_i{jk}(t) 1035793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler grad_{i-1}(t) = w_i^t(t)*grad_i(t) 1036793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler */ 1037793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( iter = 0; iter < max_iter; iter++ ) 1038793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1039793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double E = 0; 1040793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1041793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < l_count; i++ ) 1042793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dEdw[i].setTo(Scalar::all(0)); 1043793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1044793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler // first, iterate through all the samples and compute dEdw 1045793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler RPropLoop invoker(this, inputs, outputs, _sw, dcount0, dEdw, &E); 1046793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler parallel_for_(Range(0, chunk_count), invoker); 1047793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler //invoker(Range(0, chunk_count)); 1048793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1049793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler // now update weights 1050793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 1; i < l_count; i++ ) 1051793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1052793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int n1 = layer_sizes[i-1], n2 = layer_sizes[i]; 1053793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( int k = 0; k <= n1; k++ ) 1054793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1055793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_Assert(weights[i].size() == Size(n2, n1+1)); 1056793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* wk = weights[i].ptr<double>(k); 1057793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* dwk = dw[i].ptr<double>(k); 1058793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double* dEdwk = dEdw[i].ptr<double>(k); 1059793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler schar* prevEk = prev_dEdw_sign[i].ptr<schar>(k); 1060793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1061793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( int j = 0; j < n2; j++ ) 1062793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1063793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double Eval = dEdwk[j]; 1064793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double dval = dwk[j]; 1065793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double wval = wk[j]; 1066793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int s = CV_SIGN(Eval); 1067793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int ss = prevEk[j]*s; 1068793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( ss > 0 ) 1069793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1070793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dval *= dw_plus; 1071793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dval = std::min( dval, dw_max ); 1072793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dwk[j] = dval; 1073793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler wk[j] = wval + dval*s; 1074793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1075793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler else if( ss < 0 ) 1076793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1077793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dval *= dw_minus; 1078793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dval = std::max( dval, dw_min ); 1079793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler prevEk[j] = 0; 1080793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dwk[j] = dval; 1081793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler wk[j] = wval + dval*s; 1082793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1083793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler else 1084793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1085793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler prevEk[j] = (schar)s; 1086793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler wk[j] = wval + dval*s; 1087793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1088793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler dEdwk[j] = 0.; 1089793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1090793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1091793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1092793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1093793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler //printf("%d. E = %g\n", iter, E); 1094793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( fabs(prev_E - E) < epsilon ) 1095793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler break; 1096793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler prev_E = E; 1097793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1098793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1099793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return iter; 1100793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1101793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1102793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void write_params( FileStorage& fs ) const 1103793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1104793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const char* activ_func_name = activ_func == IDENTITY ? "IDENTITY" : 1105793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler activ_func == SIGMOID_SYM ? "SIGMOID_SYM" : 1106793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler activ_func == GAUSSIAN ? "GAUSSIAN" : 0; 1107793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1108793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( activ_func_name ) 1109793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "activation_function" << activ_func_name; 1110793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler else 1111793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "activation_function_id" << activ_func; 1112793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1113793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( activ_func != IDENTITY ) 1114793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1115793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "f_param1" << f_param1; 1116793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "f_param2" << f_param2; 1117793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1118793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1119793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "min_val" << min_val << "max_val" << max_val << "min_val1" << min_val1 << "max_val1" << max_val1; 1120793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1121793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "training_params" << "{"; 1122793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( params.trainMethod == ANN_MLP::BACKPROP ) 1123793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1124793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "train_method" << "BACKPROP"; 1125793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "dw_scale" << params.bpDWScale; 1126793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "moment_scale" << params.bpMomentScale; 1127793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1128793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler else if( params.trainMethod == ANN_MLP::RPROP ) 1129793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1130793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "train_method" << "RPROP"; 1131793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "dw0" << params.rpDW0; 1132793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "dw_plus" << params.rpDWPlus; 1133793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "dw_minus" << params.rpDWMinus; 1134793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "dw_min" << params.rpDWMin; 1135793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "dw_max" << params.rpDWMax; 1136793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1137793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler else 1138793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_Error(CV_StsError, "Unknown training method"); 1139793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1140793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "term_criteria" << "{"; 1141793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( params.termCrit.type & TermCriteria::EPS ) 1142793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "epsilon" << params.termCrit.epsilon; 1143793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( params.termCrit.type & TermCriteria::COUNT ) 1144793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "iterations" << params.termCrit.maxCount; 1145793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "}" << "}"; 1146793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1147793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1148793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void write( FileStorage& fs ) const 1149793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1150793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( layer_sizes.empty() ) 1151793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return; 1152793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int i, l_count = layer_count(); 1153793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1154793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "layer_sizes" << layer_sizes; 1155793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1156793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler write_params( fs ); 1157793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1158793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler size_t esz = weights[0].elemSize(); 1159793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1160793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "input_scale" << "["; 1161793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs.writeRaw("d", weights[0].ptr(), weights[0].total()*esz); 1162793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1163793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "]" << "output_scale" << "["; 1164793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs.writeRaw("d", weights[l_count].ptr(), weights[l_count].total()*esz); 1165793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1166793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "]" << "inv_output_scale" << "["; 1167793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs.writeRaw("d", weights[l_count+1].ptr(), weights[l_count+1].total()*esz); 1168793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1169793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "]" << "weights" << "["; 1170793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 1; i < l_count; i++ ) 1171793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1172793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "["; 1173793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs.writeRaw("d", weights[i].ptr(), weights[i].total()*esz); 1174793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "]"; 1175793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1176793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "]"; 1177793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1178793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1179793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void read_params( const FileNode& fn ) 1180793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1181793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler String activ_func_name = (String)fn["activation_function"]; 1182793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( !activ_func_name.empty() ) 1183793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1184793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler activ_func = activ_func_name == "SIGMOID_SYM" ? SIGMOID_SYM : 1185793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler activ_func_name == "IDENTITY" ? IDENTITY : 1186793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler activ_func_name == "GAUSSIAN" ? GAUSSIAN : -1; 1187793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_Assert( activ_func >= 0 ); 1188793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1189793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler else 1190793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler activ_func = (int)fn["activation_function_id"]; 1191793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1192793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler f_param1 = (double)fn["f_param1"]; 1193793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler f_param2 = (double)fn["f_param2"]; 1194793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1195793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler setActivationFunction( activ_func, f_param1, f_param2 ); 1196793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1197793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler min_val = (double)fn["min_val"]; 1198793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler max_val = (double)fn["max_val"]; 1199793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler min_val1 = (double)fn["min_val1"]; 1200793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler max_val1 = (double)fn["max_val1"]; 1201793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1202793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler FileNode tpn = fn["training_params"]; 1203793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params = AnnParams(); 1204793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1205793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( !tpn.empty() ) 1206793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1207793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler String tmethod_name = (String)tpn["train_method"]; 1208793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1209793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( tmethod_name == "BACKPROP" ) 1210793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1211793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.trainMethod = ANN_MLP::BACKPROP; 1212793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.bpDWScale = (double)tpn["dw_scale"]; 1213793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.bpMomentScale = (double)tpn["moment_scale"]; 1214793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1215793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler else if( tmethod_name == "RPROP" ) 1216793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1217793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.trainMethod = ANN_MLP::RPROP; 1218793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.rpDW0 = (double)tpn["dw0"]; 1219793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.rpDWPlus = (double)tpn["dw_plus"]; 1220793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.rpDWMinus = (double)tpn["dw_minus"]; 1221793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.rpDWMin = (double)tpn["dw_min"]; 1222793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.rpDWMax = (double)tpn["dw_max"]; 1223793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1224793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler else 1225793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_Error(CV_StsParseError, "Unknown training method (should be BACKPROP or RPROP)"); 1226793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1227793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler FileNode tcn = tpn["term_criteria"]; 1228793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( !tcn.empty() ) 1229793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1230793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler FileNode tcn_e = tcn["epsilon"]; 1231793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler FileNode tcn_i = tcn["iterations"]; 1232793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.termCrit.type = 0; 1233793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( !tcn_e.empty() ) 1234793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1235793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.termCrit.type |= TermCriteria::EPS; 1236793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.termCrit.epsilon = (double)tcn_e; 1237793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1238793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( !tcn_i.empty() ) 1239793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1240793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.termCrit.type |= TermCriteria::COUNT; 1241793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.termCrit.maxCount = (int)tcn_i; 1242793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1243793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1244793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1245793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1246793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1247793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void read( const FileNode& fn ) 1248793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1249793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler clear(); 1250793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1251793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler vector<int> _layer_sizes; 1252793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler readVectorOrMat(fn["layer_sizes"], _layer_sizes); 1253793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler setLayerSizes( _layer_sizes ); 1254793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1255793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int i, l_count = layer_count(); 1256793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler read_params(fn); 1257793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1258793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler size_t esz = weights[0].elemSize(); 1259793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1260793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler FileNode w = fn["input_scale"]; 1261793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler w.readRaw("d", weights[0].ptr(), weights[0].total()*esz); 1262793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1263793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler w = fn["output_scale"]; 1264793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler w.readRaw("d", weights[l_count].ptr(), weights[l_count].total()*esz); 1265793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1266793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler w = fn["inv_output_scale"]; 1267793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler w.readRaw("d", weights[l_count+1].ptr(), weights[l_count+1].total()*esz); 1268793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1269793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler FileNodeIterator w_it = fn["weights"].begin(); 1270793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1271793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 1; i < l_count; i++, ++w_it ) 1272793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler (*w_it).readRaw("d", weights[i].ptr(), weights[i].total()*esz); 1273793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler trained = true; 1274793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1275793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1276793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat getWeights(int layerIdx) const 1277793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1278793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_Assert( 0 <= layerIdx && layerIdx < (int)weights.size() ); 1279793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return weights[layerIdx]; 1280793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1281793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1282793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler bool isTrained() const 1283793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1284793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return trained; 1285793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1286793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1287793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler bool isClassifier() const 1288793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1289793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return false; 1290793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1291793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1292793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int getVarCount() const 1293793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1294793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return layer_sizes.empty() ? 0 : layer_sizes[0]; 1295793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1296793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1297793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler String getDefaultName() const 1298793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 1299793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return "opencv_ml_ann_mlp"; 1300793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 1301793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1302793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler vector<int> layer_sizes; 1303793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler vector<Mat> weights; 1304793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double f_param1, f_param2; 1305793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double min_val, max_val, min_val1, max_val1; 1306793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int activ_func; 1307793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int max_lsize, max_buf_sz; 1308793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler AnnParams params; 1309793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler RNG rng; 1310793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mutex mtx; 1311793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler bool trained; 1312793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler}; 1313793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1314793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1315793ee12c6df9cad3806238d32528c49a3ff9331dNoah PreslerPtr<ANN_MLP> ANN_MLP::create() 1316793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler{ 1317793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return makePtr<ANN_MLPImpl>(); 1318793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler} 1319793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1320793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler}} 1321793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 1322793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler/* End of file. */ 1323