16acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/*M///////////////////////////////////////////////////////////////////////////////////////
26acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
36acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
46acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
56acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  By downloading, copying, installing or using the software you agree to this license.
66acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  If you do not agree to this license, do not download, install,
76acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  copy or use the software.
86acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
96acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//                        Intel License Agreement
116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Copyright (C) 2000, Intel Corporation, all rights reserved.
136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Third party copyrights are property of their respective owners.
146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Redistribution and use in source and binary forms, with or without modification,
166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// are permitted provided that the following conditions are met:
176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//   * Redistribution's of source code must retain the above copyright notice,
196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     this list of conditions and the following disclaimer.
206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//   * Redistribution's in binary form must reproduce the above copyright notice,
226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     this list of conditions and the following disclaimer in the documentation
236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     and/or other materials provided with the distribution.
246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//   * The name of Intel Corporation may not be used to endorse or promote products
266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     derived from this software without specific prior written permission.
276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// This software is provided by the copyright holders and contributors "as is" and
296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// any express or implied warranties, including, but not limited to, the implied
306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// warranties of merchantability and fitness for a particular purpose are disclaimed.
316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// In no event shall the Intel Corporation or contributors be liable for any direct,
326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// indirect, incidental, special, exemplary, or consequential damages
336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// (including, but not limited to, procurement of substitute goods or services;
346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// loss of use, data, or profits; or business interruption) however caused
356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// and on any theory of liability, whether in contract, strict liability,
366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// or tort (including negligence or otherwise) arising in any way out of
376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// the use of this software, even if advised of the possibility of such damage.
386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//M*/
406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#include "_ml.h"
426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#if 0
446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************\
456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn*                         Auxilary functions declarations                                *
466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn\****************************************************************************************/
476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/*---------------------- functions for the CNN classifier ------------------------------*/
486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic float icvCNNModelPredict(
496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const CvStatModel* cnn_model,
506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const CvMat* image,
516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvMat* probs CV_DEFAULT(0) );
526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNModelUpdate(
546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvStatModel* cnn_model, const CvMat* images, int tflag,
556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const CvMat* responses, const CvStatModelParams* params,
566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const CvMat* CV_DEFAULT(0), const CvMat* sample_idx CV_DEFAULT(0),
576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const CvMat* CV_DEFAULT(0), const CvMat* CV_DEFAULT(0));
586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNModelRelease( CvStatModel** cnn_model );
606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvTrainCNNetwork( CvCNNetwork* network,
626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                               const float** images,
636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                               const CvMat* responses,
646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                               const CvMat* etalons,
656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                               int grad_estim_type,
666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                               int max_iter,
676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                               int start_iter );
686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/*------------------------- functions for the CNN network ------------------------------*/
706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNetworkAddLayer( CvCNNetwork* network, CvCNNLayer* layer );
716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNetworkRelease( CvCNNetwork** network );
726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/* In all layer functions we denote input by X and output by Y, where
746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   X and Y are column-vectors, so that
756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   length(X)==<n_input_planes>*<input_height>*<input_width>,
766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   length(Y)==<n_output_planes>*<output_height>*<output_width>.
776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn*/
786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/*------------------------ functions for convolutional layer ---------------------------*/
796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNConvolutionRelease( CvCNNLayer** p_layer );
806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNConvolutionForward( CvCNNLayer* layer, const CvMat* X, CvMat* Y );
826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNConvolutionBackward( CvCNNLayer*  layer, int t,
846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const CvMat* X, const CvMat* dE_dY, CvMat* dE_dX );
856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/*------------------------ functions for sub-sampling layer ----------------------------*/
876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNSubSamplingRelease( CvCNNLayer** p_layer );
886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNSubSamplingForward( CvCNNLayer* layer, const CvMat* X, CvMat* Y );
906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNSubSamplingBackward( CvCNNLayer*  layer, int t,
926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const CvMat* X, const CvMat* dE_dY, CvMat* dE_dX );
936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/*------------------------ functions for full connected layer --------------------------*/
956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNFullConnectRelease( CvCNNLayer** p_layer );
966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNFullConnectForward( CvCNNLayer* layer, const CvMat* X, CvMat* Y );
986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNFullConnectBackward( CvCNNLayer* layer, int,
1006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const CvMat*, const CvMat* dE_dY, CvMat* dE_dX );
1016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************\
1036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn*                             Functions implementations                                  *
1046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn\****************************************************************************************/
1056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#define ICV_CHECK_CNN_NETWORK(network)                                                  \
1076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{                                                                                       \
1086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNLayer* first_layer, *layer, *last_layer;                                       \
1096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n_layers, i;                                                                    \
1106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !network )                                                                      \
1116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsNullPtr,                                                        \
1126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        "Null <network> pointer. Network must be created by user." );                   \
1136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    n_layers = network->n_layers;                                                       \
1146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    first_layer = last_layer = network->layers;                                         \
1156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0, layer = first_layer; i < n_layers && layer; i++ )                       \
1166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {                                                                                   \
1176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !ICV_IS_CNN_LAYER(layer) )                                                  \
1186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsNullPtr, "Invalid network" );                               \
1196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        last_layer = layer;                                                             \
1206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        layer = layer->next_layer;                                                      \
1216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }                                                                                   \
1226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                                                                        \
1236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( i == 0 || i != n_layers || first_layer->prev_layer || layer )                   \
1246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsNullPtr, "Invalid network" );                                   \
1256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                                                                        \
1266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( first_layer->n_input_planes != 1 )                                              \
1276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "First layer must contain only one input plane" );      \
1286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                                                                        \
1296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( img_size != first_layer->input_height*first_layer->input_width )                \
1306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid input sizes of the first layer" );             \
1316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                                                                        \
1326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params->etalons->cols != last_layer->n_output_planes*                           \
1336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        last_layer->output_height*last_layer->output_width )                            \
1346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid output sizes of the last layer" );             \
1356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
1366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#define ICV_CHECK_CNN_MODEL_PARAMS(params)                                              \
1386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{                                                                                       \
1396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !params )                                                                       \
1406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsNullPtr, "Null <params> pointer" );                             \
1416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                                                                        \
1426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !ICV_IS_MAT_OF_TYPE(params->etalons, CV_32FC1) )                                \
1436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "<etalons> must be CV_32FC1 type" );                    \
1446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params->etalons->rows != cnn_model->cls_labels->cols )                          \
1456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid <etalons> size" );                             \
1466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                                                                        \
1476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params->grad_estim_type != CV_CNN_GRAD_ESTIM_RANDOM &&                          \
1486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params->grad_estim_type != CV_CNN_GRAD_ESTIM_BY_WORST_IMG )                     \
1496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid <grad_estim_type>" );                          \
1506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                                                                        \
1516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params->start_iter < 0 )                                                        \
1526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Parameter <start_iter> must be positive or zero" );    \
1536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                                                                        \
1546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params->max_iter < 1 )                                                \
1556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params->max_iter = 1;                                                 \
1566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
1576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************\
1596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn*                              Classifier functions                                      *
1606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn\****************************************************************************************/
1616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennML_IMPL CvStatModel*
1626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RenncvTrainCNNClassifier( const CvMat* _train_data, int tflag,
1636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const CvMat* _responses,
1646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const CvStatModelParams* _params,
1656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const CvMat*, const CvMat* _sample_idx, const CvMat*, const CvMat* )
1666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
1676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNStatModel* cnn_model    = 0;
1686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const float** out_train_data = 0;
1696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* responses             = 0;
1706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("cvTrainCNNClassifier");
1726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
1736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n_images;
1756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int img_size;
1766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNStatModelParams* params = (CvCNNStatModelParams*)_params;
1776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cnn_model = (CvCNNStatModel*)cvCreateStatModel(
1796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_STAT_MODEL_MAGIC_VAL|CV_CNN_MAGIC_VAL, sizeof(CvCNNStatModel),
1806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        icvCNNModelRelease, icvCNNModelPredict, icvCNNModelUpdate ));
1816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvPrepareTrainData( "cvTrainCNNClassifier",
1836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        _train_data, tflag, _responses, CV_VAR_CATEGORICAL,
1846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        0, _sample_idx, false, &out_train_data,
1856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        &n_images, &img_size, &img_size, &responses,
1866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        &cnn_model->cls_labels, 0 ));
1876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ICV_CHECK_CNN_MODEL_PARAMS(params);
1896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ICV_CHECK_CNN_NETWORK(params->network);
1906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cnn_model->network = params->network;
1926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cnn_model->etalons = (CvMat*)cvClone( params->etalons ));
1936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( icvTrainCNNetwork( cnn_model->network, out_train_data, responses,
1956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cnn_model->etalons, params->grad_estim_type, params->max_iter,
1966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params->start_iter ));
1976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
1996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( cvGetErrStatus() < 0 && cnn_model )
2016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
2026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cnn_model->release( (CvStatModel**)&cnn_model );
2036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
2046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( &out_train_data );
2056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &responses );
2066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return (CvStatModel*)cnn_model;
2086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
2096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/
2116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvTrainCNNetwork( CvCNNetwork* network,
2126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                               const float** images,
2136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                               const CvMat* responses,
2146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                               const CvMat* etalons,
2156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                               int grad_estim_type,
2166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                               int max_iter,
2176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                               int start_iter )
2186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
2196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat** X     = 0;
2206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat** dE_dX = 0;
2216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int n_layers = network->n_layers;
2226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int k;
2236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("icvTrainCNNetwork");
2256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
2266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNLayer* first_layer = network->layers;
2286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int img_height = first_layer->input_height;
2296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int img_width  = first_layer->input_width;
2306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int img_size   = img_width*img_height;
2316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int n_images   = responses->cols;
2326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat image = cvMat( 1, img_size, CV_32FC1 );
2336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNLayer* layer;
2346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n;
2356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvRNG rng = cvRNG(-1);
2366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(X = (CvMat**)cvAlloc( (n_layers+1)*sizeof(CvMat*) ));
2386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(dE_dX = (CvMat**)cvAlloc( (n_layers+1)*sizeof(CvMat*) ));
2396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    memset( X, 0, (n_layers+1)*sizeof(CvMat*) );
2406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    memset( dE_dX, 0, (n_layers+1)*sizeof(CvMat*) );
2416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(X[0] = cvCreateMat( img_height*img_width,1,CV_32FC1 ));
2436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(dE_dX[0] = cvCreateMat( 1, X[0]->rows, CV_32FC1 ));
2446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( k = 0, layer = first_layer; k < n_layers; k++, layer = layer->next_layer )
2456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
2466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(X[k+1] = cvCreateMat( layer->n_output_planes*layer->output_height*
2476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            layer->output_width, 1, CV_32FC1 ));
2486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(dE_dX[k+1] = cvCreateMat( 1, X[k+1]->rows, CV_32FC1 ));
2496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
2506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( n = 1; n <= max_iter; n++ )
2526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
2536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        float loss, max_loss = 0;
2546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int i;
2556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int worst_img_idx = -1;
2566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int* right_etal_idx = responses->data.i;
2576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvMat etalon;
2586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // Find the worst image (which produces the greatest loss) or use the random image
2606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( grad_estim_type == CV_CNN_GRAD_ESTIM_BY_WORST_IMG )
2616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
2626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < n_images; i++, right_etal_idx++ )
2636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
2646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                image.data.fl = (float*)images[i];
2656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvTranspose( &image, X[0] );
2666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( k = 0, layer = first_layer; k < n_layers; k++, layer = layer->next_layer )
2686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    CV_CALL(layer->forward( layer, X[k], X[k+1] ));
2696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvTranspose( X[n_layers], dE_dX[n_layers] );
2716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvGetRow( etalons, &etalon, *right_etal_idx );
2726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                loss = (float)cvNorm( dE_dX[n_layers], &etalon );
2736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( loss > max_loss )
2746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
2756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    max_loss = loss;
2766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    worst_img_idx = i;
2776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
2786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
2796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
2806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
2816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            worst_img_idx = cvRandInt(&rng) % n_images;
2826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // Train network on the worst image
2846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // 1) Compute the network output on the <image>
2856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        image.data.fl = (float*)images[worst_img_idx];
2866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(cvTranspose( &image, X[0] ));
2876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( k = 0, layer = first_layer; k < n_layers - 1; k++, layer = layer->next_layer )
2896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_CALL(layer->forward( layer, X[k], X[k+1] ));
2906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(layer->forward( layer, X[k], X[k+1] ));
2916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // 2) Compute the gradient
2936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvTranspose( X[n_layers], dE_dX[n_layers] );
2946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvGetRow( etalons, &etalon, responses->data.i[worst_img_idx] );
2956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvSub( dE_dX[n_layers], &etalon, dE_dX[n_layers] );
2966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // 3) Update weights by the gradient descent
2986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( k = n_layers; k > 0; k--, layer = layer->prev_layer )
2996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_CALL(layer->backward( layer, n + start_iter, X[k-1], dE_dX[k], dE_dX[k-1] ));
3006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
3016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
3036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( k = 0; k <= n_layers; k++ )
3056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReleaseMat( &X[k] );
3076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReleaseMat( &dE_dX[k] );
3086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
3096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( &X );
3106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( &dE_dX );
3116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
3126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/
3146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic float icvCNNModelPredict( const CvStatModel* model,
3156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                 const CvMat* _image,
3166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                 CvMat* probs )
3176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
3186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat** X       = 0;
3196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float* img_data = 0;
3206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n_layers = 0;
3216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int best_etal_idx = -1;
3226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int k;
3236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("icvCNNModelPredict");
3256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
3266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNStatModel* cnn_model = (CvCNNStatModel*)model;
3286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNLayer* first_layer, *layer = 0;
3296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int img_height, img_width, img_size;
3306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int nclasses, i;
3316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float loss, min_loss = FLT_MAX;
3326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float* probs_data;
3336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat etalon, image;
3346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !CV_IS_CNN(model) )
3366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid model" );
3376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    nclasses = cnn_model->cls_labels->cols;
3396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    n_layers = cnn_model->network->n_layers;
3406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    first_layer   = cnn_model->network->layers;
3416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    img_height = first_layer->input_height;
3426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    img_width  = first_layer->input_width;
3436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    img_size   = img_height*img_width;
3446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvPreparePredictData( _image, img_size, 0, nclasses, probs, &img_data );
3466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(X = (CvMat**)cvAlloc( (n_layers+1)*sizeof(CvMat*) ));
3486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    memset( X, 0, (n_layers+1)*sizeof(CvMat*) );
3496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(X[0] = cvCreateMat( img_size,1,CV_32FC1 ));
3516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( k = 0, layer = first_layer; k < n_layers; k++, layer = layer->next_layer )
3526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(X[k+1] = cvCreateMat( layer->n_output_planes*layer->output_height*
3546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            layer->output_width, 1, CV_32FC1 ));
3556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
3566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    image = cvMat( 1, img_size, CV_32FC1, img_data );
3586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvTranspose( &image, X[0] );
3596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( k = 0, layer = first_layer; k < n_layers; k++, layer = layer->next_layer )
3606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(layer->forward( layer, X[k], X[k+1] ));
3616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    probs_data = probs ? probs->data.fl : 0;
3636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    etalon = cvMat( cnn_model->etalons->cols, 1, CV_32FC1, cnn_model->etalons->data.fl );
3646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < nclasses; i++, etalon.data.fl += cnn_model->etalons->cols )
3656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        loss = (float)cvNorm( X[n_layers], &etalon );
3676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( loss < min_loss )
3686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
3696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            min_loss = loss;
3706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            best_etal_idx = i;
3716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
3726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( probs )
3736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            *probs_data++ = -loss;
3746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
3756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( probs )
3776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvExp( probs, probs );
3796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvScalar sum = cvSum( probs );
3806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvConvertScale( probs, probs, 1./sum.val[0] );
3816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
3826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
3846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( k = 0; k <= n_layers; k++ )
3866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReleaseMat( &X[k] );
3876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( &X );
3886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( img_data != _image->data.fl )
3896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvFree( &img_data );
3906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return ((float) ((CvCNNStatModel*)model)->cls_labels->data.i[best_etal_idx]);
3926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
3936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/
3956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNModelUpdate(
3966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvStatModel* _cnn_model, const CvMat* _train_data, int tflag,
3976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const CvMat* _responses, const CvStatModelParams* _params,
3986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const CvMat*, const CvMat* _sample_idx,
3996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const CvMat*, const CvMat* )
4006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
4016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const float** out_train_data = 0;
4026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* responses             = 0;
4036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* cls_labels            = 0;
4046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("icvCNNModelUpdate");
4066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
4076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n_images, img_size, i;
4096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNStatModelParams* params = (CvCNNStatModelParams*)_params;
4106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNStatModel* cnn_model = (CvCNNStatModel*)_cnn_model;
4116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !CV_IS_CNN(cnn_model) )
4136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid model" );
4146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvPrepareTrainData( "cvTrainCNNClassifier",
4166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        _train_data, tflag, _responses, CV_VAR_CATEGORICAL,
4176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        0, _sample_idx, false, &out_train_data,
4186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        &n_images, &img_size, &img_size, &responses,
4196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        &cls_labels, 0, 0 ));
4206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ICV_CHECK_CNN_MODEL_PARAMS(params);
4226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // Number of classes must be the same as when classifiers was created
4246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !CV_ARE_SIZES_EQ(cls_labels, cnn_model->cls_labels) )
4256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Number of classes must be left unchanged" );
4266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < cls_labels->cols; i++ )
4276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
4286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( cls_labels->data.i[i] != cnn_model->cls_labels->data.i[i] )
4296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadArg, "Number of classes must be left unchanged" );
4306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
4316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( icvTrainCNNetwork( cnn_model->network, out_train_data, responses,
4336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cnn_model->etalons, params->grad_estim_type, params->max_iter,
4346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params->start_iter ));
4356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
4376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( &out_train_data );
4396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &responses );
4406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
4416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/
4436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNModelRelease( CvStatModel** cnn_model )
4446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
4456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("icvCNNModelRelease");
4466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
4476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNStatModel* cnn;
4496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !cnn_model )
4506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsNullPtr, "Null double pointer" );
4516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cnn = *(CvCNNStatModel**)cnn_model;
4536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &cnn->cls_labels );
4556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &cnn->etalons );
4566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cnn->network->release( &cnn->network );
4576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( &cnn );
4596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
4616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
4636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************\
4656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn*                                 Network functions                                      *
4666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn\****************************************************************************************/
4676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennML_IMPL CvCNNetwork* cvCreateCNNetwork( CvCNNLayer* first_layer )
4686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
4696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNetwork* network = 0;
4706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "cvCreateCNNetwork" );
4726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
4736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !ICV_IS_CNN_LAYER(first_layer) )
4756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid layer" );
4766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(network = (CvCNNetwork*)cvAlloc( sizeof(CvCNNetwork) ));
4786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    memset( network, 0, sizeof(CvCNNetwork) );
4796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    network->layers    = first_layer;
4816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    network->n_layers  = 1;
4826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    network->release   = icvCNNetworkRelease;
4836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    network->add_layer = icvCNNetworkAddLayer;
4846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
4866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( cvGetErrStatus() < 0 && network )
4886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvFree( &network );
4896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return network;
4916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
4936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/
4956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNetworkAddLayer( CvCNNetwork* network, CvCNNLayer* layer )
4966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
4976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "icvCNNetworkAddLayer" );
4986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
4996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNLayer* prev_layer;
5016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( network == NULL )
5036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsNullPtr, "Null <network> pointer" );
5046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    prev_layer = network->layers;
5066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    while( prev_layer->next_layer )
5076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        prev_layer = prev_layer->next_layer;
5086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( ICV_IS_CNN_FULLCONNECT_LAYER(layer) )
5106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
5116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( layer->n_input_planes != prev_layer->output_width*prev_layer->output_height*
5126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            prev_layer->n_output_planes )
5136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadArg, "Unmatched size of the new layer" );
5146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( layer->input_height != 1 || layer->output_height != 1 ||
5156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            layer->input_width != 1  || layer->output_width != 1 )
5166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadArg, "Invalid size of the new layer" );
5176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
5186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else if( ICV_IS_CNN_CONVOLUTION_LAYER(layer) || ICV_IS_CNN_SUBSAMPLING_LAYER(layer) )
5196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
5206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( prev_layer->n_output_planes != layer->n_input_planes ||
5216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        prev_layer->output_height   != layer->input_height ||
5226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        prev_layer->output_width    != layer->input_width )
5236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Unmatched size of the new layer" );
5246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
5256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
5266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid layer" );
5276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer->prev_layer = prev_layer;
5296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    prev_layer->next_layer = layer;
5306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    network->n_layers++;
5316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
5336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
5346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/
5366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNetworkRelease( CvCNNetwork** network_pptr )
5376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
5386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "icvReleaseCNNetwork" );
5396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
5406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNetwork* network = 0;
5426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNLayer* layer = 0, *next_layer = 0;
5436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int k;
5446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( network_pptr == NULL )
5466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Null double pointer" );
5476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( *network_pptr == NULL )
5486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        return;
5496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    network = *network_pptr;
5516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer = network->layers;
5526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( layer == NULL )
5536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "CNN is empty (does not contain any layer)" );
5546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // k is the number of the layer to be deleted
5566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( k = 0; k < network->n_layers && layer; k++ )
5576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
5586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        next_layer = layer->next_layer;
5596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        layer->release( &layer );
5606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        layer = next_layer;
5616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
5626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( k != network->n_layers || layer)
5646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid network" );
5656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( &network );
5676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
5696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
5706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************\
5726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn*                                  Layer functions                                       *
5736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn\****************************************************************************************/
5746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic CvCNNLayer* icvCreateCNNLayer( int layer_type, int header_size,
5756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n_input_planes, int input_height, int input_width,
5766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n_output_planes, int output_height, int output_width,
5776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float init_learn_rate, int learn_rate_decrease_type,
5786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNLayerRelease release, CvCNNLayerForward forward, CvCNNLayerBackward backward )
5796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
5806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNLayer* layer = 0;
5816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("icvCreateCNNLayer");
5836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
5846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_ASSERT( release && forward && backward )
5866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_ASSERT( header_size >= sizeof(CvCNNLayer) )
5876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( n_input_planes < 1 || n_output_planes < 1 ||
5896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        input_height   < 1 || input_width < 1 ||
5906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        output_height  < 1 || output_width < 1 ||
5916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        input_height < output_height ||
5926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        input_width  < output_width )
5936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Incorrect input or output parameters" );
5946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( init_learn_rate < FLT_EPSILON )
5956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Initial learning rate must be positive" );
5966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( learn_rate_decrease_type != CV_CNN_LEARN_RATE_DECREASE_HYPERBOLICALLY &&
5976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        learn_rate_decrease_type != CV_CNN_LEARN_RATE_DECREASE_SQRT_INV &&
5986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        learn_rate_decrease_type != CV_CNN_LEARN_RATE_DECREASE_LOG_INV )
5996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid type of learning rate dynamics" );
6006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(layer = (CvCNNLayer*)cvAlloc( header_size ));
6026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    memset( layer, 0, header_size );
6036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer->flags = ICV_CNN_LAYER|layer_type;
6056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_ASSERT( ICV_IS_CNN_LAYER(layer) )
6066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer->n_input_planes = n_input_planes;
6086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer->input_height   = input_height;
6096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer->input_width    = input_width;
6106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer->n_output_planes = n_output_planes;
6126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer->output_height   = output_height;
6136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer->output_width    = output_width;
6146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer->init_learn_rate = init_learn_rate;
6166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer->learn_rate_decrease_type = learn_rate_decrease_type;
6176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer->release  = release;
6196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer->forward  = forward;
6206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer->backward = backward;
6216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
6236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( cvGetErrStatus() < 0 && layer)
6256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvFree( &layer );
6266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return layer;
6286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
6296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/
6316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennML_IMPL CvCNNLayer* cvCreateCNNConvolutionLayer(
6326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n_input_planes, int input_height, int input_width,
6336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n_output_planes, int K,
6346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float init_learn_rate, int learn_rate_decrease_type,
6356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* connect_mask, CvMat* weights )
6366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
6386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNConvolutionLayer* layer = 0;
6396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("cvCreateCNNConvolutionLayer");
6416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
6426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int output_height = input_height - K + 1;
6446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int output_width = input_width - K + 1;
6456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( K < 1 || init_learn_rate <= 0 )
6476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Incorrect parameters" );
6486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(layer = (CvCNNConvolutionLayer*)icvCreateCNNLayer( ICV_CNN_CONVOLUTION_LAYER,
6506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sizeof(CvCNNConvolutionLayer), n_input_planes, input_height, input_width,
6516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        n_output_planes, output_height, output_width,
6526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        init_learn_rate, learn_rate_decrease_type,
6536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        icvCNNConvolutionRelease, icvCNNConvolutionForward, icvCNNConvolutionBackward ));
6546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer->K = K;
6566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(layer->weights = cvCreateMat( n_output_planes, K*K+1, CV_32FC1 ));
6576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(layer->connect_mask = cvCreateMat( n_output_planes, n_input_planes, CV_8UC1));
6586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( weights )
6606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
6616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !ICV_IS_MAT_OF_TYPE( weights, CV_32FC1 ) )
6626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadSize, "Type of initial weights matrix must be CV_32FC1" );
6636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !CV_ARE_SIZES_EQ( weights, layer->weights ) )
6646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadSize, "Invalid size of initial weights matrix" );
6656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(cvCopy( weights, layer->weights ));
6666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
6676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
6686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
6696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvRNG rng = cvRNG( 0xFFFFFFFF );
6706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvRandArr( &rng, layer->weights, CV_RAND_UNI, cvRealScalar(-1), cvRealScalar(1) );
6716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
6726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( connect_mask )
6746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
6756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !ICV_IS_MAT_OF_TYPE( connect_mask, CV_8UC1 ) )
6766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadSize, "Type of connection matrix must be CV_32FC1" );
6776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !CV_ARE_SIZES_EQ( connect_mask, layer->connect_mask ) )
6786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadSize, "Invalid size of connection matrix" );
6796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(cvCopy( connect_mask, layer->connect_mask ));
6806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
6816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
6826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(cvSet( layer->connect_mask, cvRealScalar(1) ));
6836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
6856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( cvGetErrStatus() < 0 && layer )
6876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
6886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReleaseMat( &layer->weights );
6896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReleaseMat( &layer->connect_mask );
6906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvFree( &layer );
6916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
6926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return (CvCNNLayer*)layer;
6946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
6956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/
6976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennML_IMPL CvCNNLayer* cvCreateCNNSubSamplingLayer(
6986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n_input_planes, int input_height, int input_width,
6996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int sub_samp_scale, float a, float s,
7006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float init_learn_rate, int learn_rate_decrease_type, CvMat* weights )
7016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
7036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNSubSamplingLayer* layer = 0;
7046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("cvCreateCNNSubSamplingLayer");
7066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
7076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int output_height   = input_height/sub_samp_scale;
7096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int output_width    = input_width/sub_samp_scale;
7106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int n_output_planes = n_input_planes;
7116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( sub_samp_scale < 1 || a <= 0 || s <= 0)
7136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Incorrect parameters" );
7146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(layer = (CvCNNSubSamplingLayer*)icvCreateCNNLayer( ICV_CNN_SUBSAMPLING_LAYER,
7166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sizeof(CvCNNSubSamplingLayer), n_input_planes, input_height, input_width,
7176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        n_output_planes, output_height, output_width,
7186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        init_learn_rate, learn_rate_decrease_type,
7196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        icvCNNSubSamplingRelease, icvCNNSubSamplingForward, icvCNNSubSamplingBackward ));
7206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer->sub_samp_scale  = sub_samp_scale;
7226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer->a               = a;
7236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer->s               = s;
7246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(layer->sumX =
7266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvCreateMat( n_output_planes*output_width*output_height, 1, CV_32FC1 ));
7276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(layer->exp2ssumWX =
7286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvCreateMat( n_output_planes*output_width*output_height, 1, CV_32FC1 ));
7296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvZero( layer->sumX );
7316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvZero( layer->exp2ssumWX );
7326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(layer->weights = cvCreateMat( n_output_planes, 2, CV_32FC1 ));
7346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( weights )
7356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
7366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !ICV_IS_MAT_OF_TYPE( weights, CV_32FC1 ) )
7376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadSize, "Type of initial weights matrix must be CV_32FC1" );
7386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !CV_ARE_SIZES_EQ( weights, layer->weights ) )
7396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadSize, "Invalid size of initial weights matrix" );
7406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(cvCopy( weights, layer->weights ));
7416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
7426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
7436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
7446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvRNG rng = cvRNG( 0xFFFFFFFF );
7456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvRandArr( &rng, layer->weights, CV_RAND_UNI, cvRealScalar(-1), cvRealScalar(1) );
7466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
7476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
7496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( cvGetErrStatus() < 0 && layer )
7516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
7526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReleaseMat( &layer->exp2ssumWX );
7536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvFree( &layer );
7546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
7556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return (CvCNNLayer*)layer;
7576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
7586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/
7606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennML_IMPL CvCNNLayer* cvCreateCNNFullConnectLayer(
7616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n_inputs, int n_outputs, float a, float s,
7626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float init_learn_rate, int learn_rate_decrease_type, CvMat* weights )
7636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
7646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNFullConnectLayer* layer = 0;
7656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("cvCreateCNNFullConnectLayer");
7676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
7686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( a <= 0 || s <= 0 || init_learn_rate <= 0)
7706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Incorrect parameters" );
7716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(layer = (CvCNNFullConnectLayer*)icvCreateCNNLayer( ICV_CNN_FULLCONNECT_LAYER,
7736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sizeof(CvCNNFullConnectLayer), n_inputs, 1, 1, n_outputs, 1, 1,
7746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        init_learn_rate, learn_rate_decrease_type,
7756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        icvCNNFullConnectRelease, icvCNNFullConnectForward, icvCNNFullConnectBackward ));
7766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer->a = a;
7786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer->s = s;
7796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(layer->exp2ssumWX = cvCreateMat( n_outputs, 1, CV_32FC1 ));
7816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvZero( layer->exp2ssumWX );
7826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(layer->weights = cvCreateMat( n_outputs, n_inputs+1, CV_32FC1 ));
7846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( weights )
7856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
7866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !ICV_IS_MAT_OF_TYPE( weights, CV_32FC1 ) )
7876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadSize, "Type of initial weights matrix must be CV_32FC1" );
7886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !CV_ARE_SIZES_EQ( weights, layer->weights ) )
7896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadSize, "Invalid size of initial weights matrix" );
7906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(cvCopy( weights, layer->weights ));
7916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
7926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
7936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
7946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvRNG rng = cvRNG( 0xFFFFFFFF );
7956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvRandArr( &rng, layer->weights, CV_RAND_UNI, cvRealScalar(-1), cvRealScalar(1) );
7966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
7976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
7996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( cvGetErrStatus() < 0 && layer )
8016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
8026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReleaseMat( &layer->exp2ssumWX );
8036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReleaseMat( &layer->weights );
8046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvFree( &layer );
8056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
8066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return (CvCNNLayer*)layer;
8086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
8096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************\
8126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn*                           Layer FORWARD functions                                      *
8136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn\****************************************************************************************/
8146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNConvolutionForward( CvCNNLayer* _layer,
8156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                      const CvMat* X,
8166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                      CvMat* Y )
8176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
8186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("icvCNNConvolutionForward");
8196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !ICV_IS_CNN_CONVOLUTION_LAYER(_layer) )
8216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid layer" );
8226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {__BEGIN__;
8246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const CvCNNConvolutionLayer* layer = (CvCNNConvolutionLayer*) _layer;
8266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int K = layer->K;
8286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int n_weights_for_Yplane = K*K + 1;
8296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int nXplanes = layer->n_input_planes;
8316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int Xheight  = layer->input_height;
8326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int Xwidth   = layer->input_width ;
8336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int Xsize    = Xwidth*Xheight;
8346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int nYplanes = layer->n_output_planes;
8366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int Yheight  = layer->output_height;
8376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int Ywidth   = layer->output_width;
8386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int Ysize    = Ywidth*Yheight;
8396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int xx, yy, ni, no, kx, ky;
8416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float *Yplane = 0, *Xplane = 0, *w = 0;
8426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    uchar* connect_mask_data = 0;
8436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_ASSERT( X->rows == nXplanes*Xsize && X->cols == 1 );
8456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_ASSERT( Y->rows == nYplanes*Ysize && Y->cols == 1 );
8466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvSetZero( Y );
8486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    Yplane = Y->data.fl;
8506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    connect_mask_data = layer->connect_mask->data.ptr;
8516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    w = layer->weights->data.fl;
8526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( no = 0; no < nYplanes; no++, Yplane += Ysize, w += n_weights_for_Yplane )
8536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
8546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        Xplane = X->data.fl;
8556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( ni = 0; ni < nXplanes; ni++, Xplane += Xsize, connect_mask_data++ )
8566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
8576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( *connect_mask_data )
8586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
8596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                float* Yelem = Yplane;
8606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                // Xheight-K+1 == Yheight && Xwidth-K+1 == Ywidth
8626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( yy = 0; yy < Xheight-K+1; yy++ )
8636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
8646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    for( xx = 0; xx < Xwidth-K+1; xx++, Yelem++ )
8656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    {
8666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        float* templ = Xplane+yy*Xwidth+xx;
8676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        float WX = 0;
8686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        for( ky = 0; ky < K; ky++, templ += Xwidth-K )
8696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        {
8706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            for( kx = 0; kx < K; kx++, templ++ )
8716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            {
8726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                WX += *templ*w[ky*K+kx];
8736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            }
8746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        }
8756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        *Yelem += WX + w[K*K];
8766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    }
8776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
8786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
8796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
8806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
8816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }__END__;
8826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
8836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/
8856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNSubSamplingForward( CvCNNLayer* _layer,
8866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                      const CvMat* X,
8876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                      CvMat* Y )
8886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
8896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("icvCNNSubSamplingForward");
8906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !ICV_IS_CNN_SUBSAMPLING_LAYER(_layer) )
8926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid layer" );
8936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {__BEGIN__;
8956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const CvCNNSubSamplingLayer* layer = (CvCNNSubSamplingLayer*) _layer;
8976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int sub_sampl_scale = layer->sub_samp_scale;
8996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int nplanes = layer->n_input_planes;
9006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int Xheight = layer->input_height;
9026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int Xwidth  = layer->input_width ;
9036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int Xsize   = Xwidth*Xheight;
9046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int Yheight = layer->output_height;
9066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int Ywidth  = layer->output_width;
9076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int Ysize   = Ywidth*Yheight;
9086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int xx, yy, ni, kx, ky;
9106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float* sumX_data = 0, *w = 0;
9116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat sumX_sub_col, exp2ssumWX_sub_col;
9126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_ASSERT(X->rows == nplanes*Xsize && X->cols == 1);
9146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_ASSERT(layer->exp2ssumWX->cols == 1 && layer->exp2ssumWX->rows == nplanes*Ysize);
9156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // update inner variable layer->exp2ssumWX, which will be used in back-progation
9176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvZero( layer->sumX );
9186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvZero( layer->exp2ssumWX );
9196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( ky = 0; ky < sub_sampl_scale; ky++ )
9216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( kx = 0; kx < sub_sampl_scale; kx++ )
9226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
9236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            float* Xplane = X->data.fl;
9246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            sumX_data = layer->sumX->data.fl;
9256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( ni = 0; ni < nplanes; ni++, Xplane += Xsize )
9266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
9276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( yy = 0; yy < Yheight; yy++ )
9286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    for( xx = 0; xx < Ywidth; xx++, sumX_data++ )
9296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        *sumX_data += Xplane[((yy+ky)*Xwidth+(xx+kx))];
9306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
9316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
9326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    w = layer->weights->data.fl;
9346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvGetRows( layer->sumX, &sumX_sub_col, 0, Ysize );
9356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvGetRows( layer->exp2ssumWX, &exp2ssumWX_sub_col, 0, Ysize );
9366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( ni = 0; ni < nplanes; ni++, w += 2 )
9376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
9386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(cvConvertScale( &sumX_sub_col, &exp2ssumWX_sub_col, w[0], w[1] ));
9396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sumX_sub_col.data.fl += Ysize;
9406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        exp2ssumWX_sub_col.data.fl += Ysize;
9416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
9426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvScale( layer->exp2ssumWX, layer->exp2ssumWX, 2.0*layer->s ));
9446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvExp( layer->exp2ssumWX, layer->exp2ssumWX ));
9456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvMinS( layer->exp2ssumWX, FLT_MAX, layer->exp2ssumWX ));
9466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//#ifdef _DEBUG
9476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
9486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        float* exp2ssumWX_data = layer->exp2ssumWX->data.fl;
9496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( ni = 0; ni < layer->exp2ssumWX->rows; ni++, exp2ssumWX_data++ )
9506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
9516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( *exp2ssumWX_data == FLT_MAX )
9526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvSetErrStatus( 1 );
9536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
9546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
9556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//#endif
9566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // compute the output variable Y == ( a - 2a/(layer->exp2ssumWX + 1))
9576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvAddS( layer->exp2ssumWX, cvRealScalar(1), Y ));
9586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvDiv( 0, Y, Y, -2.0*layer->a ));
9596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvAddS( Y, cvRealScalar(layer->a), Y ));
9606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }__END__;
9626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
9636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/
9656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNFullConnectForward( CvCNNLayer* _layer, const CvMat* X, CvMat* Y )
9666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
9676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("icvCNNFullConnectForward");
9686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !ICV_IS_CNN_FULLCONNECT_LAYER(_layer) )
9706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid layer" );
9716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {__BEGIN__;
9736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const CvCNNFullConnectLayer* layer = (CvCNNFullConnectLayer*)_layer;
9756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* weights = layer->weights;
9766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat sub_weights, bias;
9776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_ASSERT(X->cols == 1 && X->rows == layer->n_input_planes);
9796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_ASSERT(Y->cols == 1 && Y->rows == layer->n_output_planes);
9806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvGetSubRect( weights, &sub_weights,
9826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                          cvRect(0, 0, weights->cols-1, weights->rows )));
9836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvGetCol( weights, &bias, weights->cols-1));
9846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // update inner variable layer->exp2ssumWX, which will be used in Back-Propagation
9866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvGEMM( &sub_weights, X, 2*layer->s, &bias, 2*layer->s, layer->exp2ssumWX ));
9876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvExp( layer->exp2ssumWX, layer->exp2ssumWX ));
9886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvMinS( layer->exp2ssumWX, FLT_MAX, layer->exp2ssumWX ));
9896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//#ifdef _DEBUG
9906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
9916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        float* exp2ssumWX_data = layer->exp2ssumWX->data.fl;
9926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int i;
9936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < layer->exp2ssumWX->rows; i++, exp2ssumWX_data++ )
9946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
9956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( *exp2ssumWX_data == FLT_MAX )
9966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvSetErrStatus( 1 );
9976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
9986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
9996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//#endif
10006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // compute the output variable Y == ( a - 2a/(layer->exp2ssumWX + 1))
10016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvAddS( layer->exp2ssumWX, cvRealScalar(1), Y ));
10026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvDiv( 0, Y, Y, -2.0*layer->a ));
10036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvAddS( Y, cvRealScalar(layer->a), Y ));
10046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }__END__;
10066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
10076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************\
10096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn*                           Layer BACKWARD functions                                     *
10106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn\****************************************************************************************/
10116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/* <dE_dY>, <dE_dX> should be row-vectors.
10136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   Function computes partial derivatives <dE_dX>
10146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   of the loss function with respect to the planes components
10156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   of the previous layer (X).
10166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   It is a basic function for back propagation method.
10176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   Input parameter <dE_dY> is the partial derivative of the
10186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   loss function with respect to the planes components
10196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   of the current layer. */
10206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNConvolutionBackward(
10216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNLayer* _layer, int t, const CvMat* X, const CvMat* dE_dY, CvMat* dE_dX )
10226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
10236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* dY_dX = 0;
10246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* dY_dW = 0;
10256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* dE_dW = 0;
10266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("icvCNNConvolutionBackward");
10286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !ICV_IS_CNN_CONVOLUTION_LAYER(_layer) )
10306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid layer" );
10316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {__BEGIN__;
10336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const CvCNNConvolutionLayer* layer = (CvCNNConvolutionLayer*) _layer;
10356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int K = layer->K;
10376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int n_X_planes     = layer->n_input_planes;
10396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int X_plane_height = layer->input_height;
10406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int X_plane_width  = layer->input_width;
10416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int X_plane_size   = X_plane_height*X_plane_width;
10426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int n_Y_planes     = layer->n_output_planes;
10446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int Y_plane_height = layer->output_height;
10456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int Y_plane_width  = layer->output_width;
10466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int Y_plane_size   = Y_plane_height*Y_plane_width;
10476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int no, ni, yy, xx, ky, kx;
10496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int X_idx = 0, Y_idx = 0;
10506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float *X_plane = 0, *w = 0;
10526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* weights = layer->weights;
10546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_ASSERT( t >= 1 );
10566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_ASSERT( n_Y_planes == weights->rows );
10576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    dY_dX = cvCreateMat( n_Y_planes*Y_plane_size, X->rows, CV_32FC1 );
10596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    dY_dW = cvCreateMat( dY_dX->rows, weights->cols*weights->rows, CV_32FC1 );
10606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    dE_dW = cvCreateMat( 1, dY_dW->cols, CV_32FC1 );
10616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvZero( dY_dX );
10636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvZero( dY_dW );
10646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // compute gradient of the loss function with respect to X and W
10666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( no = 0; no < n_Y_planes; no++, Y_idx += Y_plane_size )
10676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
10686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        w = weights->data.fl + no*(K*K+1);
10696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        X_idx = 0;
10706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        X_plane = X->data.fl;
10716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( ni = 0; ni < n_X_planes; ni++, X_plane += X_plane_size )
10726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
10736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( layer->connect_mask->data.ptr[ni*n_Y_planes+no] )
10746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
10756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( yy = 0; yy < X_plane_height - K + 1; yy++ )
10766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
10776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    for( xx = 0; xx < X_plane_width - K + 1; xx++ )
10786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    {
10796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        for( ky = 0; ky < K; ky++ )
10806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        {
10816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            for( kx = 0; kx < K; kx++ )
10826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            {
10836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                CV_MAT_ELEM(*dY_dX, float, Y_idx+yy*Y_plane_width+xx,
10846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                    X_idx+(yy+ky)*X_plane_width+(xx+kx)) = w[ky*K+kx];
10856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                // dY_dWi, i=1,...,K*K
10876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                CV_MAT_ELEM(*dY_dW, float, Y_idx+yy*Y_plane_width+xx,
10886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                    no*(K*K+1)+ky*K+kx) +=
10896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                    X_plane[(yy+ky)*X_plane_width+(xx+kx)];
10906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            }
10916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        }
10926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        // dY_dW(K*K+1)==1 because W(K*K+1) is bias
10936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        CV_MAT_ELEM(*dY_dW, float, Y_idx+yy*Y_plane_width+xx,
10946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            no*(K*K+1)+K*K) += 1;
10956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    }
10966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
10976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
10986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            X_idx += X_plane_size;
10996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
11006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
11016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvMatMul( dE_dY, dY_dW, dE_dW ));
11036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvMatMul( dE_dY, dY_dX, dE_dX ));
11046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // update weights
11066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
11076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvMat dE_dW_mat;
11086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        float eta;
11096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( layer->learn_rate_decrease_type == CV_CNN_LEARN_RATE_DECREASE_LOG_INV )
11106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            eta = -layer->init_learn_rate/logf(1+(float)t);
11116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else if( layer->learn_rate_decrease_type == CV_CNN_LEARN_RATE_DECREASE_SQRT_INV )
11126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            eta = -layer->init_learn_rate/sqrtf((float)t);
11136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
11146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            eta = -layer->init_learn_rate/(float)t;
11156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReshape( dE_dW, &dE_dW_mat, 0, weights->rows );
11166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvScaleAdd( &dE_dW_mat, cvRealScalar(eta), weights, weights );
11176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
11186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }__END__;
11206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &dY_dX );
11226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &dY_dW );
11236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &dE_dW );
11246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
11256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/
11276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNSubSamplingBackward(
11286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNLayer* _layer, int t, const CvMat*, const CvMat* dE_dY, CvMat* dE_dX )
11296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
11306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // derivative of activation function
11316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* dY_dX_elems = 0; // elements of matrix dY_dX
11326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* dY_dW_elems = 0; // elements of matrix dY_dW
11336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* dE_dW = 0;
11346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("icvCNNSubSamplingBackward");
11366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !ICV_IS_CNN_SUBSAMPLING_LAYER(_layer) )
11386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid layer" );
11396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {__BEGIN__;
11416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const CvCNNSubSamplingLayer* layer = (CvCNNSubSamplingLayer*) _layer;
11436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int Xwidth  = layer->input_width;
11456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int Ywidth  = layer->output_width;
11466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int Yheight = layer->output_height;
11476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int Ysize   = Ywidth * Yheight;
11486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int scale   = layer->sub_samp_scale;
11496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int k_max   = layer->n_output_planes * Yheight;
11506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int k, i, j, m;
11526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float* dY_dX_current_elem = 0, *dE_dX_start = 0, *dE_dW_data = 0, *w = 0;
11536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat dy_dw0, dy_dw1;
11546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat activ_func_der, sumX_row;
11556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat dE_dY_sub_row, dY_dX_sub_col, dy_dw0_sub_row, dy_dw1_sub_row;
11566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(dY_dX_elems = cvCreateMat( layer->sumX->rows, 1, CV_32FC1 ));
11586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(dY_dW_elems = cvCreateMat( 2, layer->sumX->rows, CV_32FC1 ));
11596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(dE_dW = cvCreateMat( 1, 2*layer->n_output_planes, CV_32FC1 ));
11606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // compute derivative of activ.func.
11626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // ==<dY_dX_elems> = 4as*(layer->exp2ssumWX)/(layer->exp2ssumWX + 1)^2
11636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvAddS( layer->exp2ssumWX, cvRealScalar(1), dY_dX_elems ));
11646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvPow( dY_dX_elems, dY_dX_elems, -2.0 ));
11656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvMul( dY_dX_elems, layer->exp2ssumWX, dY_dX_elems, 4.0*layer->a*layer->s ));
11666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // compute <dE_dW>
11686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // a) compute <dY_dW_elems>
11696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReshape( dY_dX_elems, &activ_func_der, 0, 1 );
11706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvGetRow( dY_dW_elems, &dy_dw0, 0 );
11716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvGetRow( dY_dW_elems, &dy_dw1, 1 );
11726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvCopy( &activ_func_der, &dy_dw0 ));
11736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvCopy( &activ_func_der, &dy_dw1 ));
11746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReshape( layer->sumX, &sumX_row, 0, 1 );
11766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvMul( &dy_dw0, &sumX_row, &dy_dw0 );
11776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // b) compute <dE_dW> = <dE_dY>*<dY_dW_elems>
11796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvGetCols( dE_dY, &dE_dY_sub_row, 0, Ysize );
11806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvGetCols( &dy_dw0, &dy_dw0_sub_row, 0, Ysize );
11816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvGetCols( &dy_dw1, &dy_dw1_sub_row, 0, Ysize );
11826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    dE_dW_data = dE_dW->data.fl;
11836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < layer->n_output_planes; i++ )
11846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
11856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        *dE_dW_data++ = (float)cvDotProduct( &dE_dY_sub_row, &dy_dw0_sub_row );
11866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        *dE_dW_data++ = (float)cvDotProduct( &dE_dY_sub_row, &dy_dw1_sub_row );
11876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        dE_dY_sub_row.data.fl += Ysize;
11896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        dy_dw0_sub_row.data.fl += Ysize;
11906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        dy_dw1_sub_row.data.fl += Ysize;
11916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
11926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // compute <dY_dX> = layer->weights*<dY_dX>
11946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    w = layer->weights->data.fl;
11956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvGetRows( dY_dX_elems, &dY_dX_sub_col, 0, Ysize );
11966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < layer->n_input_planes; i++, w++, dY_dX_sub_col.data.fl += Ysize )
11976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(cvConvertScale( &dY_dX_sub_col, &dY_dX_sub_col, (float)*w ));
11986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // compute <dE_dX>
12006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvReshape( dY_dX_elems, dY_dX_elems, 0, 1 ));
12016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvMul( dY_dX_elems, dE_dY, dY_dX_elems ));
12026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    dY_dX_current_elem = dY_dX_elems->data.fl;
12046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    dE_dX_start = dE_dX->data.fl;
12056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( k = 0; k < k_max; k++ )
12066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
12076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < Ywidth; i++, dY_dX_current_elem++ )
12086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
12096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            float* dE_dX_current_elem = dE_dX_start;
12106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j < scale; j++, dE_dX_current_elem += Xwidth - scale )
12116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
12126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( m = 0; m < scale; m++, dE_dX_current_elem++ )
12136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    *dE_dX_current_elem = *dY_dX_current_elem;
12146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
12156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            dE_dX_start += scale;
12166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
12176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        dE_dX_start += Xwidth * (scale - 1);
12186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
12196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // update weights
12216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
12226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvMat dE_dW_mat, *weights = layer->weights;
12236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        float eta;
12246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( layer->learn_rate_decrease_type == CV_CNN_LEARN_RATE_DECREASE_LOG_INV )
12256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            eta = -layer->init_learn_rate/logf(1+(float)t);
12266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else if( layer->learn_rate_decrease_type == CV_CNN_LEARN_RATE_DECREASE_SQRT_INV )
12276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            eta = -layer->init_learn_rate/sqrtf((float)t);
12286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
12296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            eta = -layer->init_learn_rate/(float)t;
12306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReshape( dE_dW, &dE_dW_mat, 0, weights->rows );
12316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvScaleAdd( &dE_dW_mat, cvRealScalar(eta), weights, weights );
12326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
12336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }__END__;
12356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &dY_dX_elems );
12376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &dY_dW_elems );
12386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &dE_dW );
12396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
12406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/
12426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/* <dE_dY>, <dE_dX> should be row-vectors.
12436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   Function computes partial derivatives <dE_dX>, <dE_dW>
12446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   of the loss function with respect to the planes components
12456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   of the previous layer (X) and the weights of the current layer (W)
12466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   and updates weights od the current layer by using <dE_dW>.
12476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   It is a basic function for back propagation method.
12486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   Input parameter <dE_dY> is the partial derivative of the
12496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   loss function with respect to the planes components
12506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   of the current layer. */
12516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNFullConnectBackward( CvCNNLayer* _layer,
12526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                    int t,
12536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                    const CvMat* X,
12546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                    const CvMat* dE_dY,
12556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                    CvMat* dE_dX )
12566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
12576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* dE_dY_activ_func_der = 0;
12586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* dE_dW = 0;
12596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "icvCNNFullConnectBackward" );
12616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !ICV_IS_CNN_FULLCONNECT_LAYER(_layer) )
12636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid layer" );
12646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {__BEGIN__;
12666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const CvCNNFullConnectLayer* layer = (CvCNNFullConnectLayer*)_layer;
12686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int n_outputs = layer->n_output_planes;
12696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int n_inputs  = layer->n_input_planes;
12706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i;
12726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float* dE_dY_activ_func_der_data;
12736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* weights = layer->weights;
12746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat sub_weights, Xtemplate, Xrow, exp2ssumWXrow;
12756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_ASSERT(X->cols == 1 && X->rows == n_inputs);
12776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_ASSERT(dE_dY->rows == 1 && dE_dY->cols == n_outputs );
12786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_ASSERT(dE_dX->rows == 1 && dE_dX->cols == n_inputs );
12796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // we violate the convetion about vector's orientation because
12816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // here is more convenient to make this parameter a row-vector
12826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(dE_dY_activ_func_der = cvCreateMat( 1, n_outputs, CV_32FC1 ));
12836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(dE_dW = cvCreateMat( 1, weights->rows*weights->cols, CV_32FC1 ));
12846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // 1) compute gradients dE_dX and dE_dW
12866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // activ_func_der == 4as*(layer->exp2ssumWX)/(layer->exp2ssumWX + 1)^2
12876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvReshape( layer->exp2ssumWX, &exp2ssumWXrow, 0, layer->exp2ssumWX->cols ));
12886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvAddS( &exp2ssumWXrow, cvRealScalar(1), dE_dY_activ_func_der ));
12896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvPow( dE_dY_activ_func_der, dE_dY_activ_func_der, -2.0 ));
12906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvMul( dE_dY_activ_func_der, &exp2ssumWXrow, dE_dY_activ_func_der,
12916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                   4.0*layer->a*layer->s ));
12926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvMul( dE_dY, dE_dY_activ_func_der, dE_dY_activ_func_der ));
12936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // sub_weights = d(W*(X|1))/dX
12956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvGetSubRect( weights, &sub_weights,
12966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvRect(0, 0, weights->cols-1, weights->rows) ));
12976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvMatMul( dE_dY_activ_func_der, &sub_weights, dE_dX ));
12986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReshape( X, &Xrow, 0, 1 );
13006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    dE_dY_activ_func_der_data = dE_dY_activ_func_der->data.fl;
13016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    Xtemplate = cvMat( 1, n_inputs, CV_32FC1, dE_dW->data.fl );
13026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < n_outputs; i++, Xtemplate.data.fl += n_inputs + 1 )
13036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
13046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(cvConvertScale( &Xrow, &Xtemplate, *dE_dY_activ_func_der_data ));
13056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        Xtemplate.data.fl[n_inputs] = *dE_dY_activ_func_der_data++;
13066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
13076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // 2) update weights
13096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
13106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvMat dE_dW_mat;
13116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        float eta;
13126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( layer->learn_rate_decrease_type == CV_CNN_LEARN_RATE_DECREASE_LOG_INV )
13136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            eta = -layer->init_learn_rate/logf(1+(float)t);
13146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else if( layer->learn_rate_decrease_type == CV_CNN_LEARN_RATE_DECREASE_SQRT_INV )
13156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            eta = -layer->init_learn_rate/sqrtf((float)t);
13166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
13176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            eta = -layer->init_learn_rate/(float)t;
13186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReshape( dE_dW, &dE_dW_mat, 0, n_outputs );
13196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvScaleAdd( &dE_dW_mat, cvRealScalar(eta), weights, weights );
13206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
13216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }__END__;
13236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &dE_dY_activ_func_der );
13256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &dE_dW );
13266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
13276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************\
13296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn*                           Layer RELEASE functions                                      *
13306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn\****************************************************************************************/
13316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNConvolutionRelease( CvCNNLayer** p_layer )
13326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
13336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("icvCNNConvolutionRelease");
13346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
13356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNConvolutionLayer* layer = 0;
13376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !p_layer )
13396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsNullPtr, "Null double pointer" );
13406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer = *(CvCNNConvolutionLayer**)p_layer;
13426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !layer )
13446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        return;
13456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !ICV_IS_CNN_CONVOLUTION_LAYER(layer) )
13466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid layer" );
13476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &layer->weights );
13496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &layer->connect_mask );
13506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( p_layer );
13516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
13536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
13546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/
13566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNSubSamplingRelease( CvCNNLayer** p_layer )
13576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
13586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("icvCNNSubSamplingRelease");
13596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
13606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNSubSamplingLayer* layer = 0;
13626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !p_layer )
13646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsNullPtr, "Null double pointer" );
13656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer = *(CvCNNSubSamplingLayer**)p_layer;
13676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !layer )
13696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        return;
13706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !ICV_IS_CNN_SUBSAMPLING_LAYER(layer) )
13716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid layer" );
13726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &layer->exp2ssumWX );
13746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &layer->weights );
13756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( p_layer );
13766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
13786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
13796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/
13816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvCNNFullConnectRelease( CvCNNLayer** p_layer )
13826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
13836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("icvCNNFullConnectRelease");
13846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
13856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNFullConnectLayer* layer = 0;
13876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !p_layer )
13896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsNullPtr, "Null double pointer" );
13906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer = *(CvCNNFullConnectLayer**)p_layer;
13926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !layer )
13946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        return;
13956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !ICV_IS_CNN_FULLCONNECT_LAYER(layer) )
13966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid layer" );
13976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &layer->exp2ssumWX );
13996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &layer->weights );
14006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( p_layer );
14016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
14036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
14046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************\
14066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn*                              Read/Write CNN classifier                                 *
14076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn\****************************************************************************************/
14086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic int icvIsCNNModel( const void* ptr )
14096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
14106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return CV_IS_CNN(ptr);
14116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
14126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/
14146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvReleaseCNNModel( void** ptr )
14156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
14166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("icvReleaseCNNModel");
14176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
14186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !ptr )
14206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsNullPtr, "NULL double pointer" );
14216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_ASSERT(CV_IS_CNN(*ptr));
14226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    icvCNNModelRelease( (CvStatModel**)ptr );
14246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
14266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
14276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/
14296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic CvCNNLayer* icvReadCNNLayer( CvFileStorage* fs, CvFileNode* node )
14306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
14316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNLayer* layer = 0;
14326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* weights    = 0;
14336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* connect_mask = 0;
14346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("icvReadCNNLayer");
14366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
14376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n_input_planes, input_height, input_width;
14396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n_output_planes, output_height, output_width;
14406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int learn_type, layer_type;
14416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float init_learn_rate;
14426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(n_input_planes  = cvReadIntByName( fs, node, "n_input_planes",  -1 ));
14446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(input_height    = cvReadIntByName( fs, node, "input_height",    -1 ));
14456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(input_width     = cvReadIntByName( fs, node, "input_width",     -1 ));
14466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(n_output_planes = cvReadIntByName( fs, node, "n_output_planes", -1 ));
14476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(output_height   = cvReadIntByName( fs, node, "output_height",   -1 ));
14486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(output_width    = cvReadIntByName( fs, node, "output_width",    -1 ));
14496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(layer_type      = cvReadIntByName( fs, node, "layer_type",      -1 ));
14506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(init_learn_rate = (float)cvReadRealByName( fs, node, "init_learn_rate", -1 ));
14526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(learn_type = cvReadIntByName( fs, node, "learn_rate_decrease_type", -1 ));
14536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(weights    = (CvMat*)cvReadByName( fs, node, "weights" ));
14546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( n_input_planes < 0  || input_height < 0  || input_width < 0 ||
14566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        n_output_planes < 0 || output_height < 0 || output_width < 0 ||
14576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        init_learn_rate < 0 || learn_type < 0 || layer_type < 0 || !weights )
14586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsParseError, "" );
14596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( layer_type == ICV_CNN_CONVOLUTION_LAYER )
14616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
14626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const int K = input_height - output_height + 1;
14636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( K <= 0 || K != input_width - output_width + 1 )
14646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadArg, "Invalid <K>" );
14656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(connect_mask = (CvMat*)cvReadByName( fs, node, "connect_mask" ));
14676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !connect_mask )
14686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsParseError, "Missing <connect mask>" );
14696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(layer = cvCreateCNNConvolutionLayer(
14716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            n_input_planes, input_height, input_width, n_output_planes, K,
14726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            init_learn_rate, learn_type, connect_mask, weights ));
14736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
14746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else if( layer_type == ICV_CNN_SUBSAMPLING_LAYER )
14756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
14766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        float a, s;
14776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const int sub_samp_scale = input_height/output_height;
14786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( sub_samp_scale <= 0 || sub_samp_scale != input_width/output_width )
14806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadArg, "Invalid <sub_samp_scale>" );
14816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(a = (float)cvReadRealByName( fs, node, "a", -1 ));
14836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(s = (float)cvReadRealByName( fs, node, "s", -1 ));
14846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( a  < 0 || s  < 0 )
14856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsParseError, "Missing <a> or <s>" );
14866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(layer = cvCreateCNNSubSamplingLayer(
14886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            n_input_planes, input_height, input_width, sub_samp_scale,
14896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            a, s, init_learn_rate, learn_type, weights ));
14906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
14916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else if( layer_type == ICV_CNN_FULLCONNECT_LAYER )
14926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
14936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        float a, s;
14946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(a = (float)cvReadRealByName( fs, node, "a", -1 ));
14956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(s = (float)cvReadRealByName( fs, node, "s", -1 ));
14966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( a  < 0 || s  < 0 )
14976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsParseError, "" );
14986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( input_height != 1  || input_width != 1 ||
14996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            output_height != 1 || output_width != 1 )
15006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadArg, "" );
15016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(layer = cvCreateCNNFullConnectLayer( n_input_planes, n_output_planes,
15036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            a, s, init_learn_rate, learn_type, weights ));
15046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
15056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
15066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid <layer_type>" );
15076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
15096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( cvGetErrStatus() < 0 && layer )
15116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        layer->release( &layer );
15126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &weights );
15146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &connect_mask );
15156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return layer;
15176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
15186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/
15206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void icvWriteCNNLayer( CvFileStorage* fs, CvCNNLayer* layer )
15216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
15226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME ("icvWriteCNNLayer");
15236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
15246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !ICV_IS_CNN_LAYER(layer) )
15266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid layer" );
15276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( cvStartWriteStruct( fs, NULL, CV_NODE_MAP, "opencv-ml-cnn-layer" ));
15296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvWriteInt( fs, "n_input_planes",  layer->n_input_planes ));
15316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvWriteInt( fs, "input_height",    layer->input_height ));
15326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvWriteInt( fs, "input_width",     layer->input_width ));
15336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvWriteInt( fs, "n_output_planes", layer->n_output_planes ));
15346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvWriteInt( fs, "output_height",   layer->output_height ));
15356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvWriteInt( fs, "output_width",    layer->output_width ));
15366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvWriteInt( fs, "learn_rate_decrease_type", layer->learn_rate_decrease_type));
15376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvWriteReal( fs, "init_learn_rate", layer->init_learn_rate ));
15386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvWrite( fs, "weights", layer->weights ));
15396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( ICV_IS_CNN_CONVOLUTION_LAYER( layer ))
15416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
15426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvCNNConvolutionLayer* l = (CvCNNConvolutionLayer*)layer;
15436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(cvWriteInt( fs, "layer_type", ICV_CNN_CONVOLUTION_LAYER ));
15446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(cvWrite( fs, "connect_mask", l->connect_mask ));
15456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
15466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else if( ICV_IS_CNN_SUBSAMPLING_LAYER( layer ) )
15476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
15486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvCNNSubSamplingLayer* l = (CvCNNSubSamplingLayer*)layer;
15496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(cvWriteInt( fs, "layer_type", ICV_CNN_SUBSAMPLING_LAYER ));
15506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(cvWriteReal( fs, "a", l->a ));
15516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(cvWriteReal( fs, "s", l->s ));
15526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
15536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else if( ICV_IS_CNN_FULLCONNECT_LAYER( layer ) )
15546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
15556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvCNNFullConnectLayer* l = (CvCNNFullConnectLayer*)layer;
15566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(cvWriteInt( fs, "layer_type", ICV_CNN_FULLCONNECT_LAYER ));
15576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(cvWriteReal( fs, "a", l->a ));
15586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(cvWriteReal( fs, "s", l->s ));
15596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
15606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
15616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid layer" );
15626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( cvEndWriteStruct( fs )); //"opencv-ml-cnn-layer"
15646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
15666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
15676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/
15696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void* icvReadCNNModel( CvFileStorage* fs, CvFileNode* root_node )
15706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
15716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNStatModel* cnn = 0;
15726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNLayer* layer = 0;
15736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("icvReadCNNModel");
15756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
15766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvFileNode* node;
15786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvSeq* seq;
15796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvSeqReader reader;
15806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i;
15816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cnn = (CvCNNStatModel*)cvCreateStatModel(
15836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_STAT_MODEL_MAGIC_VAL|CV_CNN_MAGIC_VAL, sizeof(CvCNNStatModel),
15846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        icvCNNModelRelease, icvCNNModelPredict, icvCNNModelUpdate ));
15856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cnn->etalons = (CvMat*)cvReadByName( fs, root_node, "etalons" ));
15876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cnn->cls_labels = (CvMat*)cvReadByName( fs, root_node, "cls_labels" ));
15886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !cnn->etalons || !cnn->cls_labels )
15906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsParseError, "No <etalons> or <cls_labels> in CNN model" );
15916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( node = cvGetFileNodeByName( fs, root_node, "network" ));
15936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    seq = node->data.seq;
15946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !CV_NODE_IS_SEQ(node->tag) )
15956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "" );
15966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
15986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(layer = icvReadCNNLayer( fs, (CvFileNode*)reader.ptr ));
15996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cnn->network = cvCreateCNNetwork( layer ));
16006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 1; i < seq->total; i++ )
16026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
16036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
16046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(layer = icvReadCNNLayer( fs, (CvFileNode*)reader.ptr ));
16056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(cnn->network->add_layer( cnn->network, layer ));
16066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
16076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
16096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( cvGetErrStatus() < 0 )
16116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
16126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( cnn ) cnn->release( (CvStatModel**)&cnn );
16136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( layer ) layer->release( &layer );
16146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
16156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return (void*)cnn;
16166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
16176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/
16196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic void
16206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennicvWriteCNNModel( CvFileStorage* fs, const char* name,
16216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                  const void* struct_ptr, CvAttrList )
16226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
16246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME ("icvWriteCNNModel");
16256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
16266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNStatModel* cnn = (CvCNNStatModel*)struct_ptr;
16286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n_layers, i;
16296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvCNNLayer* layer;
16306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !CV_IS_CNN(cnn) )
16326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid pointer" );
16336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    n_layers = cnn->network->n_layers;
16356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_CNN ));
16376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvWrite( fs, "etalons", cnn->etalons ));
16396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(cvWrite( fs, "cls_labels", cnn->cls_labels ));
16406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( cvStartWriteStruct( fs, "network", CV_NODE_SEQ ));
16426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    layer = cnn->network->layers;
16446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < n_layers && layer; i++, layer = layer->next_layer )
16456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(icvWriteCNNLayer( fs, layer ));
16466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( i < n_layers || layer )
16476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid network" );
16486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( cvEndWriteStruct( fs )); //"network"
16506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( cvEndWriteStruct( fs )); //"opencv-ml-cnn"
16516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
16536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
16546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic int icvRegisterCNNStatModelType()
16566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
16576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvTypeInfo info;
16586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    info.header_size = sizeof( info );
16606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    info.is_instance = icvIsCNNModel;
16616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    info.release = icvReleaseCNNModel;
16626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    info.read = icvReadCNNModel;
16636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    info.write = icvWriteCNNModel;
16646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    info.clone = NULL;
16656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    info.type_name = CV_TYPE_NAME_ML_CNN;
16666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvRegisterType( &info );
16676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return 1;
16696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} // End of icvRegisterCNNStatModelType
16706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic int cnn = icvRegisterCNNStatModelType();
16726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#endif
16746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// End of file
1676