1/*M///////////////////////////////////////////////////////////////////////////////////////
2//
3//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4//
5//  By downloading, copying, installing or using the software you agree to this license.
6//  If you do not agree to this license, do not download, install,
7//  copy or use the software.
8//
9//
10//                        Intel License Agreement
11//
12// Copyright (C) 2000, Intel Corporation, all rights reserved.
13// Third party copyrights are property of their respective owners.
14//
15// Redistribution and use in source and binary forms, with or without modification,
16// are permitted provided that the following conditions are met:
17//
18//   * Redistribution's of source code must retain the above copyright notice,
19//     this list of conditions and the following disclaimer.
20//
21//   * Redistribution's in binary form must reproduce the above copyright notice,
22//     this list of conditions and the following disclaimer in the documentation
23//     and/or other materials provided with the distribution.
24//
25//   * The name of Intel Corporation may not be used to endorse or promote products
26//     derived from this software without specific prior written permission.
27//
28// This software is provided by the copyright holders and contributors "as is" and
29// any express or implied warranties, including, but not limited to, the implied
30// warranties of merchantability and fitness for a particular purpose are disclaimed.
31// In no event shall the Intel Corporation or contributors be liable for any direct,
32// indirect, incidental, special, exemplary, or consequential damages
33// (including, but not limited to, procurement of substitute goods or services;
34// loss of use, data, or profits; or business interruption) however caused
35// and on any theory of liability, whether in contract, strict liability,
36// or tort (including negligence or otherwise) arising in any way out of
37// the use of this software, even if advised of the possibility of such damage.
38//
39//M*/
40
41#ifndef __OPENCV_ML_HPP__
42#define __OPENCV_ML_HPP__
43
44#ifdef __cplusplus
45#  include "opencv2/core.hpp"
46#endif
47
48#include "opencv2/core/core_c.h"
49#include <limits.h>
50
51#ifdef __cplusplus
52
53#include <map>
54#include <iostream>
55
56// Apple defines a check() macro somewhere in the debug headers
57// that interferes with a method definiton in this header
58#undef check
59
60/****************************************************************************************\
61*                               Main struct definitions                                  *
62\****************************************************************************************/
63
64/* log(2*PI) */
65#define CV_LOG2PI (1.8378770664093454835606594728112)
66
67/* columns of <trainData> matrix are training samples */
68#define CV_COL_SAMPLE 0
69
70/* rows of <trainData> matrix are training samples */
71#define CV_ROW_SAMPLE 1
72
73#define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
74
75struct CvVectors
76{
77    int type;
78    int dims, count;
79    CvVectors* next;
80    union
81    {
82        uchar** ptr;
83        float** fl;
84        double** db;
85    } data;
86};
87
88#if 0
89/* A structure, representing the lattice range of statmodel parameters.
90   It is used for optimizing statmodel parameters by cross-validation method.
91   The lattice is logarithmic, so <step> must be greater then 1. */
92typedef struct CvParamLattice
93{
94    double min_val;
95    double max_val;
96    double step;
97}
98CvParamLattice;
99
100CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
101                                         double log_step )
102{
103    CvParamLattice pl;
104    pl.min_val = MIN( min_val, max_val );
105    pl.max_val = MAX( min_val, max_val );
106    pl.step = MAX( log_step, 1. );
107    return pl;
108}
109
110CV_INLINE CvParamLattice cvDefaultParamLattice( void )
111{
112    CvParamLattice pl = {0,0,0};
113    return pl;
114}
115#endif
116
117/* Variable type */
118#define CV_VAR_NUMERICAL    0
119#define CV_VAR_ORDERED      0
120#define CV_VAR_CATEGORICAL  1
121
122#define CV_TYPE_NAME_ML_SVM         "opencv-ml-svm"
123#define CV_TYPE_NAME_ML_KNN         "opencv-ml-knn"
124#define CV_TYPE_NAME_ML_NBAYES      "opencv-ml-bayesian"
125#define CV_TYPE_NAME_ML_BOOSTING    "opencv-ml-boost-tree"
126#define CV_TYPE_NAME_ML_TREE        "opencv-ml-tree"
127#define CV_TYPE_NAME_ML_ANN_MLP     "opencv-ml-ann-mlp"
128#define CV_TYPE_NAME_ML_CNN         "opencv-ml-cnn"
129#define CV_TYPE_NAME_ML_RTREES      "opencv-ml-random-trees"
130#define CV_TYPE_NAME_ML_ERTREES     "opencv-ml-extremely-randomized-trees"
131#define CV_TYPE_NAME_ML_GBT         "opencv-ml-gradient-boosting-trees"
132
133#define CV_TRAIN_ERROR  0
134#define CV_TEST_ERROR   1
135
136class CvStatModel
137{
138public:
139    CvStatModel();
140    virtual ~CvStatModel();
141
142    virtual void clear();
143
144    CV_WRAP virtual void save( const char* filename, const char* name=0 ) const;
145    CV_WRAP virtual void load( const char* filename, const char* name=0 );
146
147    virtual void write( CvFileStorage* storage, const char* name ) const;
148    virtual void read( CvFileStorage* storage, CvFileNode* node );
149
150protected:
151    const char* default_model_name;
152};
153
154/****************************************************************************************\
155*                                 Normal Bayes Classifier                                *
156\****************************************************************************************/
157
158/* The structure, representing the grid range of statmodel parameters.
159   It is used for optimizing statmodel accuracy by varying model parameters,
160   the accuracy estimate being computed by cross-validation.
161   The grid is logarithmic, so <step> must be greater then 1. */
162
163class CvMLData;
164
165struct CvParamGrid
166{
167    // SVM params type
168    enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
169
170    CvParamGrid()
171    {
172        min_val = max_val = step = 0;
173    }
174
175    CvParamGrid( double min_val, double max_val, double log_step );
176    //CvParamGrid( int param_id );
177    bool check() const;
178
179    CV_PROP_RW double min_val;
180    CV_PROP_RW double max_val;
181    CV_PROP_RW double step;
182};
183
184inline CvParamGrid::CvParamGrid( double _min_val, double _max_val, double _log_step )
185{
186    min_val = _min_val;
187    max_val = _max_val;
188    step = _log_step;
189}
190
191class CvNormalBayesClassifier : public CvStatModel
192{
193public:
194    CV_WRAP CvNormalBayesClassifier();
195    virtual ~CvNormalBayesClassifier();
196
197    CvNormalBayesClassifier( const CvMat* trainData, const CvMat* responses,
198        const CvMat* varIdx=0, const CvMat* sampleIdx=0 );
199
200    virtual bool train( const CvMat* trainData, const CvMat* responses,
201        const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false );
202
203    virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0, CV_OUT CvMat* results_prob=0 ) const;
204    CV_WRAP virtual void clear();
205
206    CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses,
207                            const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat() );
208    CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
209                       const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
210                       bool update=false );
211    CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0, CV_OUT cv::Mat* results_prob=0 ) const;
212
213    virtual void write( CvFileStorage* storage, const char* name ) const;
214    virtual void read( CvFileStorage* storage, CvFileNode* node );
215
216protected:
217    int     var_count, var_all;
218    CvMat*  var_idx;
219    CvMat*  cls_labels;
220    CvMat** count;
221    CvMat** sum;
222    CvMat** productsum;
223    CvMat** avg;
224    CvMat** inv_eigen_values;
225    CvMat** cov_rotate_mats;
226    CvMat*  c;
227};
228
229
230/****************************************************************************************\
231*                          K-Nearest Neighbour Classifier                                *
232\****************************************************************************************/
233
234// k Nearest Neighbors
235class CvKNearest : public CvStatModel
236{
237public:
238
239    CV_WRAP CvKNearest();
240    virtual ~CvKNearest();
241
242    CvKNearest( const CvMat* trainData, const CvMat* responses,
243                const CvMat* sampleIdx=0, bool isRegression=false, int max_k=32 );
244
245    virtual bool train( const CvMat* trainData, const CvMat* responses,
246                        const CvMat* sampleIdx=0, bool is_regression=false,
247                        int maxK=32, bool updateBase=false );
248
249    virtual float find_nearest( const CvMat* samples, int k, CV_OUT CvMat* results=0,
250        const float** neighbors=0, CV_OUT CvMat* neighborResponses=0, CV_OUT CvMat* dist=0 ) const;
251
252    CV_WRAP CvKNearest( const cv::Mat& trainData, const cv::Mat& responses,
253               const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false, int max_k=32 );
254
255    CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
256                       const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false,
257                       int maxK=32, bool updateBase=false );
258
259    virtual float find_nearest( const cv::Mat& samples, int k, cv::Mat* results=0,
260                                const float** neighbors=0, cv::Mat* neighborResponses=0,
261                                cv::Mat* dist=0 ) const;
262    CV_WRAP virtual float find_nearest( const cv::Mat& samples, int k, CV_OUT cv::Mat& results,
263                                        CV_OUT cv::Mat& neighborResponses, CV_OUT cv::Mat& dists) const;
264
265    virtual void clear();
266    int get_max_k() const;
267    int get_var_count() const;
268    int get_sample_count() const;
269    bool is_regression() const;
270
271    virtual float write_results( int k, int k1, int start, int end,
272        const float* neighbor_responses, const float* dist, CvMat* _results,
273        CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
274
275    virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
276        float* neighbor_responses, const float** neighbors, float* dist ) const;
277
278protected:
279
280    int max_k, var_count;
281    int total;
282    bool regression;
283    CvVectors* samples;
284};
285
286/****************************************************************************************\
287*                                   Support Vector Machines                              *
288\****************************************************************************************/
289
290// SVM training parameters
291struct CvSVMParams
292{
293    CvSVMParams();
294    CvSVMParams( int svm_type, int kernel_type,
295                 double degree, double gamma, double coef0,
296                 double Cvalue, double nu, double p,
297                 CvMat* class_weights, CvTermCriteria term_crit );
298
299    CV_PROP_RW int         svm_type;
300    CV_PROP_RW int         kernel_type;
301    CV_PROP_RW double      degree; // for poly
302    CV_PROP_RW double      gamma;  // for poly/rbf/sigmoid/chi2
303    CV_PROP_RW double      coef0;  // for poly/sigmoid
304
305    CV_PROP_RW double      C;  // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
306    CV_PROP_RW double      nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
307    CV_PROP_RW double      p; // for CV_SVM_EPS_SVR
308    CvMat*      class_weights; // for CV_SVM_C_SVC
309    CV_PROP_RW CvTermCriteria term_crit; // termination criteria
310};
311
312
313struct CvSVMKernel
314{
315    typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
316                                       const float* another, float* results );
317    CvSVMKernel();
318    CvSVMKernel( const CvSVMParams* params, Calc _calc_func );
319    virtual bool create( const CvSVMParams* params, Calc _calc_func );
320    virtual ~CvSVMKernel();
321
322    virtual void clear();
323    virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
324
325    const CvSVMParams* params;
326    Calc calc_func;
327
328    virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
329                                    const float* another, float* results,
330                                    double alpha, double beta );
331    virtual void calc_intersec( int vcount, int var_count, const float** vecs,
332                            const float* another, float* results );
333    virtual void calc_chi2( int vec_count, int vec_size, const float** vecs,
334                              const float* another, float* results );
335    virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
336                              const float* another, float* results );
337    virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
338                           const float* another, float* results );
339    virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
340                            const float* another, float* results );
341    virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
342                               const float* another, float* results );
343};
344
345
346struct CvSVMKernelRow
347{
348    CvSVMKernelRow* prev;
349    CvSVMKernelRow* next;
350    float* data;
351};
352
353
354struct CvSVMSolutionInfo
355{
356    double obj;
357    double rho;
358    double upper_bound_p;
359    double upper_bound_n;
360    double r;   // for Solver_NU
361};
362
363class CvSVMSolver
364{
365public:
366    typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
367    typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
368    typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
369
370    CvSVMSolver();
371
372    CvSVMSolver( int count, int var_count, const float** samples, schar* y,
373                 int alpha_count, double* alpha, double Cp, double Cn,
374                 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
375                 SelectWorkingSet select_working_set, CalcRho calc_rho );
376    virtual bool create( int count, int var_count, const float** samples, schar* y,
377                 int alpha_count, double* alpha, double Cp, double Cn,
378                 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
379                 SelectWorkingSet select_working_set, CalcRho calc_rho );
380    virtual ~CvSVMSolver();
381
382    virtual void clear();
383    virtual bool solve_generic( CvSVMSolutionInfo& si );
384
385    virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
386                              double Cp, double Cn, CvMemStorage* storage,
387                              CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
388    virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
389                               CvMemStorage* storage, CvSVMKernel* kernel,
390                               double* alpha, CvSVMSolutionInfo& si );
391    virtual bool solve_one_class( int count, int var_count, const float** samples,
392                                  CvMemStorage* storage, CvSVMKernel* kernel,
393                                  double* alpha, CvSVMSolutionInfo& si );
394
395    virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
396                                CvMemStorage* storage, CvSVMKernel* kernel,
397                                double* alpha, CvSVMSolutionInfo& si );
398
399    virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
400                               CvMemStorage* storage, CvSVMKernel* kernel,
401                               double* alpha, CvSVMSolutionInfo& si );
402
403    virtual float* get_row_base( int i, bool* _existed );
404    virtual float* get_row( int i, float* dst );
405
406    int sample_count;
407    int var_count;
408    int cache_size;
409    int cache_line_size;
410    const float** samples;
411    const CvSVMParams* params;
412    CvMemStorage* storage;
413    CvSVMKernelRow lru_list;
414    CvSVMKernelRow* rows;
415
416    int alpha_count;
417
418    double* G;
419    double* alpha;
420
421    // -1 - lower bound, 0 - free, 1 - upper bound
422    schar* alpha_status;
423
424    schar* y;
425    double* b;
426    float* buf[2];
427    double eps;
428    int max_iter;
429    double C[2];  // C[0] == Cn, C[1] == Cp
430    CvSVMKernel* kernel;
431
432    SelectWorkingSet select_working_set_func;
433    CalcRho calc_rho_func;
434    GetRow get_row_func;
435
436    virtual bool select_working_set( int& i, int& j );
437    virtual bool select_working_set_nu_svm( int& i, int& j );
438    virtual void calc_rho( double& rho, double& r );
439    virtual void calc_rho_nu_svm( double& rho, double& r );
440
441    virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
442    virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
443    virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
444};
445
446
447struct CvSVMDecisionFunc
448{
449    double rho;
450    int sv_count;
451    double* alpha;
452    int* sv_index;
453};
454
455
456// SVM model
457class CvSVM : public CvStatModel
458{
459public:
460    // SVM type
461    enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
462
463    // SVM kernel type
464    enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3, CHI2=4, INTER=5 };
465
466    // SVM params type
467    enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
468
469    CV_WRAP CvSVM();
470    virtual ~CvSVM();
471
472    CvSVM( const CvMat* trainData, const CvMat* responses,
473           const CvMat* varIdx=0, const CvMat* sampleIdx=0,
474           CvSVMParams params=CvSVMParams() );
475
476    virtual bool train( const CvMat* trainData, const CvMat* responses,
477                        const CvMat* varIdx=0, const CvMat* sampleIdx=0,
478                        CvSVMParams params=CvSVMParams() );
479
480    virtual bool train_auto( const CvMat* trainData, const CvMat* responses,
481        const CvMat* varIdx, const CvMat* sampleIdx, CvSVMParams params,
482        int kfold = 10,
483        CvParamGrid Cgrid      = get_default_grid(CvSVM::C),
484        CvParamGrid gammaGrid  = get_default_grid(CvSVM::GAMMA),
485        CvParamGrid pGrid      = get_default_grid(CvSVM::P),
486        CvParamGrid nuGrid     = get_default_grid(CvSVM::NU),
487        CvParamGrid coeffGrid  = get_default_grid(CvSVM::COEF),
488        CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE),
489        bool balanced=false );
490
491    virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
492    virtual float predict( const CvMat* samples, CV_OUT CvMat* results, bool returnDFVal=false ) const;
493
494    CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses,
495          const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
496          CvSVMParams params=CvSVMParams() );
497
498    CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
499                       const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
500                       CvSVMParams params=CvSVMParams() );
501
502    CV_WRAP virtual bool train_auto( const cv::Mat& trainData, const cv::Mat& responses,
503                            const cv::Mat& varIdx, const cv::Mat& sampleIdx, CvSVMParams params,
504                            int k_fold = 10,
505                            CvParamGrid Cgrid      = CvSVM::get_default_grid(CvSVM::C),
506                            CvParamGrid gammaGrid  = CvSVM::get_default_grid(CvSVM::GAMMA),
507                            CvParamGrid pGrid      = CvSVM::get_default_grid(CvSVM::P),
508                            CvParamGrid nuGrid     = CvSVM::get_default_grid(CvSVM::NU),
509                            CvParamGrid coeffGrid  = CvSVM::get_default_grid(CvSVM::COEF),
510                            CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE),
511                            bool balanced=false);
512    CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;
513    CV_WRAP_AS(predict_all) virtual void predict( cv::InputArray samples, cv::OutputArray results ) const;
514
515    CV_WRAP virtual int get_support_vector_count() const;
516    virtual const float* get_support_vector(int i) const;
517    virtual CvSVMParams get_params() const { return params; }
518    CV_WRAP virtual void clear();
519
520    virtual const CvSVMDecisionFunc* get_decision_function() const { return decision_func; }
521
522    static CvParamGrid get_default_grid( int param_id );
523
524    virtual void write( CvFileStorage* storage, const char* name ) const;
525    virtual void read( CvFileStorage* storage, CvFileNode* node );
526    CV_WRAP int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
527
528protected:
529
530    virtual bool set_params( const CvSVMParams& params );
531    virtual bool train1( int sample_count, int var_count, const float** samples,
532                    const void* responses, double Cp, double Cn,
533                    CvMemStorage* _storage, double* alpha, double& rho );
534    virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
535                    const CvMat* responses, CvMemStorage* _storage, double* alpha );
536    virtual void create_kernel();
537    virtual void create_solver();
538
539    virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const;
540
541    virtual void write_params( CvFileStorage* fs ) const;
542    virtual void read_params( CvFileStorage* fs, CvFileNode* node );
543
544    void optimize_linear_svm();
545
546    CvSVMParams params;
547    CvMat* class_labels;
548    int var_all;
549    float** sv;
550    int sv_total;
551    CvMat* var_idx;
552    CvMat* class_weights;
553    CvSVMDecisionFunc* decision_func;
554    CvMemStorage* storage;
555
556    CvSVMSolver* solver;
557    CvSVMKernel* kernel;
558
559private:
560    CvSVM(const CvSVM&);
561    CvSVM& operator = (const CvSVM&);
562};
563
564/****************************************************************************************\
565*                                      Decision Tree                                     *
566\****************************************************************************************/\
567struct CvPair16u32s
568{
569    unsigned short* u;
570    int* i;
571};
572
573
574#define CV_DTREE_CAT_DIR(idx,subset) \
575    (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
576
577struct CvDTreeSplit
578{
579    int var_idx;
580    int condensed_idx;
581    int inversed;
582    float quality;
583    CvDTreeSplit* next;
584    union
585    {
586        int subset[2];
587        struct
588        {
589            float c;
590            int split_point;
591        }
592        ord;
593    };
594};
595
596struct CvDTreeNode
597{
598    int class_idx;
599    int Tn;
600    double value;
601
602    CvDTreeNode* parent;
603    CvDTreeNode* left;
604    CvDTreeNode* right;
605
606    CvDTreeSplit* split;
607
608    int sample_count;
609    int depth;
610    int* num_valid;
611    int offset;
612    int buf_idx;
613    double maxlr;
614
615    // global pruning data
616    int complexity;
617    double alpha;
618    double node_risk, tree_risk, tree_error;
619
620    // cross-validation pruning data
621    int* cv_Tn;
622    double* cv_node_risk;
623    double* cv_node_error;
624
625    int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
626    void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
627};
628
629
630struct CvDTreeParams
631{
632    CV_PROP_RW int   max_categories;
633    CV_PROP_RW int   max_depth;
634    CV_PROP_RW int   min_sample_count;
635    CV_PROP_RW int   cv_folds;
636    CV_PROP_RW bool  use_surrogates;
637    CV_PROP_RW bool  use_1se_rule;
638    CV_PROP_RW bool  truncate_pruned_tree;
639    CV_PROP_RW float regression_accuracy;
640    const float* priors;
641
642    CvDTreeParams();
643    CvDTreeParams( int max_depth, int min_sample_count,
644                   float regression_accuracy, bool use_surrogates,
645                   int max_categories, int cv_folds,
646                   bool use_1se_rule, bool truncate_pruned_tree,
647                   const float* priors );
648};
649
650
651struct CvDTreeTrainData
652{
653    CvDTreeTrainData();
654    CvDTreeTrainData( const CvMat* trainData, int tflag,
655                      const CvMat* responses, const CvMat* varIdx=0,
656                      const CvMat* sampleIdx=0, const CvMat* varType=0,
657                      const CvMat* missingDataMask=0,
658                      const CvDTreeParams& params=CvDTreeParams(),
659                      bool _shared=false, bool _add_labels=false );
660    virtual ~CvDTreeTrainData();
661
662    virtual void set_data( const CvMat* trainData, int tflag,
663                          const CvMat* responses, const CvMat* varIdx=0,
664                          const CvMat* sampleIdx=0, const CvMat* varType=0,
665                          const CvMat* missingDataMask=0,
666                          const CvDTreeParams& params=CvDTreeParams(),
667                          bool _shared=false, bool _add_labels=false,
668                          bool _update_data=false );
669    virtual void do_responses_copy();
670
671    virtual void get_vectors( const CvMat* _subsample_idx,
672         float* values, uchar* missing, float* responses, bool get_class_idx=false );
673
674    virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
675
676    virtual void write_params( CvFileStorage* fs ) const;
677    virtual void read_params( CvFileStorage* fs, CvFileNode* node );
678
679    // release all the data
680    virtual void clear();
681
682    int get_num_classes() const;
683    int get_var_type(int vi) const;
684    int get_work_var_count() const {return work_var_count;}
685
686    virtual const float* get_ord_responses( CvDTreeNode* n, float* values_buf, int* sample_indices_buf );
687    virtual const int* get_class_labels( CvDTreeNode* n, int* labels_buf );
688    virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
689    virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
690    virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
691    virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
692                                   const float** ord_values, const int** sorted_indices, int* sample_indices_buf );
693    virtual int get_child_buf_idx( CvDTreeNode* n );
694
695    ////////////////////////////////////
696
697    virtual bool set_params( const CvDTreeParams& params );
698    virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
699                                   int storage_idx, int offset );
700
701    virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
702                int split_point, int inversed, float quality );
703    virtual CvDTreeSplit* new_split_cat( int vi, float quality );
704    virtual void free_node_data( CvDTreeNode* node );
705    virtual void free_train_data();
706    virtual void free_node( CvDTreeNode* node );
707
708    int sample_count, var_all, var_count, max_c_count;
709    int ord_var_count, cat_var_count, work_var_count;
710    bool have_labels, have_priors;
711    bool is_classifier;
712    int tflag;
713
714    const CvMat* train_data;
715    const CvMat* responses;
716    CvMat* responses_copy; // used in Boosting
717
718    int buf_count, buf_size; // buf_size is obsolete, please do not use it, use expression ((int64)buf->rows * (int64)buf->cols / buf_count) instead
719    bool shared;
720    int is_buf_16u;
721
722    CvMat* cat_count;
723    CvMat* cat_ofs;
724    CvMat* cat_map;
725
726    CvMat* counts;
727    CvMat* buf;
728    inline size_t get_length_subbuf() const
729    {
730        size_t res = (size_t)(work_var_count + 1) * (size_t)sample_count;
731        return res;
732    }
733
734    CvMat* direction;
735    CvMat* split_buf;
736
737    CvMat* var_idx;
738    CvMat* var_type; // i-th element =
739                     //   k<0  - ordered
740                     //   k>=0 - categorical, see k-th element of cat_* arrays
741    CvMat* priors;
742    CvMat* priors_mult;
743
744    CvDTreeParams params;
745
746    CvMemStorage* tree_storage;
747    CvMemStorage* temp_storage;
748
749    CvDTreeNode* data_root;
750
751    CvSet* node_heap;
752    CvSet* split_heap;
753    CvSet* cv_heap;
754    CvSet* nv_heap;
755
756    cv::RNG* rng;
757};
758
759class CvDTree;
760class CvForestTree;
761
762namespace cv
763{
764    struct DTreeBestSplitFinder;
765    struct ForestTreeBestSplitFinder;
766}
767
768class CvDTree : public CvStatModel
769{
770public:
771    CV_WRAP CvDTree();
772    virtual ~CvDTree();
773
774    virtual bool train( const CvMat* trainData, int tflag,
775                        const CvMat* responses, const CvMat* varIdx=0,
776                        const CvMat* sampleIdx=0, const CvMat* varType=0,
777                        const CvMat* missingDataMask=0,
778                        CvDTreeParams params=CvDTreeParams() );
779
780    virtual bool train( CvMLData* trainData, CvDTreeParams params=CvDTreeParams() );
781
782    // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
783    virtual float calc_error( CvMLData* trainData, int type, std::vector<float> *resp = 0 );
784
785    virtual bool train( CvDTreeTrainData* trainData, const CvMat* subsampleIdx );
786
787    virtual CvDTreeNode* predict( const CvMat* sample, const CvMat* missingDataMask=0,
788                                  bool preprocessedInput=false ) const;
789
790    CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
791                       const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
792                       const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
793                       const cv::Mat& missingDataMask=cv::Mat(),
794                       CvDTreeParams params=CvDTreeParams() );
795
796    CV_WRAP virtual CvDTreeNode* predict( const cv::Mat& sample, const cv::Mat& missingDataMask=cv::Mat(),
797                                  bool preprocessedInput=false ) const;
798    CV_WRAP virtual cv::Mat getVarImportance();
799
800    virtual const CvMat* get_var_importance();
801    CV_WRAP virtual void clear();
802
803    virtual void read( CvFileStorage* fs, CvFileNode* node );
804    virtual void write( CvFileStorage* fs, const char* name ) const;
805
806    // special read & write methods for trees in the tree ensembles
807    virtual void read( CvFileStorage* fs, CvFileNode* node,
808                       CvDTreeTrainData* data );
809    virtual void write( CvFileStorage* fs ) const;
810
811    const CvDTreeNode* get_root() const;
812    int get_pruned_tree_idx() const;
813    CvDTreeTrainData* get_data();
814
815protected:
816    friend struct cv::DTreeBestSplitFinder;
817
818    virtual bool do_train( const CvMat* _subsample_idx );
819
820    virtual void try_split_node( CvDTreeNode* n );
821    virtual void split_node_data( CvDTreeNode* n );
822    virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
823    virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
824                            float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
825    virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
826                            float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
827    virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
828                            float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
829    virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
830                            float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
831    virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
832    virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
833    virtual double calc_node_dir( CvDTreeNode* node );
834    virtual void complete_node_dir( CvDTreeNode* node );
835    virtual void cluster_categories( const int* vectors, int vector_count,
836        int var_count, int* sums, int k, int* cluster_labels );
837
838    virtual void calc_node_value( CvDTreeNode* node );
839
840    virtual void prune_cv();
841    virtual double update_tree_rnc( int T, int fold );
842    virtual int cut_tree( int T, int fold, double min_alpha );
843    virtual void free_prune_data(bool cut_tree);
844    virtual void free_tree();
845
846    virtual void write_node( CvFileStorage* fs, CvDTreeNode* node ) const;
847    virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split ) const;
848    virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
849    virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
850    virtual void write_tree_nodes( CvFileStorage* fs ) const;
851    virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
852
853    CvDTreeNode* root;
854    CvMat* var_importance;
855    CvDTreeTrainData* data;
856    CvMat train_data_hdr, responses_hdr;
857    cv::Mat train_data_mat, responses_mat;
858
859public:
860    int pruned_tree_idx;
861};
862
863
864/****************************************************************************************\
865*                                   Random Trees Classifier                              *
866\****************************************************************************************/
867
868class CvRTrees;
869
870class CvForestTree: public CvDTree
871{
872public:
873    CvForestTree();
874    virtual ~CvForestTree();
875
876    virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx, CvRTrees* forest );
877
878    virtual int get_var_count() const {return data ? data->var_count : 0;}
879    virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
880
881    /* dummy methods to avoid warnings: BEGIN */
882    virtual bool train( const CvMat* trainData, int tflag,
883                        const CvMat* responses, const CvMat* varIdx=0,
884                        const CvMat* sampleIdx=0, const CvMat* varType=0,
885                        const CvMat* missingDataMask=0,
886                        CvDTreeParams params=CvDTreeParams() );
887
888    virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
889    virtual void read( CvFileStorage* fs, CvFileNode* node );
890    virtual void read( CvFileStorage* fs, CvFileNode* node,
891                       CvDTreeTrainData* data );
892    /* dummy methods to avoid warnings: END */
893
894protected:
895    friend struct cv::ForestTreeBestSplitFinder;
896
897    virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
898    CvRTrees* forest;
899};
900
901
902struct CvRTParams : public CvDTreeParams
903{
904    //Parameters for the forest
905    CV_PROP_RW bool calc_var_importance; // true <=> RF processes variable importance
906    CV_PROP_RW int nactive_vars;
907    CV_PROP_RW CvTermCriteria term_crit;
908
909    CvRTParams();
910    CvRTParams( int max_depth, int min_sample_count,
911                float regression_accuracy, bool use_surrogates,
912                int max_categories, const float* priors, bool calc_var_importance,
913                int nactive_vars, int max_num_of_trees_in_the_forest,
914                float forest_accuracy, int termcrit_type );
915};
916
917
918class CvRTrees : public CvStatModel
919{
920public:
921    CV_WRAP CvRTrees();
922    virtual ~CvRTrees();
923    virtual bool train( const CvMat* trainData, int tflag,
924                        const CvMat* responses, const CvMat* varIdx=0,
925                        const CvMat* sampleIdx=0, const CvMat* varType=0,
926                        const CvMat* missingDataMask=0,
927                        CvRTParams params=CvRTParams() );
928
929    virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
930    virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
931    virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;
932
933    CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
934                       const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
935                       const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
936                       const cv::Mat& missingDataMask=cv::Mat(),
937                       CvRTParams params=CvRTParams() );
938    CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
939    CV_WRAP virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
940    CV_WRAP virtual cv::Mat getVarImportance();
941
942    CV_WRAP virtual void clear();
943
944    virtual const CvMat* get_var_importance();
945    virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
946        const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
947
948    virtual float calc_error( CvMLData* data, int type , std::vector<float>* resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
949
950    virtual float get_train_error();
951
952    virtual void read( CvFileStorage* fs, CvFileNode* node );
953    virtual void write( CvFileStorage* fs, const char* name ) const;
954
955    CvMat* get_active_var_mask();
956    CvRNG* get_rng();
957
958    int get_tree_count() const;
959    CvForestTree* get_tree(int i) const;
960
961protected:
962    virtual cv::String getName() const;
963
964    virtual bool grow_forest( const CvTermCriteria term_crit );
965
966    // array of the trees of the forest
967    CvForestTree** trees;
968    CvDTreeTrainData* data;
969    CvMat train_data_hdr, responses_hdr;
970    cv::Mat train_data_mat, responses_mat;
971    int ntrees;
972    int nclasses;
973    double oob_error;
974    CvMat* var_importance;
975    int nsamples;
976
977    cv::RNG* rng;
978    CvMat* active_var_mask;
979};
980
981/****************************************************************************************\
982*                           Extremely randomized trees Classifier                        *
983\****************************************************************************************/
984struct CvERTreeTrainData : public CvDTreeTrainData
985{
986    virtual void set_data( const CvMat* trainData, int tflag,
987                          const CvMat* responses, const CvMat* varIdx=0,
988                          const CvMat* sampleIdx=0, const CvMat* varType=0,
989                          const CvMat* missingDataMask=0,
990                          const CvDTreeParams& params=CvDTreeParams(),
991                          bool _shared=false, bool _add_labels=false,
992                          bool _update_data=false );
993    virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
994                                   const float** ord_values, const int** missing, int* sample_buf = 0 );
995    virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
996    virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
997    virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
998    virtual void get_vectors( const CvMat* _subsample_idx, float* values, uchar* missing,
999                              float* responses, bool get_class_idx=false );
1000    virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
1001    const CvMat* missing_mask;
1002};
1003
1004class CvForestERTree : public CvForestTree
1005{
1006protected:
1007    virtual double calc_node_dir( CvDTreeNode* node );
1008    virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
1009        float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1010    virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
1011        float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1012    virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
1013        float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1014    virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
1015        float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1016    virtual void split_node_data( CvDTreeNode* n );
1017};
1018
1019class CvERTrees : public CvRTrees
1020{
1021public:
1022    CV_WRAP CvERTrees();
1023    virtual ~CvERTrees();
1024    virtual bool train( const CvMat* trainData, int tflag,
1025                        const CvMat* responses, const CvMat* varIdx=0,
1026                        const CvMat* sampleIdx=0, const CvMat* varType=0,
1027                        const CvMat* missingDataMask=0,
1028                        CvRTParams params=CvRTParams());
1029    CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1030                       const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1031                       const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1032                       const cv::Mat& missingDataMask=cv::Mat(),
1033                       CvRTParams params=CvRTParams());
1034    virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
1035protected:
1036    virtual cv::String getName() const;
1037    virtual bool grow_forest( const CvTermCriteria term_crit );
1038};
1039
1040
1041/****************************************************************************************\
1042*                                   Boosted tree classifier                              *
1043\****************************************************************************************/
1044
1045struct CvBoostParams : public CvDTreeParams
1046{
1047    CV_PROP_RW int boost_type;
1048    CV_PROP_RW int weak_count;
1049    CV_PROP_RW int split_criteria;
1050    CV_PROP_RW double weight_trim_rate;
1051
1052    CvBoostParams();
1053    CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
1054                   int max_depth, bool use_surrogates, const float* priors );
1055};
1056
1057
1058class CvBoost;
1059
1060class CvBoostTree: public CvDTree
1061{
1062public:
1063    CvBoostTree();
1064    virtual ~CvBoostTree();
1065
1066    virtual bool train( CvDTreeTrainData* trainData,
1067                        const CvMat* subsample_idx, CvBoost* ensemble );
1068
1069    virtual void scale( double s );
1070    virtual void read( CvFileStorage* fs, CvFileNode* node,
1071                       CvBoost* ensemble, CvDTreeTrainData* _data );
1072    virtual void clear();
1073
1074    /* dummy methods to avoid warnings: BEGIN */
1075    virtual bool train( const CvMat* trainData, int tflag,
1076                        const CvMat* responses, const CvMat* varIdx=0,
1077                        const CvMat* sampleIdx=0, const CvMat* varType=0,
1078                        const CvMat* missingDataMask=0,
1079                        CvDTreeParams params=CvDTreeParams() );
1080    virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
1081
1082    virtual void read( CvFileStorage* fs, CvFileNode* node );
1083    virtual void read( CvFileStorage* fs, CvFileNode* node,
1084                       CvDTreeTrainData* data );
1085    /* dummy methods to avoid warnings: END */
1086
1087protected:
1088
1089    virtual void try_split_node( CvDTreeNode* n );
1090    virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
1091    virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
1092    virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
1093        float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1094    virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
1095        float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1096    virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
1097        float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1098    virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
1099        float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1100    virtual void calc_node_value( CvDTreeNode* n );
1101    virtual double calc_node_dir( CvDTreeNode* n );
1102
1103    CvBoost* ensemble;
1104};
1105
1106
1107class CvBoost : public CvStatModel
1108{
1109public:
1110    // Boosting type
1111    enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
1112
1113    // Splitting criteria
1114    enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
1115
1116    CV_WRAP CvBoost();
1117    virtual ~CvBoost();
1118
1119    CvBoost( const CvMat* trainData, int tflag,
1120             const CvMat* responses, const CvMat* varIdx=0,
1121             const CvMat* sampleIdx=0, const CvMat* varType=0,
1122             const CvMat* missingDataMask=0,
1123             CvBoostParams params=CvBoostParams() );
1124
1125    virtual bool train( const CvMat* trainData, int tflag,
1126             const CvMat* responses, const CvMat* varIdx=0,
1127             const CvMat* sampleIdx=0, const CvMat* varType=0,
1128             const CvMat* missingDataMask=0,
1129             CvBoostParams params=CvBoostParams(),
1130             bool update=false );
1131
1132    virtual bool train( CvMLData* data,
1133             CvBoostParams params=CvBoostParams(),
1134             bool update=false );
1135
1136    virtual float predict( const CvMat* sample, const CvMat* missing=0,
1137                           CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
1138                           bool raw_mode=false, bool return_sum=false ) const;
1139
1140    CV_WRAP CvBoost( const cv::Mat& trainData, int tflag,
1141            const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1142            const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1143            const cv::Mat& missingDataMask=cv::Mat(),
1144            CvBoostParams params=CvBoostParams() );
1145
1146    CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1147                       const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1148                       const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1149                       const cv::Mat& missingDataMask=cv::Mat(),
1150                       CvBoostParams params=CvBoostParams(),
1151                       bool update=false );
1152
1153    CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
1154                                   const cv::Range& slice=cv::Range::all(), bool rawMode=false,
1155                                   bool returnSum=false ) const;
1156
1157    virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
1158
1159    CV_WRAP virtual void prune( CvSlice slice );
1160
1161    CV_WRAP virtual void clear();
1162
1163    virtual void write( CvFileStorage* storage, const char* name ) const;
1164    virtual void read( CvFileStorage* storage, CvFileNode* node );
1165    virtual const CvMat* get_active_vars(bool absolute_idx=true);
1166
1167    CvSeq* get_weak_predictors();
1168
1169    CvMat* get_weights();
1170    CvMat* get_subtree_weights();
1171    CvMat* get_weak_response();
1172    const CvBoostParams& get_params() const;
1173    const CvDTreeTrainData* get_data() const;
1174
1175protected:
1176
1177    virtual bool set_params( const CvBoostParams& params );
1178    virtual void update_weights( CvBoostTree* tree );
1179    virtual void trim_weights();
1180    virtual void write_params( CvFileStorage* fs ) const;
1181    virtual void read_params( CvFileStorage* fs, CvFileNode* node );
1182
1183    virtual void initialize_weights(double (&p)[2]);
1184
1185    CvDTreeTrainData* data;
1186    CvMat train_data_hdr, responses_hdr;
1187    cv::Mat train_data_mat, responses_mat;
1188    CvBoostParams params;
1189    CvSeq* weak;
1190
1191    CvMat* active_vars;
1192    CvMat* active_vars_abs;
1193    bool have_active_cat_vars;
1194
1195    CvMat* orig_response;
1196    CvMat* sum_response;
1197    CvMat* weak_eval;
1198    CvMat* subsample_mask;
1199    CvMat* weights;
1200    CvMat* subtree_weights;
1201    bool have_subsample;
1202};
1203
1204
1205/****************************************************************************************\
1206*                                   Gradient Boosted Trees                               *
1207\****************************************************************************************/
1208
1209// DataType: STRUCT CvGBTreesParams
1210// Parameters of GBT (Gradient Boosted trees model), including single
1211// tree settings and ensemble parameters.
1212//
1213// weak_count          - count of trees in the ensemble
1214// loss_function_type  - loss function used for ensemble training
1215// subsample_portion   - portion of whole training set used for
1216//                       every single tree training.
1217//                       subsample_portion value is in (0.0, 1.0].
1218//                       subsample_portion == 1.0 when whole dataset is
1219//                       used on each step. Count of sample used on each
1220//                       step is computed as
1221//                       int(total_samples_count * subsample_portion).
1222// shrinkage           - regularization parameter.
1223//                       Each tree prediction is multiplied on shrinkage value.
1224
1225
1226struct CvGBTreesParams : public CvDTreeParams
1227{
1228    CV_PROP_RW int weak_count;
1229    CV_PROP_RW int loss_function_type;
1230    CV_PROP_RW float subsample_portion;
1231    CV_PROP_RW float shrinkage;
1232
1233    CvGBTreesParams();
1234    CvGBTreesParams( int loss_function_type, int weak_count, float shrinkage,
1235        float subsample_portion, int max_depth, bool use_surrogates );
1236};
1237
1238// DataType: CLASS CvGBTrees
1239// Gradient Boosting Trees (GBT) algorithm implementation.
1240//
1241// data             - training dataset
1242// params           - parameters of the CvGBTrees
1243// weak             - array[0..(class_count-1)] of CvSeq
1244//                    for storing tree ensembles
1245// orig_response    - original responses of the training set samples
1246// sum_response     - predicitons of the current model on the training dataset.
1247//                    this matrix is updated on every iteration.
1248// sum_response_tmp - predicitons of the model on the training set on the next
1249//                    step. On every iteration values of sum_responses_tmp are
1250//                    computed via sum_responses values. When the current
1251//                    step is complete sum_response values become equal to
1252//                    sum_responses_tmp.
1253// sampleIdx       - indices of samples used for training the ensemble.
1254//                    CvGBTrees training procedure takes a set of samples
1255//                    (train_data) and a set of responses (responses).
1256//                    Only pairs (train_data[i], responses[i]), where i is
1257//                    in sample_idx are used for training the ensemble.
1258// subsample_train  - indices of samples used for training a single decision
1259//                    tree on the current step. This indices are countered
1260//                    relatively to the sample_idx, so that pairs
1261//                    (train_data[sample_idx[i]], responses[sample_idx[i]])
1262//                    are used for training a decision tree.
1263//                    Training set is randomly splited
1264//                    in two parts (subsample_train and subsample_test)
1265//                    on every iteration accordingly to the portion parameter.
1266// subsample_test   - relative indices of samples from the training set,
1267//                    which are not used for training a tree on the current
1268//                    step.
1269// missing          - mask of the missing values in the training set. This
1270//                    matrix has the same size as train_data. 1 - missing
1271//                    value, 0 - not a missing value.
1272// class_labels     - output class labels map.
1273// rng              - random number generator. Used for spliting the
1274//                    training set.
1275// class_count      - count of output classes.
1276//                    class_count == 1 in the case of regression,
1277//                    and > 1 in the case of classification.
1278// delta            - Huber loss function parameter.
1279// base_value       - start point of the gradient descent procedure.
1280//                    model prediction is
1281//                    f(x) = f_0 + sum_{i=1..weak_count-1}(f_i(x)), where
1282//                    f_0 is the base value.
1283
1284
1285
1286class CvGBTrees : public CvStatModel
1287{
1288public:
1289
1290    /*
1291    // DataType: ENUM
1292    // Loss functions implemented in CvGBTrees.
1293    //
1294    // SQUARED_LOSS
1295    // problem: regression
1296    // loss = (x - x')^2
1297    //
1298    // ABSOLUTE_LOSS
1299    // problem: regression
1300    // loss = abs(x - x')
1301    //
1302    // HUBER_LOSS
1303    // problem: regression
1304    // loss = delta*( abs(x - x') - delta/2), if abs(x - x') > delta
1305    //           1/2*(x - x')^2, if abs(x - x') <= delta,
1306    //           where delta is the alpha-quantile of pseudo responses from
1307    //           the training set.
1308    //
1309    // DEVIANCE_LOSS
1310    // problem: classification
1311    //
1312    */
1313    enum {SQUARED_LOSS=0, ABSOLUTE_LOSS, HUBER_LOSS=3, DEVIANCE_LOSS};
1314
1315
1316    /*
1317    // Default constructor. Creates a model only (without training).
1318    // Should be followed by one form of the train(...) function.
1319    //
1320    // API
1321    // CvGBTrees();
1322
1323    // INPUT
1324    // OUTPUT
1325    // RESULT
1326    */
1327    CV_WRAP CvGBTrees();
1328
1329
1330    /*
1331    // Full form constructor. Creates a gradient boosting model and does the
1332    // train.
1333    //
1334    // API
1335    // CvGBTrees( const CvMat* trainData, int tflag,
1336             const CvMat* responses, const CvMat* varIdx=0,
1337             const CvMat* sampleIdx=0, const CvMat* varType=0,
1338             const CvMat* missingDataMask=0,
1339             CvGBTreesParams params=CvGBTreesParams() );
1340
1341    // INPUT
1342    // trainData    - a set of input feature vectors.
1343    //                  size of matrix is
1344    //                  <count of samples> x <variables count>
1345    //                  or <variables count> x <count of samples>
1346    //                  depending on the tflag parameter.
1347    //                  matrix values are float.
1348    // tflag         - a flag showing how do samples stored in the
1349    //                  trainData matrix row by row (tflag=CV_ROW_SAMPLE)
1350    //                  or column by column (tflag=CV_COL_SAMPLE).
1351    // responses     - a vector of responses corresponding to the samples
1352    //                  in trainData.
1353    // varIdx       - indices of used variables. zero value means that all
1354    //                  variables are active.
1355    // sampleIdx    - indices of used samples. zero value means that all
1356    //                  samples from trainData are in the training set.
1357    // varType      - vector of <variables count> length. gives every
1358    //                  variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
1359    //                  varType = 0 means all variables are numerical.
1360    // missingDataMask  - a mask of misiing values in trainData.
1361    //                  missingDataMask = 0 means that there are no missing
1362    //                  values.
1363    // params         - parameters of GTB algorithm.
1364    // OUTPUT
1365    // RESULT
1366    */
1367    CvGBTrees( const CvMat* trainData, int tflag,
1368             const CvMat* responses, const CvMat* varIdx=0,
1369             const CvMat* sampleIdx=0, const CvMat* varType=0,
1370             const CvMat* missingDataMask=0,
1371             CvGBTreesParams params=CvGBTreesParams() );
1372
1373
1374    /*
1375    // Destructor.
1376    */
1377    virtual ~CvGBTrees();
1378
1379
1380    /*
1381    // Gradient tree boosting model training
1382    //
1383    // API
1384    // virtual bool train( const CvMat* trainData, int tflag,
1385             const CvMat* responses, const CvMat* varIdx=0,
1386             const CvMat* sampleIdx=0, const CvMat* varType=0,
1387             const CvMat* missingDataMask=0,
1388             CvGBTreesParams params=CvGBTreesParams(),
1389             bool update=false );
1390
1391    // INPUT
1392    // trainData    - a set of input feature vectors.
1393    //                  size of matrix is
1394    //                  <count of samples> x <variables count>
1395    //                  or <variables count> x <count of samples>
1396    //                  depending on the tflag parameter.
1397    //                  matrix values are float.
1398    // tflag         - a flag showing how do samples stored in the
1399    //                  trainData matrix row by row (tflag=CV_ROW_SAMPLE)
1400    //                  or column by column (tflag=CV_COL_SAMPLE).
1401    // responses     - a vector of responses corresponding to the samples
1402    //                  in trainData.
1403    // varIdx       - indices of used variables. zero value means that all
1404    //                  variables are active.
1405    // sampleIdx    - indices of used samples. zero value means that all
1406    //                  samples from trainData are in the training set.
1407    // varType      - vector of <variables count> length. gives every
1408    //                  variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
1409    //                  varType = 0 means all variables are numerical.
1410    // missingDataMask  - a mask of misiing values in trainData.
1411    //                  missingDataMask = 0 means that there are no missing
1412    //                  values.
1413    // params         - parameters of GTB algorithm.
1414    // update         - is not supported now. (!)
1415    // OUTPUT
1416    // RESULT
1417    // Error state.
1418    */
1419    virtual bool train( const CvMat* trainData, int tflag,
1420             const CvMat* responses, const CvMat* varIdx=0,
1421             const CvMat* sampleIdx=0, const CvMat* varType=0,
1422             const CvMat* missingDataMask=0,
1423             CvGBTreesParams params=CvGBTreesParams(),
1424             bool update=false );
1425
1426
1427    /*
1428    // Gradient tree boosting model training
1429    //
1430    // API
1431    // virtual bool train( CvMLData* data,
1432             CvGBTreesParams params=CvGBTreesParams(),
1433             bool update=false ) {return false;}
1434
1435    // INPUT
1436    // data          - training set.
1437    // params        - parameters of GTB algorithm.
1438    // update        - is not supported now. (!)
1439    // OUTPUT
1440    // RESULT
1441    // Error state.
1442    */
1443    virtual bool train( CvMLData* data,
1444             CvGBTreesParams params=CvGBTreesParams(),
1445             bool update=false );
1446
1447
1448    /*
1449    // Response value prediction
1450    //
1451    // API
1452    // virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
1453             CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
1454             int k=-1 ) const;
1455
1456    // INPUT
1457    // sample         - input sample of the same type as in the training set.
1458    // missing        - missing values mask. missing=0 if there are no
1459    //                   missing values in sample vector.
1460    // weak_responses  - predictions of all of the trees.
1461    //                   not implemented (!)
1462    // slice           - part of the ensemble used for prediction.
1463    //                   slice = CV_WHOLE_SEQ when all trees are used.
1464    // k               - number of ensemble used.
1465    //                   k is in {-1,0,1,..,<count of output classes-1>}.
1466    //                   in the case of classification problem
1467    //                   <count of output classes-1> ensembles are built.
1468    //                   If k = -1 ordinary prediction is the result,
1469    //                   otherwise function gives the prediction of the
1470    //                   k-th ensemble only.
1471    // OUTPUT
1472    // RESULT
1473    // Predicted value.
1474    */
1475    virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
1476            CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
1477            int k=-1 ) const;
1478
1479    /*
1480    // Response value prediction.
1481    // Parallel version (in the case of TBB existence)
1482    //
1483    // API
1484    // virtual float predict( const CvMat* sample, const CvMat* missing=0,
1485             CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
1486             int k=-1 ) const;
1487
1488    // INPUT
1489    // sample         - input sample of the same type as in the training set.
1490    // missing        - missing values mask. missing=0 if there are no
1491    //                   missing values in sample vector.
1492    // weak_responses  - predictions of all of the trees.
1493    //                   not implemented (!)
1494    // slice           - part of the ensemble used for prediction.
1495    //                   slice = CV_WHOLE_SEQ when all trees are used.
1496    // k               - number of ensemble used.
1497    //                   k is in {-1,0,1,..,<count of output classes-1>}.
1498    //                   in the case of classification problem
1499    //                   <count of output classes-1> ensembles are built.
1500    //                   If k = -1 ordinary prediction is the result,
1501    //                   otherwise function gives the prediction of the
1502    //                   k-th ensemble only.
1503    // OUTPUT
1504    // RESULT
1505    // Predicted value.
1506    */
1507    virtual float predict( const CvMat* sample, const CvMat* missing=0,
1508            CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
1509            int k=-1 ) const;
1510
1511    /*
1512    // Deletes all the data.
1513    //
1514    // API
1515    // virtual void clear();
1516
1517    // INPUT
1518    // OUTPUT
1519    // delete data, weak, orig_response, sum_response,
1520    //        weak_eval, subsample_train, subsample_test,
1521    //        sample_idx, missing, lass_labels
1522    // delta = 0.0
1523    // RESULT
1524    */
1525    CV_WRAP virtual void clear();
1526
1527    /*
1528    // Compute error on the train/test set.
1529    //
1530    // API
1531    // virtual float calc_error( CvMLData* _data, int type,
1532    //        std::vector<float> *resp = 0 );
1533    //
1534    // INPUT
1535    // data  - dataset
1536    // type  - defines which error is to compute: train (CV_TRAIN_ERROR) or
1537    //         test (CV_TEST_ERROR).
1538    // OUTPUT
1539    // resp  - vector of predicitons
1540    // RESULT
1541    // Error value.
1542    */
1543    virtual float calc_error( CvMLData* _data, int type,
1544            std::vector<float> *resp = 0 );
1545
1546    /*
1547    //
1548    // Write parameters of the gtb model and data. Write learned model.
1549    //
1550    // API
1551    // virtual void write( CvFileStorage* fs, const char* name ) const;
1552    //
1553    // INPUT
1554    // fs     - file storage to read parameters from.
1555    // name   - model name.
1556    // OUTPUT
1557    // RESULT
1558    */
1559    virtual void write( CvFileStorage* fs, const char* name ) const;
1560
1561
1562    /*
1563    //
1564    // Read parameters of the gtb model and data. Read learned model.
1565    //
1566    // API
1567    // virtual void read( CvFileStorage* fs, CvFileNode* node );
1568    //
1569    // INPUT
1570    // fs     - file storage to read parameters from.
1571    // node   - file node.
1572    // OUTPUT
1573    // RESULT
1574    */
1575    virtual void read( CvFileStorage* fs, CvFileNode* node );
1576
1577
1578    // new-style C++ interface
1579    CV_WRAP CvGBTrees( const cv::Mat& trainData, int tflag,
1580              const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1581              const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1582              const cv::Mat& missingDataMask=cv::Mat(),
1583              CvGBTreesParams params=CvGBTreesParams() );
1584
1585    CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1586                       const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1587                       const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1588                       const cv::Mat& missingDataMask=cv::Mat(),
1589                       CvGBTreesParams params=CvGBTreesParams(),
1590                       bool update=false );
1591
1592    CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
1593                           const cv::Range& slice = cv::Range::all(),
1594                           int k=-1 ) const;
1595
1596protected:
1597
1598    /*
1599    // Compute the gradient vector components.
1600    //
1601    // API
1602    // virtual void find_gradient( const int k = 0);
1603
1604    // INPUT
1605    // k        - used for classification problem, determining current
1606    //            tree ensemble.
1607    // OUTPUT
1608    // changes components of data->responses
1609    // which correspond to samples used for training
1610    // on the current step.
1611    // RESULT
1612    */
1613    virtual void find_gradient( const int k = 0);
1614
1615
1616    /*
1617    //
1618    // Change values in tree leaves according to the used loss function.
1619    //
1620    // API
1621    // virtual void change_values(CvDTree* tree, const int k = 0);
1622    //
1623    // INPUT
1624    // tree      - decision tree to change.
1625    // k         - used for classification problem, determining current
1626    //             tree ensemble.
1627    // OUTPUT
1628    // changes 'value' fields of the trees' leaves.
1629    // changes sum_response_tmp.
1630    // RESULT
1631    */
1632    virtual void change_values(CvDTree* tree, const int k = 0);
1633
1634
1635    /*
1636    //
1637    // Find optimal constant prediction value according to the used loss
1638    // function.
1639    // The goal is to find a constant which gives the minimal summary loss
1640    // on the _Idx samples.
1641    //
1642    // API
1643    // virtual float find_optimal_value( const CvMat* _Idx );
1644    //
1645    // INPUT
1646    // _Idx        - indices of the samples from the training set.
1647    // OUTPUT
1648    // RESULT
1649    // optimal constant value.
1650    */
1651    virtual float find_optimal_value( const CvMat* _Idx );
1652
1653
1654    /*
1655    //
1656    // Randomly split the whole training set in two parts according
1657    // to params.portion.
1658    //
1659    // API
1660    // virtual void do_subsample();
1661    //
1662    // INPUT
1663    // OUTPUT
1664    // subsample_train - indices of samples used for training
1665    // subsample_test  - indices of samples used for test
1666    // RESULT
1667    */
1668    virtual void do_subsample();
1669
1670
1671    /*
1672    //
1673    // Internal recursive function giving an array of subtree tree leaves.
1674    //
1675    // API
1676    // void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
1677    //
1678    // INPUT
1679    // node         - current leaf.
1680    // OUTPUT
1681    // count        - count of leaves in the subtree.
1682    // leaves       - array of pointers to leaves.
1683    // RESULT
1684    */
1685    void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
1686
1687
1688    /*
1689    //
1690    // Get leaves of the tree.
1691    //
1692    // API
1693    // CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
1694    //
1695    // INPUT
1696    // dtree            - decision tree.
1697    // OUTPUT
1698    // len              - count of the leaves.
1699    // RESULT
1700    // CvDTreeNode**    - array of pointers to leaves.
1701    */
1702    CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
1703
1704
1705    /*
1706    //
1707    // Is it a regression or a classification.
1708    //
1709    // API
1710    // bool problem_type();
1711    //
1712    // INPUT
1713    // OUTPUT
1714    // RESULT
1715    // false if it is a classification problem,
1716    // true - if regression.
1717    */
1718    virtual bool problem_type() const;
1719
1720
1721    /*
1722    //
1723    // Write parameters of the gtb model.
1724    //
1725    // API
1726    // virtual void write_params( CvFileStorage* fs ) const;
1727    //
1728    // INPUT
1729    // fs           - file storage to write parameters to.
1730    // OUTPUT
1731    // RESULT
1732    */
1733    virtual void write_params( CvFileStorage* fs ) const;
1734
1735
1736    /*
1737    //
1738    // Read parameters of the gtb model and data.
1739    //
1740    // API
1741    // virtual void read_params( CvFileStorage* fs );
1742    //
1743    // INPUT
1744    // fs           - file storage to read parameters from.
1745    // OUTPUT
1746    // params       - parameters of the gtb model.
1747    // data         - contains information about the structure
1748    //                of the data set (count of variables,
1749    //                their types, etc.).
1750    // class_labels - output class labels map.
1751    // RESULT
1752    */
1753    virtual void read_params( CvFileStorage* fs, CvFileNode* fnode );
1754    int get_len(const CvMat* mat) const;
1755
1756
1757    CvDTreeTrainData* data;
1758    CvGBTreesParams params;
1759
1760    CvSeq** weak;
1761    CvMat* orig_response;
1762    CvMat* sum_response;
1763    CvMat* sum_response_tmp;
1764    CvMat* sample_idx;
1765    CvMat* subsample_train;
1766    CvMat* subsample_test;
1767    CvMat* missing;
1768    CvMat* class_labels;
1769
1770    cv::RNG* rng;
1771
1772    int class_count;
1773    float delta;
1774    float base_value;
1775
1776};
1777
1778
1779
1780/****************************************************************************************\
1781*                              Artificial Neural Networks (ANN)                          *
1782\****************************************************************************************/
1783
1784/////////////////////////////////// Multi-Layer Perceptrons //////////////////////////////
1785
1786struct CvANN_MLP_TrainParams
1787{
1788    CvANN_MLP_TrainParams();
1789    CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
1790                           double param1, double param2=0 );
1791    ~CvANN_MLP_TrainParams();
1792
1793    enum { BACKPROP=0, RPROP=1 };
1794
1795    CV_PROP_RW CvTermCriteria term_crit;
1796    CV_PROP_RW int train_method;
1797
1798    // backpropagation parameters
1799    CV_PROP_RW double bp_dw_scale, bp_moment_scale;
1800
1801    // rprop parameters
1802    CV_PROP_RW double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
1803};
1804
1805
1806class CvANN_MLP : public CvStatModel
1807{
1808public:
1809    CV_WRAP CvANN_MLP();
1810    CvANN_MLP( const CvMat* layerSizes,
1811               int activateFunc=CvANN_MLP::SIGMOID_SYM,
1812               double fparam1=0, double fparam2=0 );
1813
1814    virtual ~CvANN_MLP();
1815
1816    virtual void create( const CvMat* layerSizes,
1817                         int activateFunc=CvANN_MLP::SIGMOID_SYM,
1818                         double fparam1=0, double fparam2=0 );
1819
1820    virtual int train( const CvMat* inputs, const CvMat* outputs,
1821                       const CvMat* sampleWeights, const CvMat* sampleIdx=0,
1822                       CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
1823                       int flags=0 );
1824    virtual float predict( const CvMat* inputs, CV_OUT CvMat* outputs ) const;
1825
1826    CV_WRAP CvANN_MLP( const cv::Mat& layerSizes,
1827              int activateFunc=CvANN_MLP::SIGMOID_SYM,
1828              double fparam1=0, double fparam2=0 );
1829
1830    CV_WRAP virtual void create( const cv::Mat& layerSizes,
1831                        int activateFunc=CvANN_MLP::SIGMOID_SYM,
1832                        double fparam1=0, double fparam2=0 );
1833
1834    CV_WRAP virtual int train( const cv::Mat& inputs, const cv::Mat& outputs,
1835                      const cv::Mat& sampleWeights, const cv::Mat& sampleIdx=cv::Mat(),
1836                      CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
1837                      int flags=0 );
1838
1839    CV_WRAP virtual float predict( const cv::Mat& inputs, CV_OUT cv::Mat& outputs ) const;
1840
1841    CV_WRAP virtual void clear();
1842
1843    // possible activation functions
1844    enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
1845
1846    // available training flags
1847    enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
1848
1849    virtual void read( CvFileStorage* fs, CvFileNode* node );
1850    virtual void write( CvFileStorage* storage, const char* name ) const;
1851
1852    int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
1853    const CvMat* get_layer_sizes() { return layer_sizes; }
1854    double* get_weights(int layer)
1855    {
1856        return layer_sizes && weights &&
1857            (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
1858    }
1859
1860    virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
1861
1862protected:
1863
1864    virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
1865            const CvMat* _sample_weights, const CvMat* sampleIdx,
1866            CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
1867
1868    // sequential random backpropagation
1869    virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
1870
1871    // RPROP algorithm
1872    virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
1873
1874    virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
1875    virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
1876                                 double _f_param1=0, double _f_param2=0 );
1877    virtual void init_weights();
1878    virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
1879    virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
1880    virtual void calc_input_scale( const CvVectors* vecs, int flags );
1881    virtual void calc_output_scale( const CvVectors* vecs, int flags );
1882
1883    virtual void write_params( CvFileStorage* fs ) const;
1884    virtual void read_params( CvFileStorage* fs, CvFileNode* node );
1885
1886    CvMat* layer_sizes;
1887    CvMat* wbuf;
1888    CvMat* sample_weights;
1889    double** weights;
1890    double f_param1, f_param2;
1891    double min_val, max_val, min_val1, max_val1;
1892    int activ_func;
1893    int max_count, max_buf_sz;
1894    CvANN_MLP_TrainParams params;
1895    cv::RNG* rng;
1896};
1897
1898/****************************************************************************************\
1899*                           Auxilary functions declarations                              *
1900\****************************************************************************************/
1901
1902/* Generates <sample> from multivariate normal distribution, where <mean> - is an
1903   average row vector, <cov> - symmetric covariation matrix */
1904CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample,
1905                           CvRNG* rng CV_DEFAULT(0) );
1906
1907/* Generates sample from gaussian mixture distribution */
1908CVAPI(void) cvRandGaussMixture( CvMat* means[],
1909                               CvMat* covs[],
1910                               float weights[],
1911                               int clsnum,
1912                               CvMat* sample,
1913                               CvMat* sampClasses CV_DEFAULT(0) );
1914
1915#define CV_TS_CONCENTRIC_SPHERES 0
1916
1917/* creates test set */
1918CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
1919                 int num_samples,
1920                 int num_features,
1921                 CvMat** responses,
1922                 int num_classes, ... );
1923
1924/****************************************************************************************\
1925*                                      Data                                             *
1926\****************************************************************************************/
1927
1928#define CV_COUNT     0
1929#define CV_PORTION   1
1930
1931struct CvTrainTestSplit
1932{
1933    CvTrainTestSplit();
1934    CvTrainTestSplit( int train_sample_count, bool mix = true);
1935    CvTrainTestSplit( float train_sample_portion, bool mix = true);
1936
1937    union
1938    {
1939        int count;
1940        float portion;
1941    } train_sample_part;
1942    int train_sample_part_mode;
1943
1944    bool mix;
1945};
1946
1947class CvMLData
1948{
1949public:
1950    CvMLData();
1951    virtual ~CvMLData();
1952
1953    // returns:
1954    // 0 - OK
1955    // -1 - file can not be opened or is not correct
1956    int read_csv( const char* filename );
1957
1958    const CvMat* get_values() const;
1959    const CvMat* get_responses();
1960    const CvMat* get_missing() const;
1961
1962    void set_header_lines_number( int n );
1963    int get_header_lines_number() const;
1964
1965    void set_response_idx( int idx ); // old response become predictors, new response_idx = idx
1966                                      // if idx < 0 there will be no response
1967    int get_response_idx() const;
1968
1969    void set_train_test_split( const CvTrainTestSplit * spl );
1970    const CvMat* get_train_sample_idx() const;
1971    const CvMat* get_test_sample_idx() const;
1972    void mix_train_and_test_idx();
1973
1974    const CvMat* get_var_idx();
1975    void chahge_var_idx( int vi, bool state ); // misspelled (saved for back compitability),
1976                                               // use change_var_idx
1977    void change_var_idx( int vi, bool state ); // state == true to set vi-variable as predictor
1978
1979    const CvMat* get_var_types();
1980    int get_var_type( int var_idx ) const;
1981    // following 2 methods enable to change vars type
1982    // use these methods to assign CV_VAR_CATEGORICAL type for categorical variable
1983    // with numerical labels; in the other cases var types are correctly determined automatically
1984    void set_var_types( const char* str );  // str examples:
1985                                            // "ord[0-17],cat[18]", "ord[0,2,4,10-12], cat[1,3,5-9,13,14]",
1986                                            // "cat", "ord" (all vars are categorical/ordered)
1987    void change_var_type( int var_idx, int type); // type in { CV_VAR_ORDERED, CV_VAR_CATEGORICAL }
1988
1989    void set_delimiter( char ch );
1990    char get_delimiter() const;
1991
1992    void set_miss_ch( char ch );
1993    char get_miss_ch() const;
1994
1995    const std::map<cv::String, int>& get_class_labels_map() const;
1996
1997protected:
1998    virtual void clear();
1999
2000    void str_to_flt_elem( const char* token, float& flt_elem, int& type);
2001    void free_train_test_idx();
2002
2003    char delimiter;
2004    char miss_ch;
2005    //char flt_separator;
2006
2007    CvMat* values;
2008    CvMat* missing;
2009    CvMat* var_types;
2010    CvMat* var_idx_mask;
2011
2012    CvMat* response_out; // header
2013    CvMat* var_idx_out; // mat
2014    CvMat* var_types_out; // mat
2015
2016    int header_lines_number;
2017
2018    int response_idx;
2019
2020    int train_sample_count;
2021    bool mix;
2022
2023    int total_class_count;
2024    std::map<cv::String, int> class_map;
2025
2026    CvMat* train_sample_idx;
2027    CvMat* test_sample_idx;
2028    int* sample_idx; // data of train_sample_idx and test_sample_idx
2029
2030    cv::RNG* rng;
2031};
2032
2033
2034namespace cv
2035{
2036
2037typedef CvStatModel StatModel;
2038typedef CvParamGrid ParamGrid;
2039typedef CvNormalBayesClassifier NormalBayesClassifier;
2040typedef CvKNearest KNearest;
2041typedef CvSVMParams SVMParams;
2042typedef CvSVMKernel SVMKernel;
2043typedef CvSVMSolver SVMSolver;
2044typedef CvSVM SVM;
2045typedef CvDTreeParams DTreeParams;
2046typedef CvMLData TrainData;
2047typedef CvDTree DecisionTree;
2048typedef CvForestTree ForestTree;
2049typedef CvRTParams RandomTreeParams;
2050typedef CvRTrees RandomTrees;
2051typedef CvERTreeTrainData ERTreeTRainData;
2052typedef CvForestERTree ERTree;
2053typedef CvERTrees ERTrees;
2054typedef CvBoostParams BoostParams;
2055typedef CvBoostTree BoostTree;
2056typedef CvBoost Boost;
2057typedef CvANN_MLP_TrainParams ANN_MLP_TrainParams;
2058typedef CvANN_MLP NeuralNet_MLP;
2059typedef CvGBTreesParams GradientBoostingTreeParams;
2060typedef CvGBTrees GradientBoostingTrees;
2061
2062template<> void DefaultDeleter<CvDTreeSplit>::operator ()(CvDTreeSplit* obj) const;
2063}
2064
2065#endif // __cplusplus
2066#endif // __OPENCV_ML_HPP__
2067
2068/* End of file. */
2069