1793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler#ifndef _OPENCV_BOOST_H_
2793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler#define _OPENCV_BOOST_H_
3793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
4793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler#include "traincascade_features.h"
5793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler#include "old_ml.hpp"
6793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
7793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerstruct CvCascadeBoostParams : CvBoostParams
8793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler{
9793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    float minHitRate;
10793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    float maxFalseAlarm;
11793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
12793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CvCascadeBoostParams();
13793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CvCascadeBoostParams( int _boostType, float _minHitRate, float _maxFalseAlarm,
14793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                          double _weightTrimRate, int _maxDepth, int _maxWeakCount );
15793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    virtual ~CvCascadeBoostParams() {}
16793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void write( cv::FileStorage &fs ) const;
17793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    bool read( const cv::FileNode &node );
18793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    virtual void printDefaults() const;
19793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    virtual void printAttrs() const;
20793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    virtual bool scanAttr( const std::string prmName, const std::string val);
21793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler};
22793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
23793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerstruct CvCascadeBoostTrainData : CvDTreeTrainData
24793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler{
25793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CvCascadeBoostTrainData( const CvFeatureEvaluator* _featureEvaluator,
26793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                             const CvDTreeParams& _params );
27793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CvCascadeBoostTrainData( const CvFeatureEvaluator* _featureEvaluator,
28793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                             int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize,
29793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                             const CvDTreeParams& _params = CvDTreeParams() );
30793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    virtual void setData( const CvFeatureEvaluator* _featureEvaluator,
31793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                          int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize,
32793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                          const CvDTreeParams& _params=CvDTreeParams() );
33793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void precalculate();
34793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
35793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
36793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
37793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    virtual const int* get_class_labels( CvDTreeNode* n, int* labelsBuf );
38793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    virtual const int* get_cv_labels( CvDTreeNode* n, int* labelsBuf);
39793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    virtual const int* get_sample_indices( CvDTreeNode* n, int* indicesBuf );
40793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
41793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ordValuesBuf, int* sortedIndicesBuf,
42793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                                  const float** ordValues, const int** sortedIndices, int* sampleIndicesBuf );
43793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* catValuesBuf );
44793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    virtual float getVarValue( int vi, int si );
45793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    virtual void free_train_data();
46793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
47793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    const CvFeatureEvaluator* featureEvaluator;
48793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    cv::Mat valCache; // precalculated feature values (CV_32FC1)
49793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CvMat _resp; // for casting
50793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    int numPrecalcVal, numPrecalcIdx;
51793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler};
52793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
53793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerclass CvCascadeBoostTree : public CvBoostTree
54793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler{
55793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerpublic:
56793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    virtual CvDTreeNode* predict( int sampleIdx ) const;
57793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void write( cv::FileStorage &fs, const cv::Mat& featureMap );
58793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void read( const cv::FileNode &node, CvBoost* _ensemble, CvDTreeTrainData* _data );
59793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void markFeaturesInMap( cv::Mat& featureMap );
60793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerprotected:
61793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    virtual void split_node_data( CvDTreeNode* n );
62793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler};
63793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
64793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerclass CvCascadeBoost : public CvBoost
65793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler{
66793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerpublic:
67793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    virtual bool train( const CvFeatureEvaluator* _featureEvaluator,
68793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize,
69793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        const CvCascadeBoostParams& _params=CvCascadeBoostParams() );
70793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    virtual float predict( int sampleIdx, bool returnSum = false ) const;
71793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
72793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    float getThreshold() const { return threshold; }
73793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void write( cv::FileStorage &fs, const cv::Mat& featureMap ) const;
74793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    bool read( const cv::FileNode &node, const CvFeatureEvaluator* _featureEvaluator,
75793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler               const CvCascadeBoostParams& _params );
76793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void markUsedFeaturesInMap( cv::Mat& featureMap );
77793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerprotected:
78793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    virtual bool set_params( const CvBoostParams& _params );
79793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    virtual void update_weights( CvBoostTree* tree );
80793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    virtual bool isErrDesired();
81793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
82793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    float threshold;
83793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    float minHitRate, maxFalseAlarm;
84793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler};
85793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
86793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler#endif
87