1/*M///////////////////////////////////////////////////////////////////////////////////////
2//
3//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4//
5//  By downloading, copying, installing or using the software you agree to this license.
6//  If you do not agree to this license, do not download, install,
7//  copy or use the software.
8//
9//
10//                        Intel License Agreement
11//
12// Copyright (C) 2000, Intel Corporation, all rights reserved.
13// Third party copyrights are property of their respective owners.
14//
15// Redistribution and use in source and binary forms, with or without modification,
16// are permitted provided that the following conditions are met:
17//
18//   * Redistribution's of source code must retain the above copyright notice,
19//     this list of conditions and the following disclaimer.
20//
21//   * Redistribution's in binary form must reproduce the above copyright notice,
22//     this list of conditions and the following disclaimer in the documentation
23//     and/or other materials provided with the distribution.
24//
25//   * The name of Intel Corporation may not be used to endorse or promote products
26//     derived from this software without specific prior written permission.
27//
28// This software is provided by the copyright holders and contributors "as is" and
29// any express or implied warranties, including, but not limited to, the implied
30// warranties of merchantability and fitness for a particular purpose are disclaimed.
31// In no event shall the Intel Corporation or contributors be liable for any direct,
32// indirect, incidental, special, exemplary, or consequential damages
33// (including, but not limited to, procurement of substitute goods or services;
34// loss of use, data, or profits; or business interruption) however caused
35// and on any theory of liability, whether in contract, strict liability,
36// or tort (including negligence or otherwise) arising in any way out of
37// the use of this software, even if advised of the possibility of such damage.
38//
39//M*/
40
41#ifndef __ML_H__
42#define __ML_H__
43
44// disable deprecation warning which appears in VisualStudio 8.0
45#if _MSC_VER >= 1400
46#pragma warning( disable : 4996 )
47#endif
48
49#ifndef SKIP_INCLUDES
50
51  #include "cxcore.h"
52  #include <limits.h>
53
54  #if defined WIN32 || defined WIN64
55    #include <windows.h>
56  #endif
57
58#else // SKIP_INCLUDES
59
60  #if defined WIN32 || defined WIN64
61    #define CV_CDECL __cdecl
62    #define CV_STDCALL __stdcall
63  #else
64    #define CV_CDECL
65    #define CV_STDCALL
66  #endif
67
68  #ifndef CV_EXTERN_C
69    #ifdef __cplusplus
70      #define CV_EXTERN_C extern "C"
71      #define CV_DEFAULT(val) = val
72    #else
73      #define CV_EXTERN_C
74      #define CV_DEFAULT(val)
75    #endif
76  #endif
77
78  #ifndef CV_EXTERN_C_FUNCPTR
79    #ifdef __cplusplus
80      #define CV_EXTERN_C_FUNCPTR(x) extern "C" { typedef x; }
81    #else
82      #define CV_EXTERN_C_FUNCPTR(x) typedef x
83    #endif
84  #endif
85
86  #ifndef CV_INLINE
87    #if defined __cplusplus
88      #define CV_INLINE inline
89    #elif (defined WIN32 || defined WIN64) && !defined __GNUC__
90      #define CV_INLINE __inline
91    #else
92      #define CV_INLINE static
93    #endif
94  #endif /* CV_INLINE */
95
96  #if (defined WIN32 || defined WIN64) && defined CVAPI_EXPORTS
97    #define CV_EXPORTS __declspec(dllexport)
98  #else
99    #define CV_EXPORTS
100  #endif
101
102  #ifndef CVAPI
103    #define CVAPI(rettype) CV_EXTERN_C CV_EXPORTS rettype CV_CDECL
104  #endif
105
106#endif // SKIP_INCLUDES
107
108
109#ifdef __cplusplus
110
111// Apple defines a check() macro somewhere in the debug headers
112// that interferes with a method definiton in this header
113#undef check
114
115/****************************************************************************************\
116*                               Main struct definitions                                  *
117\****************************************************************************************/
118
119/* log(2*PI) */
120#define CV_LOG2PI (1.8378770664093454835606594728112)
121
122/* columns of <trainData> matrix are training samples */
123#define CV_COL_SAMPLE 0
124
125/* rows of <trainData> matrix are training samples */
126#define CV_ROW_SAMPLE 1
127
128#define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
129
130struct CvVectors
131{
132    int type;
133    int dims, count;
134    CvVectors* next;
135    union
136    {
137        uchar** ptr;
138        float** fl;
139        double** db;
140    } data;
141};
142
143#if 0
144/* A structure, representing the lattice range of statmodel parameters.
145   It is used for optimizing statmodel parameters by cross-validation method.
146   The lattice is logarithmic, so <step> must be greater then 1. */
147typedef struct CvParamLattice
148{
149    double min_val;
150    double max_val;
151    double step;
152}
153CvParamLattice;
154
155CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
156                                         double log_step )
157{
158    CvParamLattice pl;
159    pl.min_val = MIN( min_val, max_val );
160    pl.max_val = MAX( min_val, max_val );
161    pl.step = MAX( log_step, 1. );
162    return pl;
163}
164
165CV_INLINE CvParamLattice cvDefaultParamLattice( void )
166{
167    CvParamLattice pl = {0,0,0};
168    return pl;
169}
170#endif
171
172/* Variable type */
173#define CV_VAR_NUMERICAL    0
174#define CV_VAR_ORDERED      0
175#define CV_VAR_CATEGORICAL  1
176
177#define CV_TYPE_NAME_ML_SVM         "opencv-ml-svm"
178#define CV_TYPE_NAME_ML_KNN         "opencv-ml-knn"
179#define CV_TYPE_NAME_ML_NBAYES      "opencv-ml-bayesian"
180#define CV_TYPE_NAME_ML_EM          "opencv-ml-em"
181#define CV_TYPE_NAME_ML_BOOSTING    "opencv-ml-boost-tree"
182#define CV_TYPE_NAME_ML_TREE        "opencv-ml-tree"
183#define CV_TYPE_NAME_ML_ANN_MLP     "opencv-ml-ann-mlp"
184#define CV_TYPE_NAME_ML_CNN         "opencv-ml-cnn"
185#define CV_TYPE_NAME_ML_RTREES      "opencv-ml-random-trees"
186
187class CV_EXPORTS CvStatModel
188{
189public:
190    CvStatModel();
191    virtual ~CvStatModel();
192
193    virtual void clear();
194
195    virtual void save( const char* filename, const char* name=0 );
196    virtual void load( const char* filename, const char* name=0 );
197
198    virtual void write( CvFileStorage* storage, const char* name );
199    virtual void read( CvFileStorage* storage, CvFileNode* node );
200
201protected:
202    const char* default_model_name;
203};
204
205
206/****************************************************************************************\
207*                                 Normal Bayes Classifier                                *
208\****************************************************************************************/
209
210/* The structure, representing the grid range of statmodel parameters.
211   It is used for optimizing statmodel accuracy by varying model parameters,
212   the accuracy estimate being computed by cross-validation.
213   The grid is logarithmic, so <step> must be greater then 1. */
214struct CV_EXPORTS CvParamGrid
215{
216    // SVM params type
217    enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
218
219    CvParamGrid()
220    {
221        min_val = max_val = step = 0;
222    }
223
224    CvParamGrid( double _min_val, double _max_val, double log_step )
225    {
226        min_val = _min_val;
227        max_val = _max_val;
228        step = log_step;
229    }
230    //CvParamGrid( int param_id );
231    bool check() const;
232
233    double min_val;
234    double max_val;
235    double step;
236};
237
238class CV_EXPORTS CvNormalBayesClassifier : public CvStatModel
239{
240public:
241    CvNormalBayesClassifier();
242    virtual ~CvNormalBayesClassifier();
243
244    CvNormalBayesClassifier( const CvMat* _train_data, const CvMat* _responses,
245        const CvMat* _var_idx=0, const CvMat* _sample_idx=0 );
246
247    virtual bool train( const CvMat* _train_data, const CvMat* _responses,
248        const CvMat* _var_idx = 0, const CvMat* _sample_idx=0, bool update=false );
249
250    virtual float predict( const CvMat* _samples, CvMat* results=0 ) const;
251    virtual void clear();
252
253    virtual void write( CvFileStorage* storage, const char* name );
254    virtual void read( CvFileStorage* storage, CvFileNode* node );
255
256protected:
257    int     var_count, var_all;
258    CvMat*  var_idx;
259    CvMat*  cls_labels;
260    CvMat** count;
261    CvMat** sum;
262    CvMat** productsum;
263    CvMat** avg;
264    CvMat** inv_eigen_values;
265    CvMat** cov_rotate_mats;
266    CvMat*  c;
267};
268
269
270/****************************************************************************************\
271*                          K-Nearest Neighbour Classifier                                *
272\****************************************************************************************/
273
274// k Nearest Neighbors
275class CV_EXPORTS CvKNearest : public CvStatModel
276{
277public:
278
279    CvKNearest();
280    virtual ~CvKNearest();
281
282    CvKNearest( const CvMat* _train_data, const CvMat* _responses,
283                const CvMat* _sample_idx=0, bool _is_regression=false, int max_k=32 );
284
285    virtual bool train( const CvMat* _train_data, const CvMat* _responses,
286                        const CvMat* _sample_idx=0, bool is_regression=false,
287                        int _max_k=32, bool _update_base=false );
288
289    virtual float find_nearest( const CvMat* _samples, int k, CvMat* results=0,
290        const float** neighbors=0, CvMat* neighbor_responses=0, CvMat* dist=0 ) const;
291
292    virtual void clear();
293    int get_max_k() const;
294    int get_var_count() const;
295    int get_sample_count() const;
296    bool is_regression() const;
297
298protected:
299
300    virtual float write_results( int k, int k1, int start, int end,
301        const float* neighbor_responses, const float* dist, CvMat* _results,
302        CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
303
304    virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
305        float* neighbor_responses, const float** neighbors, float* dist ) const;
306
307
308    int max_k, var_count;
309    int total;
310    bool regression;
311    CvVectors* samples;
312};
313
314/****************************************************************************************\
315*                                   Support Vector Machines                              *
316\****************************************************************************************/
317
318// SVM training parameters
319struct CV_EXPORTS CvSVMParams
320{
321    CvSVMParams();
322    CvSVMParams( int _svm_type, int _kernel_type,
323                 double _degree, double _gamma, double _coef0,
324                 double _C, double _nu, double _p,
325                 CvMat* _class_weights, CvTermCriteria _term_crit );
326
327    int         svm_type;
328    int         kernel_type;
329    double      degree; // for poly
330    double      gamma;  // for poly/rbf/sigmoid
331    double      coef0;  // for poly/sigmoid
332
333    double      C;  // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
334    double      nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
335    double      p; // for CV_SVM_EPS_SVR
336    CvMat*      class_weights; // for CV_SVM_C_SVC
337    CvTermCriteria term_crit; // termination criteria
338};
339
340
341struct CV_EXPORTS CvSVMKernel
342{
343    typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
344                                       const float* another, float* results );
345    CvSVMKernel();
346    CvSVMKernel( const CvSVMParams* _params, Calc _calc_func );
347    virtual bool create( const CvSVMParams* _params, Calc _calc_func );
348    virtual ~CvSVMKernel();
349
350    virtual void clear();
351    virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
352
353    const CvSVMParams* params;
354    Calc calc_func;
355
356    virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
357                                    const float* another, float* results,
358                                    double alpha, double beta );
359
360    virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
361                              const float* another, float* results );
362    virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
363                           const float* another, float* results );
364    virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
365                            const float* another, float* results );
366    virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
367                               const float* another, float* results );
368};
369
370
371struct CvSVMKernelRow
372{
373    CvSVMKernelRow* prev;
374    CvSVMKernelRow* next;
375    float* data;
376};
377
378
379struct CvSVMSolutionInfo
380{
381    double obj;
382    double rho;
383    double upper_bound_p;
384    double upper_bound_n;
385    double r;   // for Solver_NU
386};
387
388class CV_EXPORTS CvSVMSolver
389{
390public:
391    typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
392    typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
393    typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
394
395    CvSVMSolver();
396
397    CvSVMSolver( int count, int var_count, const float** samples, schar* y,
398                 int alpha_count, double* alpha, double Cp, double Cn,
399                 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
400                 SelectWorkingSet select_working_set, CalcRho calc_rho );
401    virtual bool create( int count, int var_count, const float** samples, schar* y,
402                 int alpha_count, double* alpha, double Cp, double Cn,
403                 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
404                 SelectWorkingSet select_working_set, CalcRho calc_rho );
405    virtual ~CvSVMSolver();
406
407    virtual void clear();
408    virtual bool solve_generic( CvSVMSolutionInfo& si );
409
410    virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
411                              double Cp, double Cn, CvMemStorage* storage,
412                              CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
413    virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
414                               CvMemStorage* storage, CvSVMKernel* kernel,
415                               double* alpha, CvSVMSolutionInfo& si );
416    virtual bool solve_one_class( int count, int var_count, const float** samples,
417                                  CvMemStorage* storage, CvSVMKernel* kernel,
418                                  double* alpha, CvSVMSolutionInfo& si );
419
420    virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
421                                CvMemStorage* storage, CvSVMKernel* kernel,
422                                double* alpha, CvSVMSolutionInfo& si );
423
424    virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
425                               CvMemStorage* storage, CvSVMKernel* kernel,
426                               double* alpha, CvSVMSolutionInfo& si );
427
428    virtual float* get_row_base( int i, bool* _existed );
429    virtual float* get_row( int i, float* dst );
430
431    int sample_count;
432    int var_count;
433    int cache_size;
434    int cache_line_size;
435    const float** samples;
436    const CvSVMParams* params;
437    CvMemStorage* storage;
438    CvSVMKernelRow lru_list;
439    CvSVMKernelRow* rows;
440
441    int alpha_count;
442
443    double* G;
444    double* alpha;
445
446    // -1 - lower bound, 0 - free, 1 - upper bound
447    schar* alpha_status;
448
449    schar* y;
450    double* b;
451    float* buf[2];
452    double eps;
453    int max_iter;
454    double C[2];  // C[0] == Cn, C[1] == Cp
455    CvSVMKernel* kernel;
456
457    SelectWorkingSet select_working_set_func;
458    CalcRho calc_rho_func;
459    GetRow get_row_func;
460
461    virtual bool select_working_set( int& i, int& j );
462    virtual bool select_working_set_nu_svm( int& i, int& j );
463    virtual void calc_rho( double& rho, double& r );
464    virtual void calc_rho_nu_svm( double& rho, double& r );
465
466    virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
467    virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
468    virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
469};
470
471
472struct CvSVMDecisionFunc
473{
474    double rho;
475    int sv_count;
476    double* alpha;
477    int* sv_index;
478};
479
480
481// SVM model
482class CV_EXPORTS CvSVM : public CvStatModel
483{
484public:
485    // SVM type
486    enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
487
488    // SVM kernel type
489    enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 };
490
491    // SVM params type
492    enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
493
494    CvSVM();
495    virtual ~CvSVM();
496
497    CvSVM( const CvMat* _train_data, const CvMat* _responses,
498           const CvMat* _var_idx=0, const CvMat* _sample_idx=0,
499           CvSVMParams _params=CvSVMParams() );
500
501    virtual bool train( const CvMat* _train_data, const CvMat* _responses,
502                        const CvMat* _var_idx=0, const CvMat* _sample_idx=0,
503                        CvSVMParams _params=CvSVMParams() );
504    virtual bool train_auto( const CvMat* _train_data, const CvMat* _responses,
505        const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params,
506        int k_fold = 10,
507        CvParamGrid C_grid      = get_default_grid(CvSVM::C),
508        CvParamGrid gamma_grid  = get_default_grid(CvSVM::GAMMA),
509        CvParamGrid p_grid      = get_default_grid(CvSVM::P),
510        CvParamGrid nu_grid     = get_default_grid(CvSVM::NU),
511        CvParamGrid coef_grid   = get_default_grid(CvSVM::COEF),
512        CvParamGrid degree_grid = get_default_grid(CvSVM::DEGREE) );
513
514    virtual float predict( const CvMat* _sample ) const;
515
516    virtual int get_support_vector_count() const;
517    virtual const float* get_support_vector(int i) const;
518    virtual CvSVMParams get_params() const { return params; };
519    virtual void clear();
520
521    static CvParamGrid get_default_grid( int param_id );
522
523    virtual void write( CvFileStorage* storage, const char* name );
524    virtual void read( CvFileStorage* storage, CvFileNode* node );
525    int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
526
527protected:
528
529    virtual bool set_params( const CvSVMParams& _params );
530    virtual bool train1( int sample_count, int var_count, const float** samples,
531                    const void* _responses, double Cp, double Cn,
532                    CvMemStorage* _storage, double* alpha, double& rho );
533    virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
534                    const CvMat* _responses, CvMemStorage* _storage, double* alpha );
535    virtual void create_kernel();
536    virtual void create_solver();
537
538    virtual void write_params( CvFileStorage* fs );
539    virtual void read_params( CvFileStorage* fs, CvFileNode* node );
540
541    CvSVMParams params;
542    CvMat* class_labels;
543    int var_all;
544    float** sv;
545    int sv_total;
546    CvMat* var_idx;
547    CvMat* class_weights;
548    CvSVMDecisionFunc* decision_func;
549    CvMemStorage* storage;
550
551    CvSVMSolver* solver;
552    CvSVMKernel* kernel;
553};
554
555/****************************************************************************************\
556*                              Expectation - Maximization                                *
557\****************************************************************************************/
558
559struct CV_EXPORTS CvEMParams
560{
561    CvEMParams() : nclusters(10), cov_mat_type(1/*CvEM::COV_MAT_DIAGONAL*/),
562        start_step(0/*CvEM::START_AUTO_STEP*/), probs(0), weights(0), means(0), covs(0)
563    {
564        term_crit=cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON );
565    }
566
567    CvEMParams( int _nclusters, int _cov_mat_type=1/*CvEM::COV_MAT_DIAGONAL*/,
568                int _start_step=0/*CvEM::START_AUTO_STEP*/,
569                CvTermCriteria _term_crit=cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON),
570                const CvMat* _probs=0, const CvMat* _weights=0, const CvMat* _means=0, const CvMat** _covs=0 ) :
571                nclusters(_nclusters), cov_mat_type(_cov_mat_type), start_step(_start_step),
572                probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit)
573    {}
574
575    int nclusters;
576    int cov_mat_type;
577    int start_step;
578    const CvMat* probs;
579    const CvMat* weights;
580    const CvMat* means;
581    const CvMat** covs;
582    CvTermCriteria term_crit;
583};
584
585
586class CV_EXPORTS CvEM : public CvStatModel
587{
588public:
589    // Type of covariation matrices
590    enum { COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2 };
591
592    // The initial step
593    enum { START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0 };
594
595    CvEM();
596    CvEM( const CvMat* samples, const CvMat* sample_idx=0,
597          CvEMParams params=CvEMParams(), CvMat* labels=0 );
598
599    virtual ~CvEM();
600
601    virtual bool train( const CvMat* samples, const CvMat* sample_idx=0,
602                        CvEMParams params=CvEMParams(), CvMat* labels=0 );
603
604    virtual float predict( const CvMat* sample, CvMat* probs ) const;
605    virtual void clear();
606
607    int get_nclusters() const;
608    const CvMat* get_means() const;
609    const CvMat** get_covs() const;
610    const CvMat* get_weights() const;
611    const CvMat* get_probs() const;
612
613    inline double get_log_likelihood () const { return log_likelihood; };
614
615protected:
616
617    virtual void set_params( const CvEMParams& params,
618                             const CvVectors& train_data );
619    virtual void init_em( const CvVectors& train_data );
620    virtual double run_em( const CvVectors& train_data );
621    virtual void init_auto( const CvVectors& samples );
622    virtual void kmeans( const CvVectors& train_data, int nclusters,
623                         CvMat* labels, CvTermCriteria criteria,
624                         const CvMat* means );
625    CvEMParams params;
626    double log_likelihood;
627
628    CvMat* means;
629    CvMat** covs;
630    CvMat* weights;
631    CvMat* probs;
632
633    CvMat* log_weight_div_det;
634    CvMat* inv_eigen_values;
635    CvMat** cov_rotate_mats;
636};
637
638/****************************************************************************************\
639*                                      Decision Tree                                     *
640\****************************************************************************************/
641
642struct CvPair32s32f
643{
644    int i;
645    float val;
646};
647
648
649#define CV_DTREE_CAT_DIR(idx,subset) \
650    (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
651
652struct CvDTreeSplit
653{
654    int var_idx;
655    int inversed;
656    float quality;
657    CvDTreeSplit* next;
658    union
659    {
660        int subset[2];
661        struct
662        {
663            float c;
664            int split_point;
665        }
666        ord;
667    };
668};
669
670
671struct CvDTreeNode
672{
673    int class_idx;
674    int Tn;
675    double value;
676
677    CvDTreeNode* parent;
678    CvDTreeNode* left;
679    CvDTreeNode* right;
680
681    CvDTreeSplit* split;
682
683    int sample_count;
684    int depth;
685    int* num_valid;
686    int offset;
687    int buf_idx;
688    double maxlr;
689
690    // global pruning data
691    int complexity;
692    double alpha;
693    double node_risk, tree_risk, tree_error;
694
695    // cross-validation pruning data
696    int* cv_Tn;
697    double* cv_node_risk;
698    double* cv_node_error;
699
700    int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
701    void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
702};
703
704
705struct CV_EXPORTS CvDTreeParams
706{
707    int   max_categories;
708    int   max_depth;
709    int   min_sample_count;
710    int   cv_folds;
711    bool  use_surrogates;
712    bool  use_1se_rule;
713    bool  truncate_pruned_tree;
714    float regression_accuracy;
715    const float* priors;
716
717    CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10),
718        cv_folds(10), use_surrogates(true), use_1se_rule(true),
719        truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0)
720    {}
721
722    CvDTreeParams( int _max_depth, int _min_sample_count,
723                   float _regression_accuracy, bool _use_surrogates,
724                   int _max_categories, int _cv_folds,
725                   bool _use_1se_rule, bool _truncate_pruned_tree,
726                   const float* _priors ) :
727        max_categories(_max_categories), max_depth(_max_depth),
728        min_sample_count(_min_sample_count), cv_folds (_cv_folds),
729        use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule),
730        truncate_pruned_tree(_truncate_pruned_tree),
731        regression_accuracy(_regression_accuracy),
732        priors(_priors)
733    {}
734};
735
736
737struct CV_EXPORTS CvDTreeTrainData
738{
739    CvDTreeTrainData();
740    CvDTreeTrainData( const CvMat* _train_data, int _tflag,
741                      const CvMat* _responses, const CvMat* _var_idx=0,
742                      const CvMat* _sample_idx=0, const CvMat* _var_type=0,
743                      const CvMat* _missing_mask=0,
744                      const CvDTreeParams& _params=CvDTreeParams(),
745                      bool _shared=false, bool _add_labels=false );
746    virtual ~CvDTreeTrainData();
747
748    virtual void set_data( const CvMat* _train_data, int _tflag,
749                          const CvMat* _responses, const CvMat* _var_idx=0,
750                          const CvMat* _sample_idx=0, const CvMat* _var_type=0,
751                          const CvMat* _missing_mask=0,
752                          const CvDTreeParams& _params=CvDTreeParams(),
753                          bool _shared=false, bool _add_labels=false,
754                          bool _update_data=false );
755
756    virtual void get_vectors( const CvMat* _subsample_idx,
757         float* values, uchar* missing, float* responses, bool get_class_idx=false );
758
759    virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
760
761    virtual void write_params( CvFileStorage* fs );
762    virtual void read_params( CvFileStorage* fs, CvFileNode* node );
763
764    // release all the data
765    virtual void clear();
766
767    int get_num_classes() const;
768    int get_var_type(int vi) const;
769    int get_work_var_count() const;
770
771    virtual int* get_class_labels( CvDTreeNode* n );
772    virtual float* get_ord_responses( CvDTreeNode* n );
773    virtual int* get_labels( CvDTreeNode* n );
774    virtual int* get_cat_var_data( CvDTreeNode* n, int vi );
775    virtual CvPair32s32f* get_ord_var_data( CvDTreeNode* n, int vi );
776    virtual int get_child_buf_idx( CvDTreeNode* n );
777
778    ////////////////////////////////////
779
780    virtual bool set_params( const CvDTreeParams& params );
781    virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
782                                   int storage_idx, int offset );
783
784    virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
785                int split_point, int inversed, float quality );
786    virtual CvDTreeSplit* new_split_cat( int vi, float quality );
787    virtual void free_node_data( CvDTreeNode* node );
788    virtual void free_train_data();
789    virtual void free_node( CvDTreeNode* node );
790
791    int sample_count, var_all, var_count, max_c_count;
792    int ord_var_count, cat_var_count;
793    bool have_labels, have_priors;
794    bool is_classifier;
795
796    int buf_count, buf_size;
797    bool shared;
798
799    CvMat* cat_count;
800    CvMat* cat_ofs;
801    CvMat* cat_map;
802
803    CvMat* counts;
804    CvMat* buf;
805    CvMat* direction;
806    CvMat* split_buf;
807
808    CvMat* var_idx;
809    CvMat* var_type; // i-th element =
810                     //   k<0  - ordered
811                     //   k>=0 - categorical, see k-th element of cat_* arrays
812    CvMat* priors;
813    CvMat* priors_mult;
814
815    CvDTreeParams params;
816
817    CvMemStorage* tree_storage;
818    CvMemStorage* temp_storage;
819
820    CvDTreeNode* data_root;
821
822    CvSet* node_heap;
823    CvSet* split_heap;
824    CvSet* cv_heap;
825    CvSet* nv_heap;
826
827    CvRNG rng;
828};
829
830
831class CV_EXPORTS CvDTree : public CvStatModel
832{
833public:
834    CvDTree();
835    virtual ~CvDTree();
836
837    virtual bool train( const CvMat* _train_data, int _tflag,
838                        const CvMat* _responses, const CvMat* _var_idx=0,
839                        const CvMat* _sample_idx=0, const CvMat* _var_type=0,
840                        const CvMat* _missing_mask=0,
841                        CvDTreeParams params=CvDTreeParams() );
842
843    virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
844
845    virtual CvDTreeNode* predict( const CvMat* _sample, const CvMat* _missing_data_mask=0,
846                                  bool preprocessed_input=false ) const;
847    virtual const CvMat* get_var_importance();
848    virtual void clear();
849
850    virtual void read( CvFileStorage* fs, CvFileNode* node );
851    virtual void write( CvFileStorage* fs, const char* name );
852
853    // special read & write methods for trees in the tree ensembles
854    virtual void read( CvFileStorage* fs, CvFileNode* node,
855                       CvDTreeTrainData* data );
856    virtual void write( CvFileStorage* fs );
857
858    const CvDTreeNode* get_root() const;
859    int get_pruned_tree_idx() const;
860    CvDTreeTrainData* get_data();
861
862protected:
863
864    virtual bool do_train( const CvMat* _subsample_idx );
865
866    virtual void try_split_node( CvDTreeNode* n );
867    virtual void split_node_data( CvDTreeNode* n );
868    virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
869    virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi );
870    virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi );
871    virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi );
872    virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi );
873    virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi );
874    virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi );
875    virtual double calc_node_dir( CvDTreeNode* node );
876    virtual void complete_node_dir( CvDTreeNode* node );
877    virtual void cluster_categories( const int* vectors, int vector_count,
878        int var_count, int* sums, int k, int* cluster_labels );
879
880    virtual void calc_node_value( CvDTreeNode* node );
881
882    virtual void prune_cv();
883    virtual double update_tree_rnc( int T, int fold );
884    virtual int cut_tree( int T, int fold, double min_alpha );
885    virtual void free_prune_data(bool cut_tree);
886    virtual void free_tree();
887
888    virtual void write_node( CvFileStorage* fs, CvDTreeNode* node );
889    virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split );
890    virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
891    virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
892    virtual void write_tree_nodes( CvFileStorage* fs );
893    virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
894
895    CvDTreeNode* root;
896
897    int pruned_tree_idx;
898    CvMat* var_importance;
899
900    CvDTreeTrainData* data;
901};
902
903
904/****************************************************************************************\
905*                                   Random Trees Classifier                              *
906\****************************************************************************************/
907
908class CvRTrees;
909
910class CV_EXPORTS CvForestTree: public CvDTree
911{
912public:
913    CvForestTree();
914    virtual ~CvForestTree();
915
916    virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx, CvRTrees* forest );
917
918    virtual int get_var_count() const {return data ? data->var_count : 0;}
919    virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
920
921    /* dummy methods to avoid warnings: BEGIN */
922    virtual bool train( const CvMat* _train_data, int _tflag,
923                        const CvMat* _responses, const CvMat* _var_idx=0,
924                        const CvMat* _sample_idx=0, const CvMat* _var_type=0,
925                        const CvMat* _missing_mask=0,
926                        CvDTreeParams params=CvDTreeParams() );
927
928    virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
929    virtual void read( CvFileStorage* fs, CvFileNode* node );
930    virtual void read( CvFileStorage* fs, CvFileNode* node,
931                       CvDTreeTrainData* data );
932    /* dummy methods to avoid warnings: END */
933
934protected:
935    virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
936    CvRTrees* forest;
937};
938
939
940struct CV_EXPORTS CvRTParams : public CvDTreeParams
941{
942    //Parameters for the forest
943    bool calc_var_importance; // true <=> RF processes variable importance
944    int nactive_vars;
945    CvTermCriteria term_crit;
946
947    CvRTParams() : CvDTreeParams( 5, 10, 0, false, 10, 0, false, false, 0 ),
948        calc_var_importance(false), nactive_vars(0)
949    {
950        term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 50, 0.1 );
951    }
952
953    CvRTParams( int _max_depth, int _min_sample_count,
954                float _regression_accuracy, bool _use_surrogates,
955                int _max_categories, const float* _priors, bool _calc_var_importance,
956                int _nactive_vars, int max_num_of_trees_in_the_forest,
957                float forest_accuracy, int termcrit_type ) :
958        CvDTreeParams( _max_depth, _min_sample_count, _regression_accuracy,
959                       _use_surrogates, _max_categories, 0,
960                       false, false, _priors ),
961        calc_var_importance(_calc_var_importance),
962        nactive_vars(_nactive_vars)
963    {
964        term_crit = cvTermCriteria(termcrit_type,
965            max_num_of_trees_in_the_forest, forest_accuracy);
966    }
967};
968
969
970class CV_EXPORTS CvRTrees : public CvStatModel
971{
972public:
973    CvRTrees();
974    virtual ~CvRTrees();
975    virtual bool train( const CvMat* _train_data, int _tflag,
976                        const CvMat* _responses, const CvMat* _var_idx=0,
977                        const CvMat* _sample_idx=0, const CvMat* _var_type=0,
978                        const CvMat* _missing_mask=0,
979                        CvRTParams params=CvRTParams() );
980    virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
981    virtual void clear();
982
983    virtual const CvMat* get_var_importance();
984    virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
985        const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
986
987    virtual void read( CvFileStorage* fs, CvFileNode* node );
988    virtual void write( CvFileStorage* fs, const char* name );
989
990    CvMat* get_active_var_mask();
991    CvRNG* get_rng();
992
993    int get_tree_count() const;
994    CvForestTree* get_tree(int i) const;
995
996protected:
997
998    bool grow_forest( const CvTermCriteria term_crit );
999
1000    // array of the trees of the forest
1001    CvForestTree** trees;
1002    CvDTreeTrainData* data;
1003    int ntrees;
1004    int nclasses;
1005    double oob_error;
1006    CvMat* var_importance;
1007    int nsamples;
1008
1009    CvRNG rng;
1010    CvMat* active_var_mask;
1011};
1012
1013
1014/****************************************************************************************\
1015*                                   Boosted tree classifier                              *
1016\****************************************************************************************/
1017
1018struct CV_EXPORTS CvBoostParams : public CvDTreeParams
1019{
1020    int boost_type;
1021    int weak_count;
1022    int split_criteria;
1023    double weight_trim_rate;
1024
1025    CvBoostParams();
1026    CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
1027                   int max_depth, bool use_surrogates, const float* priors );
1028};
1029
1030
1031class CvBoost;
1032
1033class CV_EXPORTS CvBoostTree: public CvDTree
1034{
1035public:
1036    CvBoostTree();
1037    virtual ~CvBoostTree();
1038
1039    virtual bool train( CvDTreeTrainData* _train_data,
1040                        const CvMat* subsample_idx, CvBoost* ensemble );
1041
1042    virtual void scale( double s );
1043    virtual void read( CvFileStorage* fs, CvFileNode* node,
1044                       CvBoost* ensemble, CvDTreeTrainData* _data );
1045    virtual void clear();
1046
1047    /* dummy methods to avoid warnings: BEGIN */
1048    virtual bool train( const CvMat* _train_data, int _tflag,
1049                        const CvMat* _responses, const CvMat* _var_idx=0,
1050                        const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1051                        const CvMat* _missing_mask=0,
1052                        CvDTreeParams params=CvDTreeParams() );
1053
1054    virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
1055    virtual void read( CvFileStorage* fs, CvFileNode* node );
1056    virtual void read( CvFileStorage* fs, CvFileNode* node,
1057                       CvDTreeTrainData* data );
1058    /* dummy methods to avoid warnings: END */
1059
1060protected:
1061
1062    virtual void try_split_node( CvDTreeNode* n );
1063    virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi );
1064    virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi );
1065    virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi );
1066    virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi );
1067    virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi );
1068    virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi );
1069    virtual void calc_node_value( CvDTreeNode* n );
1070    virtual double calc_node_dir( CvDTreeNode* n );
1071
1072    CvBoost* ensemble;
1073};
1074
1075
1076class CV_EXPORTS CvBoost : public CvStatModel
1077{
1078public:
1079    // Boosting type
1080    enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
1081
1082    // Splitting criteria
1083    enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
1084
1085    CvBoost();
1086    virtual ~CvBoost();
1087
1088    CvBoost( const CvMat* _train_data, int _tflag,
1089             const CvMat* _responses, const CvMat* _var_idx=0,
1090             const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1091             const CvMat* _missing_mask=0,
1092             CvBoostParams params=CvBoostParams() );
1093
1094    virtual bool train( const CvMat* _train_data, int _tflag,
1095             const CvMat* _responses, const CvMat* _var_idx=0,
1096             const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1097             const CvMat* _missing_mask=0,
1098             CvBoostParams params=CvBoostParams(),
1099             bool update=false );
1100
1101    virtual float predict( const CvMat* _sample, const CvMat* _missing=0,
1102                           CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
1103                           bool raw_mode=false ) const;
1104
1105    virtual void prune( CvSlice slice );
1106
1107    virtual void clear();
1108
1109    virtual void write( CvFileStorage* storage, const char* name );
1110    virtual void read( CvFileStorage* storage, CvFileNode* node );
1111
1112    CvSeq* get_weak_predictors();
1113
1114    CvMat* get_weights();
1115    CvMat* get_subtree_weights();
1116    CvMat* get_weak_response();
1117    const CvBoostParams& get_params() const;
1118
1119protected:
1120
1121    virtual bool set_params( const CvBoostParams& _params );
1122    virtual void update_weights( CvBoostTree* tree );
1123    virtual void trim_weights();
1124    virtual void write_params( CvFileStorage* fs );
1125    virtual void read_params( CvFileStorage* fs, CvFileNode* node );
1126
1127    CvDTreeTrainData* data;
1128    CvBoostParams params;
1129    CvSeq* weak;
1130
1131    CvMat* orig_response;
1132    CvMat* sum_response;
1133    CvMat* weak_eval;
1134    CvMat* subsample_mask;
1135    CvMat* weights;
1136    CvMat* subtree_weights;
1137    bool have_subsample;
1138};
1139
1140
1141/****************************************************************************************\
1142*                              Artificial Neural Networks (ANN)                          *
1143\****************************************************************************************/
1144
1145/////////////////////////////////// Multi-Layer Perceptrons //////////////////////////////
1146
1147struct CV_EXPORTS CvANN_MLP_TrainParams
1148{
1149    CvANN_MLP_TrainParams();
1150    CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
1151                           double param1, double param2=0 );
1152    ~CvANN_MLP_TrainParams();
1153
1154    enum { BACKPROP=0, RPROP=1 };
1155
1156    CvTermCriteria term_crit;
1157    int train_method;
1158
1159    // backpropagation parameters
1160    double bp_dw_scale, bp_moment_scale;
1161
1162    // rprop parameters
1163    double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
1164};
1165
1166
1167class CV_EXPORTS CvANN_MLP : public CvStatModel
1168{
1169public:
1170    CvANN_MLP();
1171    CvANN_MLP( const CvMat* _layer_sizes,
1172               int _activ_func=SIGMOID_SYM,
1173               double _f_param1=0, double _f_param2=0 );
1174
1175    virtual ~CvANN_MLP();
1176
1177    virtual void create( const CvMat* _layer_sizes,
1178                         int _activ_func=SIGMOID_SYM,
1179                         double _f_param1=0, double _f_param2=0 );
1180
1181    virtual int train( const CvMat* _inputs, const CvMat* _outputs,
1182                       const CvMat* _sample_weights, const CvMat* _sample_idx=0,
1183                       CvANN_MLP_TrainParams _params = CvANN_MLP_TrainParams(),
1184                       int flags=0 );
1185    virtual float predict( const CvMat* _inputs,
1186                           CvMat* _outputs ) const;
1187
1188    virtual void clear();
1189
1190    // possible activation functions
1191    enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
1192
1193    // available training flags
1194    enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
1195
1196    virtual void read( CvFileStorage* fs, CvFileNode* node );
1197    virtual void write( CvFileStorage* storage, const char* name );
1198
1199    int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
1200    const CvMat* get_layer_sizes() { return layer_sizes; }
1201    double* get_weights(int layer)
1202    {
1203        return layer_sizes && weights &&
1204            (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
1205    }
1206
1207protected:
1208
1209    virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
1210            const CvMat* _sample_weights, const CvMat* _sample_idx,
1211            CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
1212
1213    // sequential random backpropagation
1214    virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
1215
1216    // RPROP algorithm
1217    virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
1218
1219    virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
1220    virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
1221    virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
1222                                 double _f_param1=0, double _f_param2=0 );
1223    virtual void init_weights();
1224    virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
1225    virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
1226    virtual void calc_input_scale( const CvVectors* vecs, int flags );
1227    virtual void calc_output_scale( const CvVectors* vecs, int flags );
1228
1229    virtual void write_params( CvFileStorage* fs );
1230    virtual void read_params( CvFileStorage* fs, CvFileNode* node );
1231
1232    CvMat* layer_sizes;
1233    CvMat* wbuf;
1234    CvMat* sample_weights;
1235    double** weights;
1236    double f_param1, f_param2;
1237    double min_val, max_val, min_val1, max_val1;
1238    int activ_func;
1239    int max_count, max_buf_sz;
1240    CvANN_MLP_TrainParams params;
1241    CvRNG rng;
1242};
1243
1244#if 0
1245/****************************************************************************************\
1246*                            Convolutional Neural Network                                *
1247\****************************************************************************************/
1248typedef struct CvCNNLayer CvCNNLayer;
1249typedef struct CvCNNetwork CvCNNetwork;
1250
1251#define CV_CNN_LEARN_RATE_DECREASE_HYPERBOLICALLY  1
1252#define CV_CNN_LEARN_RATE_DECREASE_SQRT_INV        2
1253#define CV_CNN_LEARN_RATE_DECREASE_LOG_INV         3
1254
1255#define CV_CNN_GRAD_ESTIM_RANDOM        0
1256#define CV_CNN_GRAD_ESTIM_BY_WORST_IMG  1
1257
1258#define ICV_CNN_LAYER                0x55550000
1259#define ICV_CNN_CONVOLUTION_LAYER    0x00001111
1260#define ICV_CNN_SUBSAMPLING_LAYER    0x00002222
1261#define ICV_CNN_FULLCONNECT_LAYER    0x00003333
1262
1263#define ICV_IS_CNN_LAYER( layer )                                          \
1264    ( ((layer) != NULL) && ((((CvCNNLayer*)(layer))->flags & CV_MAGIC_MASK)\
1265        == ICV_CNN_LAYER ))
1266
1267#define ICV_IS_CNN_CONVOLUTION_LAYER( layer )                              \
1268    ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags       \
1269        & ~CV_MAGIC_MASK) == ICV_CNN_CONVOLUTION_LAYER )
1270
1271#define ICV_IS_CNN_SUBSAMPLING_LAYER( layer )                              \
1272    ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags       \
1273        & ~CV_MAGIC_MASK) == ICV_CNN_SUBSAMPLING_LAYER )
1274
1275#define ICV_IS_CNN_FULLCONNECT_LAYER( layer )                              \
1276    ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags       \
1277        & ~CV_MAGIC_MASK) == ICV_CNN_FULLCONNECT_LAYER )
1278
1279typedef void (CV_CDECL *CvCNNLayerForward)
1280    ( CvCNNLayer* layer, const CvMat* input, CvMat* output );
1281
1282typedef void (CV_CDECL *CvCNNLayerBackward)
1283    ( CvCNNLayer* layer, int t, const CvMat* X, const CvMat* dE_dY, CvMat* dE_dX );
1284
1285typedef void (CV_CDECL *CvCNNLayerRelease)
1286    (CvCNNLayer** layer);
1287
1288typedef void (CV_CDECL *CvCNNetworkAddLayer)
1289    (CvCNNetwork* network, CvCNNLayer* layer);
1290
1291typedef void (CV_CDECL *CvCNNetworkRelease)
1292    (CvCNNetwork** network);
1293
1294#define CV_CNN_LAYER_FIELDS()           \
1295    /* Indicator of the layer's type */ \
1296    int flags;                          \
1297                                        \
1298    /* Number of input images */        \
1299    int n_input_planes;                 \
1300    /* Height of each input image */    \
1301    int input_height;                   \
1302    /* Width of each input image */     \
1303    int input_width;                    \
1304                                        \
1305    /* Number of output images */       \
1306    int n_output_planes;                \
1307    /* Height of each output image */   \
1308    int output_height;                  \
1309    /* Width of each output image */    \
1310    int output_width;                   \
1311                                        \
1312    /* Learning rate at the first iteration */                      \
1313    float init_learn_rate;                                          \
1314    /* Dynamics of learning rate decreasing */                      \
1315    int learn_rate_decrease_type;                                   \
1316    /* Trainable weights of the layer (including bias) */           \
1317    /* i-th row is a set of weights of the i-th output plane */     \
1318    CvMat* weights;                                                 \
1319                                                                    \
1320    CvCNNLayerForward  forward;                                     \
1321    CvCNNLayerBackward backward;                                    \
1322    CvCNNLayerRelease  release;                                     \
1323    /* Pointers to the previous and next layers in the network */   \
1324    CvCNNLayer* prev_layer;                                         \
1325    CvCNNLayer* next_layer
1326
1327typedef struct CvCNNLayer
1328{
1329    CV_CNN_LAYER_FIELDS();
1330}CvCNNLayer;
1331
1332typedef struct CvCNNConvolutionLayer
1333{
1334    CV_CNN_LAYER_FIELDS();
1335    // Kernel size (height and width) for convolution.
1336    int K;
1337    // connections matrix, (i,j)-th element is 1 iff there is a connection between
1338    // i-th plane of the current layer and j-th plane of the previous layer;
1339    // (i,j)-th element is equal to 0 otherwise
1340    CvMat *connect_mask;
1341    // value of the learning rate for updating weights at the first iteration
1342}CvCNNConvolutionLayer;
1343
1344typedef struct CvCNNSubSamplingLayer
1345{
1346    CV_CNN_LAYER_FIELDS();
1347    // ratio between the heights (or widths - ratios are supposed to be equal)
1348    // of the input and output planes
1349    int sub_samp_scale;
1350    // amplitude of sigmoid activation function
1351    float a;
1352    // scale parameter of sigmoid activation function
1353    float s;
1354    // exp2ssumWX = exp(2<s>*(bias+w*(x1+...+x4))), where x1,...x4 are some elements of X
1355    // - is the vector used in computing of the activation function in backward
1356    CvMat* exp2ssumWX;
1357    // (x1+x2+x3+x4), where x1,...x4 are some elements of X
1358    // - is the vector used in computing of the activation function in backward
1359    CvMat* sumX;
1360}CvCNNSubSamplingLayer;
1361
1362// Structure of the last layer.
1363typedef struct CvCNNFullConnectLayer
1364{
1365    CV_CNN_LAYER_FIELDS();
1366    // amplitude of sigmoid activation function
1367    float a;
1368    // scale parameter of sigmoid activation function
1369    float s;
1370    // exp2ssumWX = exp(2*<s>*(W*X)) - is the vector used in computing of the
1371    // activation function and it's derivative by the formulae
1372    // activ.func. = <a>(exp(2<s>WX)-1)/(exp(2<s>WX)+1) == <a> - 2<a>/(<exp2ssumWX> + 1)
1373    // (activ.func.)' = 4<a><s>exp(2<s>WX)/(exp(2<s>WX)+1)^2
1374    CvMat* exp2ssumWX;
1375}CvCNNFullConnectLayer;
1376
1377typedef struct CvCNNetwork
1378{
1379    int n_layers;
1380    CvCNNLayer* layers;
1381    CvCNNetworkAddLayer add_layer;
1382    CvCNNetworkRelease release;
1383}CvCNNetwork;
1384
1385typedef struct CvCNNStatModel
1386{
1387    CV_STAT_MODEL_FIELDS();
1388    CvCNNetwork* network;
1389    // etalons are allocated as rows, the i-th etalon has label cls_labeles[i]
1390    CvMat* etalons;
1391    // classes labels
1392    CvMat* cls_labels;
1393}CvCNNStatModel;
1394
1395typedef struct CvCNNStatModelParams
1396{
1397    CV_STAT_MODEL_PARAM_FIELDS();
1398    // network must be created by the functions cvCreateCNNetwork and <add_layer>
1399    CvCNNetwork* network;
1400    CvMat* etalons;
1401    // termination criteria
1402    int max_iter;
1403    int start_iter;
1404    int grad_estim_type;
1405}CvCNNStatModelParams;
1406
1407CVAPI(CvCNNLayer*) cvCreateCNNConvolutionLayer(
1408    int n_input_planes, int input_height, int input_width,
1409    int n_output_planes, int K,
1410    float init_learn_rate, int learn_rate_decrease_type,
1411    CvMat* connect_mask CV_DEFAULT(0), CvMat* weights CV_DEFAULT(0) );
1412
1413CVAPI(CvCNNLayer*) cvCreateCNNSubSamplingLayer(
1414    int n_input_planes, int input_height, int input_width,
1415    int sub_samp_scale, float a, float s,
1416    float init_learn_rate, int learn_rate_decrease_type, CvMat* weights CV_DEFAULT(0) );
1417
1418CVAPI(CvCNNLayer*) cvCreateCNNFullConnectLayer(
1419    int n_inputs, int n_outputs, float a, float s,
1420    float init_learn_rate, int learning_type, CvMat* weights CV_DEFAULT(0) );
1421
1422CVAPI(CvCNNetwork*) cvCreateCNNetwork( CvCNNLayer* first_layer );
1423
1424CVAPI(CvStatModel*) cvTrainCNNClassifier(
1425            const CvMat* train_data, int tflag,
1426            const CvMat* responses,
1427            const CvStatModelParams* params,
1428            const CvMat* CV_DEFAULT(0),
1429            const CvMat* sample_idx CV_DEFAULT(0),
1430            const CvMat* CV_DEFAULT(0), const CvMat* CV_DEFAULT(0) );
1431
1432/****************************************************************************************\
1433*                               Estimate classifiers algorithms                          *
1434\****************************************************************************************/
1435typedef const CvMat* (CV_CDECL *CvStatModelEstimateGetMat)
1436                    ( const CvStatModel* estimateModel );
1437
1438typedef int (CV_CDECL *CvStatModelEstimateNextStep)
1439                    ( CvStatModel* estimateModel );
1440
1441typedef void (CV_CDECL *CvStatModelEstimateCheckClassifier)
1442                    ( CvStatModel* estimateModel,
1443                const CvStatModel* model,
1444                const CvMat*       features,
1445                      int          sample_t_flag,
1446                const CvMat*       responses );
1447
1448typedef void (CV_CDECL *CvStatModelEstimateCheckClassifierEasy)
1449                    ( CvStatModel* estimateModel,
1450                const CvStatModel* model );
1451
1452typedef float (CV_CDECL *CvStatModelEstimateGetCurrentResult)
1453                    ( const CvStatModel* estimateModel,
1454                            float*       correlation );
1455
1456typedef void (CV_CDECL *CvStatModelEstimateReset)
1457                    ( CvStatModel* estimateModel );
1458
1459//-------------------------------- Cross-validation --------------------------------------
1460#define CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_PARAM_FIELDS()    \
1461    CV_STAT_MODEL_PARAM_FIELDS();                                 \
1462    int     k_fold;                                               \
1463    int     is_regression;                                        \
1464    CvRNG*  rng
1465
1466typedef struct CvCrossValidationParams
1467{
1468    CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_PARAM_FIELDS();
1469} CvCrossValidationParams;
1470
1471#define CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_FIELDS()    \
1472    CvStatModelEstimateGetMat               getTrainIdxMat; \
1473    CvStatModelEstimateGetMat               getCheckIdxMat; \
1474    CvStatModelEstimateNextStep             nextStep;       \
1475    CvStatModelEstimateCheckClassifier      check;          \
1476    CvStatModelEstimateGetCurrentResult     getResult;      \
1477    CvStatModelEstimateReset                reset;          \
1478    int     is_regression;                                  \
1479    int     folds_all;                                      \
1480    int     samples_all;                                    \
1481    int*    sampleIdxAll;                                   \
1482    int*    folds;                                          \
1483    int     max_fold_size;                                  \
1484    int         current_fold;                               \
1485    int         is_checked;                                 \
1486    CvMat*      sampleIdxTrain;                             \
1487    CvMat*      sampleIdxEval;                              \
1488    CvMat*      predict_results;                            \
1489    int     correct_results;                                \
1490    int     all_results;                                    \
1491    double  sq_error;                                       \
1492    double  sum_correct;                                    \
1493    double  sum_predict;                                    \
1494    double  sum_cc;                                         \
1495    double  sum_pp;                                         \
1496    double  sum_cp
1497
1498typedef struct CvCrossValidationModel
1499{
1500    CV_STAT_MODEL_FIELDS();
1501    CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_FIELDS();
1502} CvCrossValidationModel;
1503
1504CVAPI(CvStatModel*)
1505cvCreateCrossValidationEstimateModel
1506           ( int                samples_all,
1507       const CvStatModelParams* estimateParams CV_DEFAULT(0),
1508       const CvMat*             sampleIdx CV_DEFAULT(0) );
1509
1510CVAPI(float)
1511cvCrossValidation( const CvMat*             trueData,
1512                         int                tflag,
1513                   const CvMat*             trueClasses,
1514                         CvStatModel*     (*createClassifier)( const CvMat*,
1515                                                                     int,
1516                                                               const CvMat*,
1517                                                               const CvStatModelParams*,
1518                                                               const CvMat*,
1519                                                               const CvMat*,
1520                                                               const CvMat*,
1521                                                               const CvMat* ),
1522                   const CvStatModelParams* estimateParams CV_DEFAULT(0),
1523                   const CvStatModelParams* trainParams CV_DEFAULT(0),
1524                   const CvMat*             compIdx CV_DEFAULT(0),
1525                   const CvMat*             sampleIdx CV_DEFAULT(0),
1526                         CvStatModel**      pCrValModel CV_DEFAULT(0),
1527                   const CvMat*             typeMask CV_DEFAULT(0),
1528                   const CvMat*             missedMeasurementMask CV_DEFAULT(0) );
1529#endif
1530
1531/****************************************************************************************\
1532*                           Auxilary functions declarations                              *
1533\****************************************************************************************/
1534
1535/* Generates <sample> from multivariate normal distribution, where <mean> - is an
1536   average row vector, <cov> - symmetric covariation matrix */
1537CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample,
1538                           CvRNG* rng CV_DEFAULT(0) );
1539
1540/* Generates sample from gaussian mixture distribution */
1541CVAPI(void) cvRandGaussMixture( CvMat* means[],
1542                               CvMat* covs[],
1543                               float weights[],
1544                               int clsnum,
1545                               CvMat* sample,
1546                               CvMat* sampClasses CV_DEFAULT(0) );
1547
1548#define CV_TS_CONCENTRIC_SPHERES 0
1549
1550/* creates test set */
1551CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
1552                 int num_samples,
1553                 int num_features,
1554                 CvMat** responses,
1555                 int num_classes, ... );
1556
1557/* Aij <- Aji for i > j if lower_to_upper != 0
1558              for i < j if lower_to_upper = 0 */
1559CVAPI(void) cvCompleteSymm( CvMat* matrix, int lower_to_upper );
1560
1561#endif
1562
1563#endif /*__ML_H__*/
1564/* End of file. */
1565