16acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/*M/////////////////////////////////////////////////////////////////////////////////////// 26acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 36acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 46acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 56acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// By downloading, copying, installing or using the software you agree to this license. 66acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// If you do not agree to this license, do not download, install, 76acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// copy or use the software. 86acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 96acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Intel License Agreement 116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Copyright (C) 2000, Intel Corporation, all rights reserved. 136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Third party copyrights are property of their respective owners. 146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Redistribution and use in source and binary forms, with or without modification, 166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// are permitted provided that the following conditions are met: 176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// * Redistribution's of source code must retain the above copyright notice, 196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// this list of conditions and the following disclaimer. 206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// * Redistribution's in binary form must reproduce the above copyright notice, 226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// this list of conditions and the following disclaimer in the documentation 236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// and/or other materials provided with the distribution. 246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// * The name of Intel Corporation may not be used to endorse or promote products 266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// derived from this software without specific prior written permission. 276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// This software is provided by the copyright holders and contributors "as is" and 296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// any express or implied warranties, including, but not limited to, the implied 306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// warranties of merchantability and fitness for a particular purpose are disclaimed. 316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// In no event shall the Intel Corporation or contributors be liable for any direct, 326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// indirect, incidental, special, exemplary, or consequential damages 336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// (including, but not limited to, procurement of substitute goods or services; 346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// loss of use, data, or profits; or business interruption) however caused 356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// and on any theory of liability, whether in contract, strict liability, 366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// or tort (including negligence or otherwise) arising in any way out of 376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// the use of this software, even if advised of the possibility of such damage. 386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//M*/ 406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#include "_ml.h" 426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#if 0 446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************\ 456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn* Auxilary functions declarations * 466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn\****************************************************************************************/ 476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/*---------------------- functions for the CNN classifier ------------------------------*/ 486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic float icvCNNModelPredict( 496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvStatModel* cnn_model, 506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* image, 516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* probs CV_DEFAULT(0) ); 526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNModelUpdate( 546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvStatModel* cnn_model, const CvMat* images, int tflag, 556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* responses, const CvStatModelParams* params, 566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* CV_DEFAULT(0), const CvMat* sample_idx CV_DEFAULT(0), 576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* CV_DEFAULT(0), const CvMat* CV_DEFAULT(0)); 586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNModelRelease( CvStatModel** cnn_model ); 606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvTrainCNNetwork( CvCNNetwork* network, 626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float** images, 636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* responses, 646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* etalons, 656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int grad_estim_type, 666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int max_iter, 676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int start_iter ); 686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/*------------------------- functions for the CNN network ------------------------------*/ 706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNetworkAddLayer( CvCNNetwork* network, CvCNNLayer* layer ); 716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNetworkRelease( CvCNNetwork** network ); 726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/* In all layer functions we denote input by X and output by Y, where 746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn X and Y are column-vectors, so that 756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn length(X)==<n_input_planes>*<input_height>*<input_width>, 766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn length(Y)==<n_output_planes>*<output_height>*<output_width>. 776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn*/ 786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/*------------------------ functions for convolutional layer ---------------------------*/ 796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNConvolutionRelease( CvCNNLayer** p_layer ); 806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNConvolutionForward( CvCNNLayer* layer, const CvMat* X, CvMat* Y ); 826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNConvolutionBackward( CvCNNLayer* layer, int t, 846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* X, const CvMat* dE_dY, CvMat* dE_dX ); 856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/*------------------------ functions for sub-sampling layer ----------------------------*/ 876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNSubSamplingRelease( CvCNNLayer** p_layer ); 886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNSubSamplingForward( CvCNNLayer* layer, const CvMat* X, CvMat* Y ); 906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNSubSamplingBackward( CvCNNLayer* layer, int t, 926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* X, const CvMat* dE_dY, CvMat* dE_dX ); 936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/*------------------------ functions for full connected layer --------------------------*/ 956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNFullConnectRelease( CvCNNLayer** p_layer ); 966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNFullConnectForward( CvCNNLayer* layer, const CvMat* X, CvMat* Y ); 986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNFullConnectBackward( CvCNNLayer* layer, int, 1006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat*, const CvMat* dE_dY, CvMat* dE_dX ); 1016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************\ 1036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn* Functions implementations * 1046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn\****************************************************************************************/ 1056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#define ICV_CHECK_CNN_NETWORK(network) \ 1076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ \ 1086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNLayer* first_layer, *layer, *last_layer; \ 1096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n_layers, i; \ 1106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !network ) \ 1116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsNullPtr, \ 1126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "Null <network> pointer. Network must be created by user." ); \ 1136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn n_layers = network->n_layers; \ 1146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn first_layer = last_layer = network->layers; \ 1156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0, layer = first_layer; i < n_layers && layer; i++ ) \ 1166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { \ 1176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !ICV_IS_CNN_LAYER(layer) ) \ 1186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsNullPtr, "Invalid network" ); \ 1196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn last_layer = layer; \ 1206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer = layer->next_layer; \ 1216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } \ 1226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn \ 1236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( i == 0 || i != n_layers || first_layer->prev_layer || layer ) \ 1246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsNullPtr, "Invalid network" ); \ 1256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn \ 1266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( first_layer->n_input_planes != 1 ) \ 1276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "First layer must contain only one input plane" ); \ 1286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn \ 1296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( img_size != first_layer->input_height*first_layer->input_width ) \ 1306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid input sizes of the first layer" ); \ 1316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn \ 1326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params->etalons->cols != last_layer->n_output_planes* \ 1336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn last_layer->output_height*last_layer->output_width ) \ 1346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid output sizes of the last layer" ); \ 1356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 1366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#define ICV_CHECK_CNN_MODEL_PARAMS(params) \ 1386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ \ 1396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !params ) \ 1406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsNullPtr, "Null <params> pointer" ); \ 1416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn \ 1426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !ICV_IS_MAT_OF_TYPE(params->etalons, CV_32FC1) ) \ 1436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "<etalons> must be CV_32FC1 type" ); \ 1446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params->etalons->rows != cnn_model->cls_labels->cols ) \ 1456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid <etalons> size" ); \ 1466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn \ 1476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params->grad_estim_type != CV_CNN_GRAD_ESTIM_RANDOM && \ 1486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params->grad_estim_type != CV_CNN_GRAD_ESTIM_BY_WORST_IMG ) \ 1496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid <grad_estim_type>" ); \ 1506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn \ 1516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params->start_iter < 0 ) \ 1526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Parameter <start_iter> must be positive or zero" ); \ 1536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn \ 1546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params->max_iter < 1 ) \ 1556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params->max_iter = 1; \ 1566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 1576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************\ 1596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn* Classifier functions * 1606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn\****************************************************************************************/ 1616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennML_IMPL CvStatModel* 1626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RenncvTrainCNNClassifier( const CvMat* _train_data, int tflag, 1636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* _responses, 1646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvStatModelParams* _params, 1656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat*, const CvMat* _sample_idx, const CvMat*, const CvMat* ) 1666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 1676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNStatModel* cnn_model = 0; 1686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float** out_train_data = 0; 1696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* responses = 0; 1706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME("cvTrainCNNClassifier"); 1726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 1736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n_images; 1756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int img_size; 1766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNStatModelParams* params = (CvCNNStatModelParams*)_params; 1776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cnn_model = (CvCNNStatModel*)cvCreateStatModel( 1796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_STAT_MODEL_MAGIC_VAL|CV_CNN_MAGIC_VAL, sizeof(CvCNNStatModel), 1806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn icvCNNModelRelease, icvCNNModelPredict, icvCNNModelUpdate )); 1816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvPrepareTrainData( "cvTrainCNNClassifier", 1836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _train_data, tflag, _responses, CV_VAR_CATEGORICAL, 1846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 0, _sample_idx, false, &out_train_data, 1856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn &n_images, &img_size, &img_size, &responses, 1866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn &cnn_model->cls_labels, 0 )); 1876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ICV_CHECK_CNN_MODEL_PARAMS(params); 1896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ICV_CHECK_CNN_NETWORK(params->network); 1906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cnn_model->network = params->network; 1926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cnn_model->etalons = (CvMat*)cvClone( params->etalons )); 1936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( icvTrainCNNetwork( cnn_model->network, out_train_data, responses, 1956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cnn_model->etalons, params->grad_estim_type, params->max_iter, 1966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params->start_iter )); 1976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 1996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cvGetErrStatus() < 0 && cnn_model ) 2016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cnn_model->release( (CvStatModel**)&cnn_model ); 2036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &out_train_data ); 2056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &responses ); 2066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return (CvStatModel*)cnn_model; 2086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 2096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/ 2116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvTrainCNNetwork( CvCNNetwork* network, 2126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float** images, 2136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* responses, 2146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* etalons, 2156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int grad_estim_type, 2166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int max_iter, 2176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int start_iter ) 2186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 2196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat** X = 0; 2206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat** dE_dX = 0; 2216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int n_layers = network->n_layers; 2226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int k; 2236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME("icvTrainCNNetwork"); 2256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 2266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNLayer* first_layer = network->layers; 2286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int img_height = first_layer->input_height; 2296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int img_width = first_layer->input_width; 2306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int img_size = img_width*img_height; 2316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int n_images = responses->cols; 2326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat image = cvMat( 1, img_size, CV_32FC1 ); 2336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNLayer* layer; 2346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n; 2356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvRNG rng = cvRNG(-1); 2366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(X = (CvMat**)cvAlloc( (n_layers+1)*sizeof(CvMat*) )); 2386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(dE_dX = (CvMat**)cvAlloc( (n_layers+1)*sizeof(CvMat*) )); 2396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn memset( X, 0, (n_layers+1)*sizeof(CvMat*) ); 2406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn memset( dE_dX, 0, (n_layers+1)*sizeof(CvMat*) ); 2416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(X[0] = cvCreateMat( img_height*img_width,1,CV_32FC1 )); 2436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(dE_dX[0] = cvCreateMat( 1, X[0]->rows, CV_32FC1 )); 2446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0, layer = first_layer; k < n_layers; k++, layer = layer->next_layer ) 2456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(X[k+1] = cvCreateMat( layer->n_output_planes*layer->output_height* 2476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->output_width, 1, CV_32FC1 )); 2486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(dE_dX[k+1] = cvCreateMat( 1, X[k+1]->rows, CV_32FC1 )); 2496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( n = 1; n <= max_iter; n++ ) 2526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float loss, max_loss = 0; 2546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i; 2556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int worst_img_idx = -1; 2566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* right_etal_idx = responses->data.i; 2576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat etalon; 2586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // Find the worst image (which produces the greatest loss) or use the random image 2606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( grad_estim_type == CV_CNN_GRAD_ESTIM_BY_WORST_IMG ) 2616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n_images; i++, right_etal_idx++ ) 2636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn image.data.fl = (float*)images[i]; 2656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvTranspose( &image, X[0] ); 2666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0, layer = first_layer; k < n_layers; k++, layer = layer->next_layer ) 2686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(layer->forward( layer, X[k], X[k+1] )); 2696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvTranspose( X[n_layers], dE_dX[n_layers] ); 2716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGetRow( etalons, &etalon, *right_etal_idx ); 2726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn loss = (float)cvNorm( dE_dX[n_layers], &etalon ); 2736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( loss > max_loss ) 2746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_loss = loss; 2766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn worst_img_idx = i; 2776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 2816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn worst_img_idx = cvRandInt(&rng) % n_images; 2826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // Train network on the worst image 2846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // 1) Compute the network output on the <image> 2856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn image.data.fl = (float*)images[worst_img_idx]; 2866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvTranspose( &image, X[0] )); 2876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0, layer = first_layer; k < n_layers - 1; k++, layer = layer->next_layer ) 2896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(layer->forward( layer, X[k], X[k+1] )); 2906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(layer->forward( layer, X[k], X[k+1] )); 2916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // 2) Compute the gradient 2936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvTranspose( X[n_layers], dE_dX[n_layers] ); 2946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGetRow( etalons, &etalon, responses->data.i[worst_img_idx] ); 2956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSub( dE_dX[n_layers], &etalon, dE_dX[n_layers] ); 2966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // 3) Update weights by the gradient descent 2986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = n_layers; k > 0; k--, layer = layer->prev_layer ) 2996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(layer->backward( layer, n + start_iter, X[k-1], dE_dX[k], dE_dX[k-1] )); 3006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 3036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k <= n_layers; k++ ) 3056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &X[k] ); 3076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &dE_dX[k] ); 3086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &X ); 3106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &dE_dX ); 3116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 3126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/ 3146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic float icvCNNModelPredict( const CvStatModel* model, 3156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* _image, 3166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* probs ) 3176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 3186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat** X = 0; 3196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float* img_data = 0; 3206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n_layers = 0; 3216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int best_etal_idx = -1; 3226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int k; 3236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME("icvCNNModelPredict"); 3256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 3266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNStatModel* cnn_model = (CvCNNStatModel*)model; 3286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNLayer* first_layer, *layer = 0; 3296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int img_height, img_width, img_size; 3306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int nclasses, i; 3316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float loss, min_loss = FLT_MAX; 3326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float* probs_data; 3336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat etalon, image; 3346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_IS_CNN(model) ) 3366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid model" ); 3376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn nclasses = cnn_model->cls_labels->cols; 3396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn n_layers = cnn_model->network->n_layers; 3406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn first_layer = cnn_model->network->layers; 3416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn img_height = first_layer->input_height; 3426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn img_width = first_layer->input_width; 3436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn img_size = img_height*img_width; 3446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvPreparePredictData( _image, img_size, 0, nclasses, probs, &img_data ); 3466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(X = (CvMat**)cvAlloc( (n_layers+1)*sizeof(CvMat*) )); 3486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn memset( X, 0, (n_layers+1)*sizeof(CvMat*) ); 3496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(X[0] = cvCreateMat( img_size,1,CV_32FC1 )); 3516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0, layer = first_layer; k < n_layers; k++, layer = layer->next_layer ) 3526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(X[k+1] = cvCreateMat( layer->n_output_planes*layer->output_height* 3546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->output_width, 1, CV_32FC1 )); 3556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn image = cvMat( 1, img_size, CV_32FC1, img_data ); 3586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvTranspose( &image, X[0] ); 3596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0, layer = first_layer; k < n_layers; k++, layer = layer->next_layer ) 3606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(layer->forward( layer, X[k], X[k+1] )); 3616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn probs_data = probs ? probs->data.fl : 0; 3636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn etalon = cvMat( cnn_model->etalons->cols, 1, CV_32FC1, cnn_model->etalons->data.fl ); 3646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < nclasses; i++, etalon.data.fl += cnn_model->etalons->cols ) 3656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn loss = (float)cvNorm( X[n_layers], &etalon ); 3676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( loss < min_loss ) 3686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn min_loss = loss; 3706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_etal_idx = i; 3716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( probs ) 3736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn *probs_data++ = -loss; 3746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( probs ) 3776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvExp( probs, probs ); 3796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvScalar sum = cvSum( probs ); 3806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvConvertScale( probs, probs, 1./sum.val[0] ); 3816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 3846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k <= n_layers; k++ ) 3866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &X[k] ); 3876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &X ); 3886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( img_data != _image->data.fl ) 3896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &img_data ); 3906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return ((float) ((CvCNNStatModel*)model)->cls_labels->data.i[best_etal_idx]); 3926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 3936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/ 3956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNModelUpdate( 3966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvStatModel* _cnn_model, const CvMat* _train_data, int tflag, 3976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* _responses, const CvStatModelParams* _params, 3986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat*, const CvMat* _sample_idx, 3996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat*, const CvMat* ) 4006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 4016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float** out_train_data = 0; 4026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* responses = 0; 4036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* cls_labels = 0; 4046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME("icvCNNModelUpdate"); 4066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 4076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n_images, img_size, i; 4096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNStatModelParams* params = (CvCNNStatModelParams*)_params; 4106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNStatModel* cnn_model = (CvCNNStatModel*)_cnn_model; 4116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_IS_CNN(cnn_model) ) 4136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid model" ); 4146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvPrepareTrainData( "cvTrainCNNClassifier", 4166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _train_data, tflag, _responses, CV_VAR_CATEGORICAL, 4176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 0, _sample_idx, false, &out_train_data, 4186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn &n_images, &img_size, &img_size, &responses, 4196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn &cls_labels, 0, 0 )); 4206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ICV_CHECK_CNN_MODEL_PARAMS(params); 4226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // Number of classes must be the same as when classifiers was created 4246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_ARE_SIZES_EQ(cls_labels, cnn_model->cls_labels) ) 4256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Number of classes must be left unchanged" ); 4266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < cls_labels->cols; i++ ) 4276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cls_labels->data.i[i] != cnn_model->cls_labels->data.i[i] ) 4296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Number of classes must be left unchanged" ); 4306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( icvTrainCNNetwork( cnn_model->network, out_train_data, responses, 4336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cnn_model->etalons, params->grad_estim_type, params->max_iter, 4346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params->start_iter )); 4356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 4376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &out_train_data ); 4396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &responses ); 4406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 4416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/ 4436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNModelRelease( CvStatModel** cnn_model ) 4446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 4456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME("icvCNNModelRelease"); 4466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 4476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNStatModel* cnn; 4496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !cnn_model ) 4506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsNullPtr, "Null double pointer" ); 4516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cnn = *(CvCNNStatModel**)cnn_model; 4536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &cnn->cls_labels ); 4556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &cnn->etalons ); 4566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cnn->network->release( &cnn->network ); 4576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &cnn ); 4596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 4616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 4636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************\ 4656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn* Network functions * 4666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn\****************************************************************************************/ 4676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennML_IMPL CvCNNetwork* cvCreateCNNetwork( CvCNNLayer* first_layer ) 4686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 4696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNetwork* network = 0; 4706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "cvCreateCNNetwork" ); 4726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 4736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !ICV_IS_CNN_LAYER(first_layer) ) 4756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid layer" ); 4766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(network = (CvCNNetwork*)cvAlloc( sizeof(CvCNNetwork) )); 4786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn memset( network, 0, sizeof(CvCNNetwork) ); 4796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn network->layers = first_layer; 4816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn network->n_layers = 1; 4826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn network->release = icvCNNetworkRelease; 4836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn network->add_layer = icvCNNetworkAddLayer; 4846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 4866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cvGetErrStatus() < 0 && network ) 4886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &network ); 4896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return network; 4916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 4936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/ 4956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNetworkAddLayer( CvCNNetwork* network, CvCNNLayer* layer ) 4966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 4976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "icvCNNetworkAddLayer" ); 4986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 4996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNLayer* prev_layer; 5016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( network == NULL ) 5036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsNullPtr, "Null <network> pointer" ); 5046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn prev_layer = network->layers; 5066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn while( prev_layer->next_layer ) 5076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn prev_layer = prev_layer->next_layer; 5086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ICV_IS_CNN_FULLCONNECT_LAYER(layer) ) 5106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( layer->n_input_planes != prev_layer->output_width*prev_layer->output_height* 5126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn prev_layer->n_output_planes ) 5136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Unmatched size of the new layer" ); 5146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( layer->input_height != 1 || layer->output_height != 1 || 5156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->input_width != 1 || layer->output_width != 1 ) 5166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid size of the new layer" ); 5176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( ICV_IS_CNN_CONVOLUTION_LAYER(layer) || ICV_IS_CNN_SUBSAMPLING_LAYER(layer) ) 5196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( prev_layer->n_output_planes != layer->n_input_planes || 5216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn prev_layer->output_height != layer->input_height || 5226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn prev_layer->output_width != layer->input_width ) 5236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Unmatched size of the new layer" ); 5246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 5266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid layer" ); 5276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->prev_layer = prev_layer; 5296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn prev_layer->next_layer = layer; 5306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn network->n_layers++; 5316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 5336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 5346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/ 5366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNetworkRelease( CvCNNetwork** network_pptr ) 5376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 5386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "icvReleaseCNNetwork" ); 5396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 5406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNetwork* network = 0; 5426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNLayer* layer = 0, *next_layer = 0; 5436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int k; 5446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( network_pptr == NULL ) 5466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Null double pointer" ); 5476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( *network_pptr == NULL ) 5486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return; 5496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn network = *network_pptr; 5516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer = network->layers; 5526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( layer == NULL ) 5536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "CNN is empty (does not contain any layer)" ); 5546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // k is the number of the layer to be deleted 5566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < network->n_layers && layer; k++ ) 5576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn next_layer = layer->next_layer; 5596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->release( &layer ); 5606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer = next_layer; 5616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( k != network->n_layers || layer) 5646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid network" ); 5656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &network ); 5676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 5696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 5706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************\ 5726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn* Layer functions * 5736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn\****************************************************************************************/ 5746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic CvCNNLayer* icvCreateCNNLayer( int layer_type, int header_size, 5756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n_input_planes, int input_height, int input_width, 5766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n_output_planes, int output_height, int output_width, 5776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float init_learn_rate, int learn_rate_decrease_type, 5786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNLayerRelease release, CvCNNLayerForward forward, CvCNNLayerBackward backward ) 5796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 5806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNLayer* layer = 0; 5816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME("icvCreateCNNLayer"); 5836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 5846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ASSERT( release && forward && backward ) 5866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ASSERT( header_size >= sizeof(CvCNNLayer) ) 5876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( n_input_planes < 1 || n_output_planes < 1 || 5896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn input_height < 1 || input_width < 1 || 5906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn output_height < 1 || output_width < 1 || 5916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn input_height < output_height || 5926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn input_width < output_width ) 5936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Incorrect input or output parameters" ); 5946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( init_learn_rate < FLT_EPSILON ) 5956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Initial learning rate must be positive" ); 5966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( learn_rate_decrease_type != CV_CNN_LEARN_RATE_DECREASE_HYPERBOLICALLY && 5976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn learn_rate_decrease_type != CV_CNN_LEARN_RATE_DECREASE_SQRT_INV && 5986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn learn_rate_decrease_type != CV_CNN_LEARN_RATE_DECREASE_LOG_INV ) 5996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid type of learning rate dynamics" ); 6006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(layer = (CvCNNLayer*)cvAlloc( header_size )); 6026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn memset( layer, 0, header_size ); 6036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->flags = ICV_CNN_LAYER|layer_type; 6056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ASSERT( ICV_IS_CNN_LAYER(layer) ) 6066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->n_input_planes = n_input_planes; 6086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->input_height = input_height; 6096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->input_width = input_width; 6106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->n_output_planes = n_output_planes; 6126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->output_height = output_height; 6136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->output_width = output_width; 6146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->init_learn_rate = init_learn_rate; 6166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->learn_rate_decrease_type = learn_rate_decrease_type; 6176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->release = release; 6196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->forward = forward; 6206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->backward = backward; 6216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 6236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cvGetErrStatus() < 0 && layer) 6256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &layer ); 6266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return layer; 6286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 6296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/ 6316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennML_IMPL CvCNNLayer* cvCreateCNNConvolutionLayer( 6326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n_input_planes, int input_height, int input_width, 6336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n_output_planes, int K, 6346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float init_learn_rate, int learn_rate_decrease_type, 6356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* connect_mask, CvMat* weights ) 6366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 6386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNConvolutionLayer* layer = 0; 6396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME("cvCreateCNNConvolutionLayer"); 6416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 6426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int output_height = input_height - K + 1; 6446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int output_width = input_width - K + 1; 6456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( K < 1 || init_learn_rate <= 0 ) 6476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Incorrect parameters" ); 6486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(layer = (CvCNNConvolutionLayer*)icvCreateCNNLayer( ICV_CNN_CONVOLUTION_LAYER, 6506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sizeof(CvCNNConvolutionLayer), n_input_planes, input_height, input_width, 6516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn n_output_planes, output_height, output_width, 6526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn init_learn_rate, learn_rate_decrease_type, 6536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn icvCNNConvolutionRelease, icvCNNConvolutionForward, icvCNNConvolutionBackward )); 6546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->K = K; 6566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(layer->weights = cvCreateMat( n_output_planes, K*K+1, CV_32FC1 )); 6576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(layer->connect_mask = cvCreateMat( n_output_planes, n_input_planes, CV_8UC1)); 6586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( weights ) 6606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !ICV_IS_MAT_OF_TYPE( weights, CV_32FC1 ) ) 6626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadSize, "Type of initial weights matrix must be CV_32FC1" ); 6636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_ARE_SIZES_EQ( weights, layer->weights ) ) 6646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadSize, "Invalid size of initial weights matrix" ); 6656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvCopy( weights, layer->weights )); 6666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 6686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvRNG rng = cvRNG( 0xFFFFFFFF ); 6706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvRandArr( &rng, layer->weights, CV_RAND_UNI, cvRealScalar(-1), cvRealScalar(1) ); 6716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( connect_mask ) 6746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !ICV_IS_MAT_OF_TYPE( connect_mask, CV_8UC1 ) ) 6766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadSize, "Type of connection matrix must be CV_32FC1" ); 6776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_ARE_SIZES_EQ( connect_mask, layer->connect_mask ) ) 6786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadSize, "Invalid size of connection matrix" ); 6796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvCopy( connect_mask, layer->connect_mask )); 6806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 6826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvSet( layer->connect_mask, cvRealScalar(1) )); 6836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 6856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cvGetErrStatus() < 0 && layer ) 6876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &layer->weights ); 6896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &layer->connect_mask ); 6906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &layer ); 6916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return (CvCNNLayer*)layer; 6946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 6956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/ 6976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennML_IMPL CvCNNLayer* cvCreateCNNSubSamplingLayer( 6986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n_input_planes, int input_height, int input_width, 6996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int sub_samp_scale, float a, float s, 7006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float init_learn_rate, int learn_rate_decrease_type, CvMat* weights ) 7016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 7036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNSubSamplingLayer* layer = 0; 7046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME("cvCreateCNNSubSamplingLayer"); 7066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 7076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int output_height = input_height/sub_samp_scale; 7096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int output_width = input_width/sub_samp_scale; 7106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int n_output_planes = n_input_planes; 7116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( sub_samp_scale < 1 || a <= 0 || s <= 0) 7136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Incorrect parameters" ); 7146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(layer = (CvCNNSubSamplingLayer*)icvCreateCNNLayer( ICV_CNN_SUBSAMPLING_LAYER, 7166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sizeof(CvCNNSubSamplingLayer), n_input_planes, input_height, input_width, 7176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn n_output_planes, output_height, output_width, 7186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn init_learn_rate, learn_rate_decrease_type, 7196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn icvCNNSubSamplingRelease, icvCNNSubSamplingForward, icvCNNSubSamplingBackward )); 7206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->sub_samp_scale = sub_samp_scale; 7226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->a = a; 7236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->s = s; 7246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(layer->sumX = 7266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvCreateMat( n_output_planes*output_width*output_height, 1, CV_32FC1 )); 7276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(layer->exp2ssumWX = 7286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvCreateMat( n_output_planes*output_width*output_height, 1, CV_32FC1 )); 7296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvZero( layer->sumX ); 7316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvZero( layer->exp2ssumWX ); 7326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(layer->weights = cvCreateMat( n_output_planes, 2, CV_32FC1 )); 7346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( weights ) 7356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !ICV_IS_MAT_OF_TYPE( weights, CV_32FC1 ) ) 7376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadSize, "Type of initial weights matrix must be CV_32FC1" ); 7386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_ARE_SIZES_EQ( weights, layer->weights ) ) 7396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadSize, "Invalid size of initial weights matrix" ); 7406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvCopy( weights, layer->weights )); 7416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 7436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvRNG rng = cvRNG( 0xFFFFFFFF ); 7456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvRandArr( &rng, layer->weights, CV_RAND_UNI, cvRealScalar(-1), cvRealScalar(1) ); 7466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 7496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cvGetErrStatus() < 0 && layer ) 7516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &layer->exp2ssumWX ); 7536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &layer ); 7546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return (CvCNNLayer*)layer; 7576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 7586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/ 7606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennML_IMPL CvCNNLayer* cvCreateCNNFullConnectLayer( 7616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n_inputs, int n_outputs, float a, float s, 7626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float init_learn_rate, int learn_rate_decrease_type, CvMat* weights ) 7636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 7646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNFullConnectLayer* layer = 0; 7656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME("cvCreateCNNFullConnectLayer"); 7676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 7686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( a <= 0 || s <= 0 || init_learn_rate <= 0) 7706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Incorrect parameters" ); 7716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(layer = (CvCNNFullConnectLayer*)icvCreateCNNLayer( ICV_CNN_FULLCONNECT_LAYER, 7736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sizeof(CvCNNFullConnectLayer), n_inputs, 1, 1, n_outputs, 1, 1, 7746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn init_learn_rate, learn_rate_decrease_type, 7756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn icvCNNFullConnectRelease, icvCNNFullConnectForward, icvCNNFullConnectBackward )); 7766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->a = a; 7786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->s = s; 7796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(layer->exp2ssumWX = cvCreateMat( n_outputs, 1, CV_32FC1 )); 7816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvZero( layer->exp2ssumWX ); 7826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(layer->weights = cvCreateMat( n_outputs, n_inputs+1, CV_32FC1 )); 7846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( weights ) 7856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !ICV_IS_MAT_OF_TYPE( weights, CV_32FC1 ) ) 7876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadSize, "Type of initial weights matrix must be CV_32FC1" ); 7886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_ARE_SIZES_EQ( weights, layer->weights ) ) 7896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadSize, "Invalid size of initial weights matrix" ); 7906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvCopy( weights, layer->weights )); 7916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 7936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvRNG rng = cvRNG( 0xFFFFFFFF ); 7956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvRandArr( &rng, layer->weights, CV_RAND_UNI, cvRealScalar(-1), cvRealScalar(1) ); 7966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 7996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cvGetErrStatus() < 0 && layer ) 8016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &layer->exp2ssumWX ); 8036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &layer->weights ); 8046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &layer ); 8056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return (CvCNNLayer*)layer; 8086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 8096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************\ 8126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn* Layer FORWARD functions * 8136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn\****************************************************************************************/ 8146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNConvolutionForward( CvCNNLayer* _layer, 8156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* X, 8166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* Y ) 8176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 8186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME("icvCNNConvolutionForward"); 8196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !ICV_IS_CNN_CONVOLUTION_LAYER(_layer) ) 8216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid layer" ); 8226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn {__BEGIN__; 8246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvCNNConvolutionLayer* layer = (CvCNNConvolutionLayer*) _layer; 8266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int K = layer->K; 8286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int n_weights_for_Yplane = K*K + 1; 8296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int nXplanes = layer->n_input_planes; 8316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int Xheight = layer->input_height; 8326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int Xwidth = layer->input_width ; 8336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int Xsize = Xwidth*Xheight; 8346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int nYplanes = layer->n_output_planes; 8366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int Yheight = layer->output_height; 8376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int Ywidth = layer->output_width; 8386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int Ysize = Ywidth*Yheight; 8396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int xx, yy, ni, no, kx, ky; 8416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float *Yplane = 0, *Xplane = 0, *w = 0; 8426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn uchar* connect_mask_data = 0; 8436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ASSERT( X->rows == nXplanes*Xsize && X->cols == 1 ); 8456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ASSERT( Y->rows == nYplanes*Ysize && Y->cols == 1 ); 8466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSetZero( Y ); 8486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn Yplane = Y->data.fl; 8506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn connect_mask_data = layer->connect_mask->data.ptr; 8516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w = layer->weights->data.fl; 8526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( no = 0; no < nYplanes; no++, Yplane += Ysize, w += n_weights_for_Yplane ) 8536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn Xplane = X->data.fl; 8556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ni = 0; ni < nXplanes; ni++, Xplane += Xsize, connect_mask_data++ ) 8566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( *connect_mask_data ) 8586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float* Yelem = Yplane; 8606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // Xheight-K+1 == Yheight && Xwidth-K+1 == Ywidth 8626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( yy = 0; yy < Xheight-K+1; yy++ ) 8636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( xx = 0; xx < Xwidth-K+1; xx++, Yelem++ ) 8656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float* templ = Xplane+yy*Xwidth+xx; 8676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float WX = 0; 8686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ky = 0; ky < K; ky++, templ += Xwidth-K ) 8696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( kx = 0; kx < K; kx++, templ++ ) 8716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn WX += *templ*w[ky*K+kx]; 8736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn *Yelem += WX + w[K*K]; 8766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn }__END__; 8826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 8836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/ 8856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNSubSamplingForward( CvCNNLayer* _layer, 8866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* X, 8876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* Y ) 8886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 8896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME("icvCNNSubSamplingForward"); 8906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !ICV_IS_CNN_SUBSAMPLING_LAYER(_layer) ) 8926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid layer" ); 8936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn {__BEGIN__; 8956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvCNNSubSamplingLayer* layer = (CvCNNSubSamplingLayer*) _layer; 8976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int sub_sampl_scale = layer->sub_samp_scale; 8996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int nplanes = layer->n_input_planes; 9006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int Xheight = layer->input_height; 9026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int Xwidth = layer->input_width ; 9036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int Xsize = Xwidth*Xheight; 9046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int Yheight = layer->output_height; 9066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int Ywidth = layer->output_width; 9076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int Ysize = Ywidth*Yheight; 9086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int xx, yy, ni, kx, ky; 9106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float* sumX_data = 0, *w = 0; 9116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat sumX_sub_col, exp2ssumWX_sub_col; 9126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ASSERT(X->rows == nplanes*Xsize && X->cols == 1); 9146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ASSERT(layer->exp2ssumWX->cols == 1 && layer->exp2ssumWX->rows == nplanes*Ysize); 9156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // update inner variable layer->exp2ssumWX, which will be used in back-progation 9176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvZero( layer->sumX ); 9186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvZero( layer->exp2ssumWX ); 9196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ky = 0; ky < sub_sampl_scale; ky++ ) 9216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( kx = 0; kx < sub_sampl_scale; kx++ ) 9226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float* Xplane = X->data.fl; 9246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sumX_data = layer->sumX->data.fl; 9256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ni = 0; ni < nplanes; ni++, Xplane += Xsize ) 9266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( yy = 0; yy < Yheight; yy++ ) 9286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( xx = 0; xx < Ywidth; xx++, sumX_data++ ) 9296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn *sumX_data += Xplane[((yy+ky)*Xwidth+(xx+kx))]; 9306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w = layer->weights->data.fl; 9346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGetRows( layer->sumX, &sumX_sub_col, 0, Ysize ); 9356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGetRows( layer->exp2ssumWX, &exp2ssumWX_sub_col, 0, Ysize ); 9366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ni = 0; ni < nplanes; ni++, w += 2 ) 9376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvConvertScale( &sumX_sub_col, &exp2ssumWX_sub_col, w[0], w[1] )); 9396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sumX_sub_col.data.fl += Ysize; 9406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn exp2ssumWX_sub_col.data.fl += Ysize; 9416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvScale( layer->exp2ssumWX, layer->exp2ssumWX, 2.0*layer->s )); 9446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvExp( layer->exp2ssumWX, layer->exp2ssumWX )); 9456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvMinS( layer->exp2ssumWX, FLT_MAX, layer->exp2ssumWX )); 9466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//#ifdef _DEBUG 9476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float* exp2ssumWX_data = layer->exp2ssumWX->data.fl; 9496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ni = 0; ni < layer->exp2ssumWX->rows; ni++, exp2ssumWX_data++ ) 9506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( *exp2ssumWX_data == FLT_MAX ) 9526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSetErrStatus( 1 ); 9536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//#endif 9566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // compute the output variable Y == ( a - 2a/(layer->exp2ssumWX + 1)) 9576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvAddS( layer->exp2ssumWX, cvRealScalar(1), Y )); 9586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvDiv( 0, Y, Y, -2.0*layer->a )); 9596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvAddS( Y, cvRealScalar(layer->a), Y )); 9606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn }__END__; 9626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 9636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/ 9656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNFullConnectForward( CvCNNLayer* _layer, const CvMat* X, CvMat* Y ) 9666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 9676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME("icvCNNFullConnectForward"); 9686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !ICV_IS_CNN_FULLCONNECT_LAYER(_layer) ) 9706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid layer" ); 9716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn {__BEGIN__; 9736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvCNNFullConnectLayer* layer = (CvCNNFullConnectLayer*)_layer; 9756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* weights = layer->weights; 9766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat sub_weights, bias; 9776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ASSERT(X->cols == 1 && X->rows == layer->n_input_planes); 9796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ASSERT(Y->cols == 1 && Y->rows == layer->n_output_planes); 9806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvGetSubRect( weights, &sub_weights, 9826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvRect(0, 0, weights->cols-1, weights->rows ))); 9836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvGetCol( weights, &bias, weights->cols-1)); 9846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // update inner variable layer->exp2ssumWX, which will be used in Back-Propagation 9866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvGEMM( &sub_weights, X, 2*layer->s, &bias, 2*layer->s, layer->exp2ssumWX )); 9876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvExp( layer->exp2ssumWX, layer->exp2ssumWX )); 9886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvMinS( layer->exp2ssumWX, FLT_MAX, layer->exp2ssumWX )); 9896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//#ifdef _DEBUG 9906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float* exp2ssumWX_data = layer->exp2ssumWX->data.fl; 9926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i; 9936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < layer->exp2ssumWX->rows; i++, exp2ssumWX_data++ ) 9946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( *exp2ssumWX_data == FLT_MAX ) 9966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSetErrStatus( 1 ); 9976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//#endif 10006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // compute the output variable Y == ( a - 2a/(layer->exp2ssumWX + 1)) 10016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvAddS( layer->exp2ssumWX, cvRealScalar(1), Y )); 10026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvDiv( 0, Y, Y, -2.0*layer->a )); 10036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvAddS( Y, cvRealScalar(layer->a), Y )); 10046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn }__END__; 10066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 10076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************\ 10096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn* Layer BACKWARD functions * 10106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn\****************************************************************************************/ 10116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/* <dE_dY>, <dE_dX> should be row-vectors. 10136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn Function computes partial derivatives <dE_dX> 10146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn of the loss function with respect to the planes components 10156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn of the previous layer (X). 10166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn It is a basic function for back propagation method. 10176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn Input parameter <dE_dY> is the partial derivative of the 10186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn loss function with respect to the planes components 10196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn of the current layer. */ 10206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNConvolutionBackward( 10216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNLayer* _layer, int t, const CvMat* X, const CvMat* dE_dY, CvMat* dE_dX ) 10226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 10236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* dY_dX = 0; 10246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* dY_dW = 0; 10256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* dE_dW = 0; 10266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME("icvCNNConvolutionBackward"); 10286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !ICV_IS_CNN_CONVOLUTION_LAYER(_layer) ) 10306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid layer" ); 10316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn {__BEGIN__; 10336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvCNNConvolutionLayer* layer = (CvCNNConvolutionLayer*) _layer; 10356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int K = layer->K; 10376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int n_X_planes = layer->n_input_planes; 10396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int X_plane_height = layer->input_height; 10406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int X_plane_width = layer->input_width; 10416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int X_plane_size = X_plane_height*X_plane_width; 10426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int n_Y_planes = layer->n_output_planes; 10446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int Y_plane_height = layer->output_height; 10456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int Y_plane_width = layer->output_width; 10466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int Y_plane_size = Y_plane_height*Y_plane_width; 10476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int no, ni, yy, xx, ky, kx; 10496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int X_idx = 0, Y_idx = 0; 10506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float *X_plane = 0, *w = 0; 10526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* weights = layer->weights; 10546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ASSERT( t >= 1 ); 10566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ASSERT( n_Y_planes == weights->rows ); 10576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dY_dX = cvCreateMat( n_Y_planes*Y_plane_size, X->rows, CV_32FC1 ); 10596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dY_dW = cvCreateMat( dY_dX->rows, weights->cols*weights->rows, CV_32FC1 ); 10606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dE_dW = cvCreateMat( 1, dY_dW->cols, CV_32FC1 ); 10616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvZero( dY_dX ); 10636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvZero( dY_dW ); 10646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // compute gradient of the loss function with respect to X and W 10666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( no = 0; no < n_Y_planes; no++, Y_idx += Y_plane_size ) 10676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w = weights->data.fl + no*(K*K+1); 10696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn X_idx = 0; 10706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn X_plane = X->data.fl; 10716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ni = 0; ni < n_X_planes; ni++, X_plane += X_plane_size ) 10726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( layer->connect_mask->data.ptr[ni*n_Y_planes+no] ) 10746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( yy = 0; yy < X_plane_height - K + 1; yy++ ) 10766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( xx = 0; xx < X_plane_width - K + 1; xx++ ) 10786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ky = 0; ky < K; ky++ ) 10806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( kx = 0; kx < K; kx++ ) 10826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_MAT_ELEM(*dY_dX, float, Y_idx+yy*Y_plane_width+xx, 10846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn X_idx+(yy+ky)*X_plane_width+(xx+kx)) = w[ky*K+kx]; 10856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // dY_dWi, i=1,...,K*K 10876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_MAT_ELEM(*dY_dW, float, Y_idx+yy*Y_plane_width+xx, 10886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn no*(K*K+1)+ky*K+kx) += 10896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn X_plane[(yy+ky)*X_plane_width+(xx+kx)]; 10906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // dY_dW(K*K+1)==1 because W(K*K+1) is bias 10936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_MAT_ELEM(*dY_dW, float, Y_idx+yy*Y_plane_width+xx, 10946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn no*(K*K+1)+K*K) += 1; 10956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn X_idx += X_plane_size; 10996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvMatMul( dE_dY, dY_dW, dE_dW )); 11036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvMatMul( dE_dY, dY_dX, dE_dX )); 11046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // update weights 11066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat dE_dW_mat; 11086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float eta; 11096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( layer->learn_rate_decrease_type == CV_CNN_LEARN_RATE_DECREASE_LOG_INV ) 11106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn eta = -layer->init_learn_rate/logf(1+(float)t); 11116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( layer->learn_rate_decrease_type == CV_CNN_LEARN_RATE_DECREASE_SQRT_INV ) 11126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn eta = -layer->init_learn_rate/sqrtf((float)t); 11136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 11146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn eta = -layer->init_learn_rate/(float)t; 11156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReshape( dE_dW, &dE_dW_mat, 0, weights->rows ); 11166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvScaleAdd( &dE_dW_mat, cvRealScalar(eta), weights, weights ); 11176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn }__END__; 11206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &dY_dX ); 11226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &dY_dW ); 11236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &dE_dW ); 11246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 11256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/ 11276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNSubSamplingBackward( 11286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNLayer* _layer, int t, const CvMat*, const CvMat* dE_dY, CvMat* dE_dX ) 11296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 11306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // derivative of activation function 11316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* dY_dX_elems = 0; // elements of matrix dY_dX 11326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* dY_dW_elems = 0; // elements of matrix dY_dW 11336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* dE_dW = 0; 11346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME("icvCNNSubSamplingBackward"); 11366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !ICV_IS_CNN_SUBSAMPLING_LAYER(_layer) ) 11386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid layer" ); 11396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn {__BEGIN__; 11416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvCNNSubSamplingLayer* layer = (CvCNNSubSamplingLayer*) _layer; 11436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int Xwidth = layer->input_width; 11456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int Ywidth = layer->output_width; 11466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int Yheight = layer->output_height; 11476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int Ysize = Ywidth * Yheight; 11486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int scale = layer->sub_samp_scale; 11496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int k_max = layer->n_output_planes * Yheight; 11506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int k, i, j, m; 11526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float* dY_dX_current_elem = 0, *dE_dX_start = 0, *dE_dW_data = 0, *w = 0; 11536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat dy_dw0, dy_dw1; 11546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat activ_func_der, sumX_row; 11556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat dE_dY_sub_row, dY_dX_sub_col, dy_dw0_sub_row, dy_dw1_sub_row; 11566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(dY_dX_elems = cvCreateMat( layer->sumX->rows, 1, CV_32FC1 )); 11586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(dY_dW_elems = cvCreateMat( 2, layer->sumX->rows, CV_32FC1 )); 11596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(dE_dW = cvCreateMat( 1, 2*layer->n_output_planes, CV_32FC1 )); 11606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // compute derivative of activ.func. 11626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // ==<dY_dX_elems> = 4as*(layer->exp2ssumWX)/(layer->exp2ssumWX + 1)^2 11636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvAddS( layer->exp2ssumWX, cvRealScalar(1), dY_dX_elems )); 11646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvPow( dY_dX_elems, dY_dX_elems, -2.0 )); 11656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvMul( dY_dX_elems, layer->exp2ssumWX, dY_dX_elems, 4.0*layer->a*layer->s )); 11666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // compute <dE_dW> 11686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // a) compute <dY_dW_elems> 11696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReshape( dY_dX_elems, &activ_func_der, 0, 1 ); 11706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGetRow( dY_dW_elems, &dy_dw0, 0 ); 11716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGetRow( dY_dW_elems, &dy_dw1, 1 ); 11726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvCopy( &activ_func_der, &dy_dw0 )); 11736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvCopy( &activ_func_der, &dy_dw1 )); 11746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReshape( layer->sumX, &sumX_row, 0, 1 ); 11766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvMul( &dy_dw0, &sumX_row, &dy_dw0 ); 11776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // b) compute <dE_dW> = <dE_dY>*<dY_dW_elems> 11796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGetCols( dE_dY, &dE_dY_sub_row, 0, Ysize ); 11806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGetCols( &dy_dw0, &dy_dw0_sub_row, 0, Ysize ); 11816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGetCols( &dy_dw1, &dy_dw1_sub_row, 0, Ysize ); 11826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dE_dW_data = dE_dW->data.fl; 11836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < layer->n_output_planes; i++ ) 11846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn *dE_dW_data++ = (float)cvDotProduct( &dE_dY_sub_row, &dy_dw0_sub_row ); 11866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn *dE_dW_data++ = (float)cvDotProduct( &dE_dY_sub_row, &dy_dw1_sub_row ); 11876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dE_dY_sub_row.data.fl += Ysize; 11896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dy_dw0_sub_row.data.fl += Ysize; 11906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dy_dw1_sub_row.data.fl += Ysize; 11916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // compute <dY_dX> = layer->weights*<dY_dX> 11946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w = layer->weights->data.fl; 11956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGetRows( dY_dX_elems, &dY_dX_sub_col, 0, Ysize ); 11966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < layer->n_input_planes; i++, w++, dY_dX_sub_col.data.fl += Ysize ) 11976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvConvertScale( &dY_dX_sub_col, &dY_dX_sub_col, (float)*w )); 11986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // compute <dE_dX> 12006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvReshape( dY_dX_elems, dY_dX_elems, 0, 1 )); 12016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvMul( dY_dX_elems, dE_dY, dY_dX_elems )); 12026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dY_dX_current_elem = dY_dX_elems->data.fl; 12046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dE_dX_start = dE_dX->data.fl; 12056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < k_max; k++ ) 12066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 12076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < Ywidth; i++, dY_dX_current_elem++ ) 12086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 12096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float* dE_dX_current_elem = dE_dX_start; 12106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < scale; j++, dE_dX_current_elem += Xwidth - scale ) 12116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 12126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( m = 0; m < scale; m++, dE_dX_current_elem++ ) 12136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn *dE_dX_current_elem = *dY_dX_current_elem; 12146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 12156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dE_dX_start += scale; 12166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 12176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dE_dX_start += Xwidth * (scale - 1); 12186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 12196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // update weights 12216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 12226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat dE_dW_mat, *weights = layer->weights; 12236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float eta; 12246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( layer->learn_rate_decrease_type == CV_CNN_LEARN_RATE_DECREASE_LOG_INV ) 12256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn eta = -layer->init_learn_rate/logf(1+(float)t); 12266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( layer->learn_rate_decrease_type == CV_CNN_LEARN_RATE_DECREASE_SQRT_INV ) 12276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn eta = -layer->init_learn_rate/sqrtf((float)t); 12286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 12296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn eta = -layer->init_learn_rate/(float)t; 12306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReshape( dE_dW, &dE_dW_mat, 0, weights->rows ); 12316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvScaleAdd( &dE_dW_mat, cvRealScalar(eta), weights, weights ); 12326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 12336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn }__END__; 12356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &dY_dX_elems ); 12376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &dY_dW_elems ); 12386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &dE_dW ); 12396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 12406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/ 12426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/* <dE_dY>, <dE_dX> should be row-vectors. 12436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn Function computes partial derivatives <dE_dX>, <dE_dW> 12446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn of the loss function with respect to the planes components 12456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn of the previous layer (X) and the weights of the current layer (W) 12466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn and updates weights od the current layer by using <dE_dW>. 12476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn It is a basic function for back propagation method. 12486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn Input parameter <dE_dY> is the partial derivative of the 12496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn loss function with respect to the planes components 12506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn of the current layer. */ 12516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNFullConnectBackward( CvCNNLayer* _layer, 12526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int t, 12536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* X, 12546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* dE_dY, 12556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* dE_dX ) 12566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 12576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* dE_dY_activ_func_der = 0; 12586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* dE_dW = 0; 12596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "icvCNNFullConnectBackward" ); 12616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !ICV_IS_CNN_FULLCONNECT_LAYER(_layer) ) 12636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid layer" ); 12646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn {__BEGIN__; 12666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvCNNFullConnectLayer* layer = (CvCNNFullConnectLayer*)_layer; 12686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int n_outputs = layer->n_output_planes; 12696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int n_inputs = layer->n_input_planes; 12706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i; 12726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float* dE_dY_activ_func_der_data; 12736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* weights = layer->weights; 12746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat sub_weights, Xtemplate, Xrow, exp2ssumWXrow; 12756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ASSERT(X->cols == 1 && X->rows == n_inputs); 12776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ASSERT(dE_dY->rows == 1 && dE_dY->cols == n_outputs ); 12786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ASSERT(dE_dX->rows == 1 && dE_dX->cols == n_inputs ); 12796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // we violate the convetion about vector's orientation because 12816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // here is more convenient to make this parameter a row-vector 12826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(dE_dY_activ_func_der = cvCreateMat( 1, n_outputs, CV_32FC1 )); 12836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(dE_dW = cvCreateMat( 1, weights->rows*weights->cols, CV_32FC1 )); 12846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // 1) compute gradients dE_dX and dE_dW 12866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // activ_func_der == 4as*(layer->exp2ssumWX)/(layer->exp2ssumWX + 1)^2 12876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvReshape( layer->exp2ssumWX, &exp2ssumWXrow, 0, layer->exp2ssumWX->cols )); 12886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvAddS( &exp2ssumWXrow, cvRealScalar(1), dE_dY_activ_func_der )); 12896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvPow( dE_dY_activ_func_der, dE_dY_activ_func_der, -2.0 )); 12906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvMul( dE_dY_activ_func_der, &exp2ssumWXrow, dE_dY_activ_func_der, 12916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4.0*layer->a*layer->s )); 12926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvMul( dE_dY, dE_dY_activ_func_der, dE_dY_activ_func_der )); 12936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // sub_weights = d(W*(X|1))/dX 12956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvGetSubRect( weights, &sub_weights, 12966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvRect(0, 0, weights->cols-1, weights->rows) )); 12976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvMatMul( dE_dY_activ_func_der, &sub_weights, dE_dX )); 12986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReshape( X, &Xrow, 0, 1 ); 13006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dE_dY_activ_func_der_data = dE_dY_activ_func_der->data.fl; 13016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn Xtemplate = cvMat( 1, n_inputs, CV_32FC1, dE_dW->data.fl ); 13026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n_outputs; i++, Xtemplate.data.fl += n_inputs + 1 ) 13036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 13046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvConvertScale( &Xrow, &Xtemplate, *dE_dY_activ_func_der_data )); 13056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn Xtemplate.data.fl[n_inputs] = *dE_dY_activ_func_der_data++; 13066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 13076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // 2) update weights 13096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 13106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat dE_dW_mat; 13116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float eta; 13126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( layer->learn_rate_decrease_type == CV_CNN_LEARN_RATE_DECREASE_LOG_INV ) 13136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn eta = -layer->init_learn_rate/logf(1+(float)t); 13146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( layer->learn_rate_decrease_type == CV_CNN_LEARN_RATE_DECREASE_SQRT_INV ) 13156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn eta = -layer->init_learn_rate/sqrtf((float)t); 13166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 13176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn eta = -layer->init_learn_rate/(float)t; 13186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReshape( dE_dW, &dE_dW_mat, 0, n_outputs ); 13196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvScaleAdd( &dE_dW_mat, cvRealScalar(eta), weights, weights ); 13206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 13216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn }__END__; 13236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &dE_dY_activ_func_der ); 13256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &dE_dW ); 13266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 13276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************\ 13296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn* Layer RELEASE functions * 13306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn\****************************************************************************************/ 13316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNConvolutionRelease( CvCNNLayer** p_layer ) 13326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 13336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME("icvCNNConvolutionRelease"); 13346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 13356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNConvolutionLayer* layer = 0; 13376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !p_layer ) 13396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsNullPtr, "Null double pointer" ); 13406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer = *(CvCNNConvolutionLayer**)p_layer; 13426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !layer ) 13446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return; 13456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !ICV_IS_CNN_CONVOLUTION_LAYER(layer) ) 13466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid layer" ); 13476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &layer->weights ); 13496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &layer->connect_mask ); 13506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( p_layer ); 13516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 13536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 13546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/ 13566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNSubSamplingRelease( CvCNNLayer** p_layer ) 13576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 13586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME("icvCNNSubSamplingRelease"); 13596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 13606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNSubSamplingLayer* layer = 0; 13626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !p_layer ) 13646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsNullPtr, "Null double pointer" ); 13656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer = *(CvCNNSubSamplingLayer**)p_layer; 13676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !layer ) 13696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return; 13706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !ICV_IS_CNN_SUBSAMPLING_LAYER(layer) ) 13716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid layer" ); 13726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &layer->exp2ssumWX ); 13746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &layer->weights ); 13756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( p_layer ); 13766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 13786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 13796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/ 13816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNFullConnectRelease( CvCNNLayer** p_layer ) 13826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 13836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME("icvCNNFullConnectRelease"); 13846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 13856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNFullConnectLayer* layer = 0; 13876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !p_layer ) 13896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsNullPtr, "Null double pointer" ); 13906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer = *(CvCNNFullConnectLayer**)p_layer; 13926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !layer ) 13946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return; 13956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !ICV_IS_CNN_FULLCONNECT_LAYER(layer) ) 13966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid layer" ); 13976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &layer->exp2ssumWX ); 13996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &layer->weights ); 14006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( p_layer ); 14016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 14036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 14046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************\ 14066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn* Read/Write CNN classifier * 14076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn\****************************************************************************************/ 14086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic int icvIsCNNModel( const void* ptr ) 14096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 14106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return CV_IS_CNN(ptr); 14116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 14126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/ 14146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvReleaseCNNModel( void** ptr ) 14156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 14166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME("icvReleaseCNNModel"); 14176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 14186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !ptr ) 14206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsNullPtr, "NULL double pointer" ); 14216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ASSERT(CV_IS_CNN(*ptr)); 14226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn icvCNNModelRelease( (CvStatModel**)ptr ); 14246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 14266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 14276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/ 14296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic CvCNNLayer* icvReadCNNLayer( CvFileStorage* fs, CvFileNode* node ) 14306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 14316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNLayer* layer = 0; 14326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* weights = 0; 14336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* connect_mask = 0; 14346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME("icvReadCNNLayer"); 14366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 14376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n_input_planes, input_height, input_width; 14396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n_output_planes, output_height, output_width; 14406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int learn_type, layer_type; 14416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float init_learn_rate; 14426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(n_input_planes = cvReadIntByName( fs, node, "n_input_planes", -1 )); 14446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(input_height = cvReadIntByName( fs, node, "input_height", -1 )); 14456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(input_width = cvReadIntByName( fs, node, "input_width", -1 )); 14466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(n_output_planes = cvReadIntByName( fs, node, "n_output_planes", -1 )); 14476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(output_height = cvReadIntByName( fs, node, "output_height", -1 )); 14486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(output_width = cvReadIntByName( fs, node, "output_width", -1 )); 14496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(layer_type = cvReadIntByName( fs, node, "layer_type", -1 )); 14506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(init_learn_rate = (float)cvReadRealByName( fs, node, "init_learn_rate", -1 )); 14526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(learn_type = cvReadIntByName( fs, node, "learn_rate_decrease_type", -1 )); 14536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(weights = (CvMat*)cvReadByName( fs, node, "weights" )); 14546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( n_input_planes < 0 || input_height < 0 || input_width < 0 || 14566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn n_output_planes < 0 || output_height < 0 || output_width < 0 || 14576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn init_learn_rate < 0 || learn_type < 0 || layer_type < 0 || !weights ) 14586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsParseError, "" ); 14596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( layer_type == ICV_CNN_CONVOLUTION_LAYER ) 14616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 14626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int K = input_height - output_height + 1; 14636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( K <= 0 || K != input_width - output_width + 1 ) 14646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid <K>" ); 14656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(connect_mask = (CvMat*)cvReadByName( fs, node, "connect_mask" )); 14676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !connect_mask ) 14686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsParseError, "Missing <connect mask>" ); 14696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(layer = cvCreateCNNConvolutionLayer( 14716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn n_input_planes, input_height, input_width, n_output_planes, K, 14726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn init_learn_rate, learn_type, connect_mask, weights )); 14736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 14746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( layer_type == ICV_CNN_SUBSAMPLING_LAYER ) 14756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 14766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float a, s; 14776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int sub_samp_scale = input_height/output_height; 14786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( sub_samp_scale <= 0 || sub_samp_scale != input_width/output_width ) 14806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid <sub_samp_scale>" ); 14816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(a = (float)cvReadRealByName( fs, node, "a", -1 )); 14836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(s = (float)cvReadRealByName( fs, node, "s", -1 )); 14846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( a < 0 || s < 0 ) 14856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsParseError, "Missing <a> or <s>" ); 14866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(layer = cvCreateCNNSubSamplingLayer( 14886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn n_input_planes, input_height, input_width, sub_samp_scale, 14896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn a, s, init_learn_rate, learn_type, weights )); 14906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 14916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( layer_type == ICV_CNN_FULLCONNECT_LAYER ) 14926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 14936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float a, s; 14946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(a = (float)cvReadRealByName( fs, node, "a", -1 )); 14956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(s = (float)cvReadRealByName( fs, node, "s", -1 )); 14966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( a < 0 || s < 0 ) 14976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsParseError, "" ); 14986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( input_height != 1 || input_width != 1 || 14996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn output_height != 1 || output_width != 1 ) 15006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "" ); 15016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(layer = cvCreateCNNFullConnectLayer( n_input_planes, n_output_planes, 15036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn a, s, init_learn_rate, learn_type, weights )); 15046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 15056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 15066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid <layer_type>" ); 15076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 15096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cvGetErrStatus() < 0 && layer ) 15116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer->release( &layer ); 15126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &weights ); 15146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &connect_mask ); 15156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return layer; 15176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 15186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/ 15206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvWriteCNNLayer( CvFileStorage* fs, CvCNNLayer* layer ) 15216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 15226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME ("icvWriteCNNLayer"); 15236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 15246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !ICV_IS_CNN_LAYER(layer) ) 15266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid layer" ); 15276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvStartWriteStruct( fs, NULL, CV_NODE_MAP, "opencv-ml-cnn-layer" )); 15296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvWriteInt( fs, "n_input_planes", layer->n_input_planes )); 15316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvWriteInt( fs, "input_height", layer->input_height )); 15326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvWriteInt( fs, "input_width", layer->input_width )); 15336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvWriteInt( fs, "n_output_planes", layer->n_output_planes )); 15346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvWriteInt( fs, "output_height", layer->output_height )); 15356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvWriteInt( fs, "output_width", layer->output_width )); 15366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvWriteInt( fs, "learn_rate_decrease_type", layer->learn_rate_decrease_type)); 15376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvWriteReal( fs, "init_learn_rate", layer->init_learn_rate )); 15386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvWrite( fs, "weights", layer->weights )); 15396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ICV_IS_CNN_CONVOLUTION_LAYER( layer )) 15416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 15426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNConvolutionLayer* l = (CvCNNConvolutionLayer*)layer; 15436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvWriteInt( fs, "layer_type", ICV_CNN_CONVOLUTION_LAYER )); 15446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvWrite( fs, "connect_mask", l->connect_mask )); 15456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 15466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( ICV_IS_CNN_SUBSAMPLING_LAYER( layer ) ) 15476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 15486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNSubSamplingLayer* l = (CvCNNSubSamplingLayer*)layer; 15496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvWriteInt( fs, "layer_type", ICV_CNN_SUBSAMPLING_LAYER )); 15506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvWriteReal( fs, "a", l->a )); 15516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvWriteReal( fs, "s", l->s )); 15526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 15536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( ICV_IS_CNN_FULLCONNECT_LAYER( layer ) ) 15546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 15556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNFullConnectLayer* l = (CvCNNFullConnectLayer*)layer; 15566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvWriteInt( fs, "layer_type", ICV_CNN_FULLCONNECT_LAYER )); 15576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvWriteReal( fs, "a", l->a )); 15586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvWriteReal( fs, "s", l->s )); 15596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 15606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 15616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid layer" ); 15626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvEndWriteStruct( fs )); //"opencv-ml-cnn-layer" 15646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 15666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 15676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/ 15696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void* icvReadCNNModel( CvFileStorage* fs, CvFileNode* root_node ) 15706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 15716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNStatModel* cnn = 0; 15726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNLayer* layer = 0; 15736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME("icvReadCNNModel"); 15756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 15766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvFileNode* node; 15786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvSeq* seq; 15796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvSeqReader reader; 15806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i; 15816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cnn = (CvCNNStatModel*)cvCreateStatModel( 15836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_STAT_MODEL_MAGIC_VAL|CV_CNN_MAGIC_VAL, sizeof(CvCNNStatModel), 15846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn icvCNNModelRelease, icvCNNModelPredict, icvCNNModelUpdate )); 15856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cnn->etalons = (CvMat*)cvReadByName( fs, root_node, "etalons" )); 15876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cnn->cls_labels = (CvMat*)cvReadByName( fs, root_node, "cls_labels" )); 15886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !cnn->etalons || !cnn->cls_labels ) 15906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsParseError, "No <etalons> or <cls_labels> in CNN model" ); 15916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( node = cvGetFileNodeByName( fs, root_node, "network" )); 15936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn seq = node->data.seq; 15946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_NODE_IS_SEQ(node->tag) ) 15956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "" ); 15966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvStartReadSeq( seq, &reader, 0 )); 15986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(layer = icvReadCNNLayer( fs, (CvFileNode*)reader.ptr )); 15996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cnn->network = cvCreateCNNetwork( layer )); 16006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 1; i < seq->total; i++ ) 16026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 16036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_NEXT_SEQ_ELEM( seq->elem_size, reader ); 16046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(layer = icvReadCNNLayer( fs, (CvFileNode*)reader.ptr )); 16056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cnn->network->add_layer( cnn->network, layer )); 16066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 16076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 16096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cvGetErrStatus() < 0 ) 16116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 16126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cnn ) cnn->release( (CvStatModel**)&cnn ); 16136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( layer ) layer->release( &layer ); 16146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 16156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return (void*)cnn; 16166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 16176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/ 16196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void 16206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennicvWriteCNNModel( CvFileStorage* fs, const char* name, 16216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const void* struct_ptr, CvAttrList ) 16226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 16246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME ("icvWriteCNNModel"); 16256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 16266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNStatModel* cnn = (CvCNNStatModel*)struct_ptr; 16286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n_layers, i; 16296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvCNNLayer* layer; 16306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_IS_CNN(cnn) ) 16326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid pointer" ); 16336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn n_layers = cnn->network->n_layers; 16356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_CNN )); 16376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvWrite( fs, "etalons", cnn->etalons )); 16396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvWrite( fs, "cls_labels", cnn->cls_labels )); 16406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvStartWriteStruct( fs, "network", CV_NODE_SEQ )); 16426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn layer = cnn->network->layers; 16446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n_layers && layer; i++, layer = layer->next_layer ) 16456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(icvWriteCNNLayer( fs, layer )); 16466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( i < n_layers || layer ) 16476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Invalid network" ); 16486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvEndWriteStruct( fs )); //"network" 16506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvEndWriteStruct( fs )); //"opencv-ml-cnn" 16516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 16536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 16546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic int icvRegisterCNNStatModelType() 16566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 16576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvTypeInfo info; 16586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn info.header_size = sizeof( info ); 16606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn info.is_instance = icvIsCNNModel; 16616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn info.release = icvReleaseCNNModel; 16626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn info.read = icvReadCNNModel; 16636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn info.write = icvWriteCNNModel; 16646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn info.clone = NULL; 16656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn info.type_name = CV_TYPE_NAME_ML_CNN; 16666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvRegisterType( &info ); 16676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return 1; 16696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} // End of icvRegisterCNNStatModelType 16706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic int cnn = icvRegisterCNNStatModelType(); 16726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#endif 16746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// End of file 1676