1/*M///////////////////////////////////////////////////////////////////////////////////////
2//
3//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4//
5//  By downloading, copying, installing or using the software you agree to this license.
6//  If you do not agree to this license, do not download, install,
7//  copy or use the software.
8//
9//
10//                        Intel License Agreement
11//
12// Copyright (C) 2000, Intel Corporation, all rights reserved.
13// Third party copyrights are property of their respective owners.
14//
15// Redistribution and use in source and binary forms, with or without modification,
16// are permitted provided that the following conditions are met:
17//
18//   * Redistribution's of source code must retain the above copyright notice,
19//     this list of conditions and the following disclaimer.
20//
21//   * Redistribution's in binary form must reproduce the above copyright notice,
22//     this list of conditions and the following disclaimer in the documentation
23//     and/or other materials provided with the distribution.
24//
25//   * The name of Intel Corporation may not be used to endorse or promote products
26//     derived from this software without specific prior written permission.
27//
28// This software is provided by the copyright holders and contributors "as is" and
29// any express or implied warranties, including, but not limited to, the implied
30// warranties of merchantability and fitness for a particular purpose are disclaimed.
31// In no event shall the Intel Corporation or contributors be liable for any direct,
32// indirect, incidental, special, exemplary, or consequential damages
33// (including, but not limited to, procurement of substitute goods or services;
34// loss of use, data, or profits; or business interruption) however caused
35// and on any theory of liability, whether in contract, strict liability,
36// or tort (including negligence or otherwise) arising in any way out of
37// the use of this software, even if advised of the possibility of such damage.
38//
39//M*/
40
41#ifndef __OPENCV_ML_PRECOMP_HPP__
42#define __OPENCV_ML_PRECOMP_HPP__
43
44#include "opencv2/core.hpp"
45#include "opencv2/ml.hpp"
46#include "opencv2/core/core_c.h"
47#include "opencv2/core/utility.hpp"
48
49#include "opencv2/core/private.hpp"
50
51#include <assert.h>
52#include <float.h>
53#include <limits.h>
54#include <math.h>
55#include <stdlib.h>
56#include <stdio.h>
57#include <string.h>
58#include <time.h>
59#include <vector>
60
61/****************************************************************************************\
62 *                               Main struct definitions                                  *
63 \****************************************************************************************/
64
65/* log(2*PI) */
66#define CV_LOG2PI (1.8378770664093454835606594728112)
67
68namespace cv
69{
70namespace ml
71{
72    using std::vector;
73
74    #define CV_DTREE_CAT_DIR(idx,subset) \
75        (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
76
77    template<typename _Tp> struct cmp_lt_idx
78    {
79        cmp_lt_idx(const _Tp* _arr) : arr(_arr) {}
80        bool operator ()(int a, int b) const { return arr[a] < arr[b]; }
81        const _Tp* arr;
82    };
83
84    template<typename _Tp> struct cmp_lt_ptr
85    {
86        cmp_lt_ptr() {}
87        bool operator ()(const _Tp* a, const _Tp* b) const { return *a < *b; }
88    };
89
90    static inline void setRangeVector(std::vector<int>& vec, int n)
91    {
92        vec.resize(n);
93        for( int i = 0; i < n; i++ )
94            vec[i] = i;
95    }
96
97    static inline void writeTermCrit(FileStorage& fs, const TermCriteria& termCrit)
98    {
99        if( (termCrit.type & TermCriteria::EPS) != 0 )
100            fs << "epsilon" << termCrit.epsilon;
101        if( (termCrit.type & TermCriteria::COUNT) != 0 )
102            fs << "iterations" << termCrit.maxCount;
103    }
104
105    static inline TermCriteria readTermCrit(const FileNode& fn)
106    {
107        TermCriteria termCrit;
108        double epsilon = (double)fn["epsilon"];
109        if( epsilon > 0 )
110        {
111            termCrit.type |= TermCriteria::EPS;
112            termCrit.epsilon = epsilon;
113        }
114        int iters = (int)fn["iterations"];
115        if( iters > 0 )
116        {
117            termCrit.type |= TermCriteria::COUNT;
118            termCrit.maxCount = iters;
119        }
120        return termCrit;
121    }
122
123    struct TreeParams
124    {
125        TreeParams();
126        TreeParams( int maxDepth, int minSampleCount,
127                    double regressionAccuracy, bool useSurrogates,
128                    int maxCategories, int CVFolds,
129                    bool use1SERule, bool truncatePrunedTree,
130                    const Mat& priors );
131
132        inline void setMaxCategories(int val)
133        {
134            if( val < 2 )
135                CV_Error( CV_StsOutOfRange, "max_categories should be >= 2" );
136            maxCategories = std::min(val, 15 );
137        }
138        inline void setMaxDepth(int val)
139        {
140            if( val < 0 )
141                CV_Error( CV_StsOutOfRange, "max_depth should be >= 0" );
142            maxDepth = std::min( val, 25 );
143        }
144        inline void setMinSampleCount(int val)
145        {
146            minSampleCount = std::max(val, 1);
147        }
148        inline void setCVFolds(int val)
149        {
150            if( val < 0 )
151                CV_Error( CV_StsOutOfRange,
152                          "params.CVFolds should be =0 (the tree is not pruned) "
153                          "or n>0 (tree is pruned using n-fold cross-validation)" );
154            if( val == 1 )
155                val = 0;
156            CVFolds = val;
157        }
158        inline void setRegressionAccuracy(float val)
159        {
160            if( val < 0 )
161                CV_Error( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
162            regressionAccuracy = val;
163        }
164
165        inline int getMaxCategories() const { return maxCategories; }
166        inline int getMaxDepth() const { return maxDepth; }
167        inline int getMinSampleCount() const { return minSampleCount; }
168        inline int getCVFolds() const { return CVFolds; }
169        inline float getRegressionAccuracy() const { return regressionAccuracy; }
170
171        CV_IMPL_PROPERTY(bool, UseSurrogates, useSurrogates)
172        CV_IMPL_PROPERTY(bool, Use1SERule, use1SERule)
173        CV_IMPL_PROPERTY(bool, TruncatePrunedTree, truncatePrunedTree)
174        CV_IMPL_PROPERTY_S(cv::Mat, Priors, priors)
175
176        public:
177            bool  useSurrogates;
178        bool  use1SERule;
179        bool  truncatePrunedTree;
180        Mat priors;
181
182    protected:
183        int   maxCategories;
184        int   maxDepth;
185        int   minSampleCount;
186        int   CVFolds;
187        float regressionAccuracy;
188    };
189
190    struct RTreeParams
191    {
192        RTreeParams();
193        RTreeParams(bool calcVarImportance, int nactiveVars, TermCriteria termCrit );
194        bool calcVarImportance;
195        int nactiveVars;
196        TermCriteria termCrit;
197    };
198
199    struct BoostTreeParams
200    {
201        BoostTreeParams();
202        BoostTreeParams(int boostType, int weakCount, double weightTrimRate);
203        int boostType;
204        int weakCount;
205        double weightTrimRate;
206    };
207
208    class DTreesImpl : public DTrees
209    {
210    public:
211        struct WNode
212        {
213            WNode()
214            {
215                class_idx = sample_count = depth = complexity = 0;
216                parent = left = right = split = defaultDir = -1;
217                Tn = INT_MAX;
218                value = maxlr = alpha = node_risk = tree_risk = tree_error = 0.;
219            }
220
221            int class_idx;
222            double Tn;
223            double value;
224
225            int parent;
226            int left;
227            int right;
228            int defaultDir;
229
230            int split;
231
232            int sample_count;
233            int depth;
234            double maxlr;
235
236            // global pruning data
237            int complexity;
238            double alpha;
239            double node_risk, tree_risk, tree_error;
240        };
241
242        struct WSplit
243        {
244            WSplit()
245            {
246                varIdx = next = 0;
247                inversed = false;
248                quality = c = 0.f;
249                subsetOfs = -1;
250            }
251
252            int varIdx;
253            bool inversed;
254            float quality;
255            int next;
256            float c;
257            int subsetOfs;
258        };
259
260        struct WorkData
261        {
262            WorkData(const Ptr<TrainData>& _data);
263
264            Ptr<TrainData> data;
265            vector<WNode> wnodes;
266            vector<WSplit> wsplits;
267            vector<int> wsubsets;
268            vector<double> cv_Tn;
269            vector<double> cv_node_risk;
270            vector<double> cv_node_error;
271            vector<int> cv_labels;
272            vector<double> sample_weights;
273            vector<int> cat_responses;
274            vector<double> ord_responses;
275            vector<int> sidx;
276            int maxSubsetSize;
277        };
278
279        CV_WRAP_SAME_PROPERTY(int, MaxCategories, params)
280        CV_WRAP_SAME_PROPERTY(int, MaxDepth, params)
281        CV_WRAP_SAME_PROPERTY(int, MinSampleCount, params)
282        CV_WRAP_SAME_PROPERTY(int, CVFolds, params)
283        CV_WRAP_SAME_PROPERTY(bool, UseSurrogates, params)
284        CV_WRAP_SAME_PROPERTY(bool, Use1SERule, params)
285        CV_WRAP_SAME_PROPERTY(bool, TruncatePrunedTree, params)
286        CV_WRAP_SAME_PROPERTY(float, RegressionAccuracy, params)
287        CV_WRAP_SAME_PROPERTY_S(cv::Mat, Priors, params)
288
289        DTreesImpl();
290        virtual ~DTreesImpl();
291        virtual void clear();
292
293        String getDefaultName() const { return "opencv_ml_dtree"; }
294        bool isTrained() const { return !roots.empty(); }
295        bool isClassifier() const { return _isClassifier; }
296        int getVarCount() const { return varType.empty() ? 0 : (int)(varType.size() - 1); }
297        int getCatCount(int vi) const { return catOfs[vi][1] - catOfs[vi][0]; }
298        int getSubsetSize(int vi) const { return (getCatCount(vi) + 31)/32; }
299
300        virtual void setDParams(const TreeParams& _params);
301        virtual void startTraining( const Ptr<TrainData>& trainData, int flags );
302        virtual void endTraining();
303        virtual void initCompVarIdx();
304        virtual bool train( const Ptr<TrainData>& trainData, int flags );
305
306        virtual int addTree( const vector<int>& sidx );
307        virtual int addNodeAndTrySplit( int parent, const vector<int>& sidx );
308        virtual const vector<int>& getActiveVars();
309        virtual int findBestSplit( const vector<int>& _sidx );
310        virtual void calcValue( int nidx, const vector<int>& _sidx );
311
312        virtual WSplit findSplitOrdClass( int vi, const vector<int>& _sidx, double initQuality );
313
314        // simple k-means, slightly modified to take into account the "weight" (L1-norm) of each vector.
315        virtual void clusterCategories( const double* vectors, int n, int m, double* csums, int k, int* labels );
316        virtual WSplit findSplitCatClass( int vi, const vector<int>& _sidx, double initQuality, int* subset );
317
318        virtual WSplit findSplitOrdReg( int vi, const vector<int>& _sidx, double initQuality );
319        virtual WSplit findSplitCatReg( int vi, const vector<int>& _sidx, double initQuality, int* subset );
320
321        virtual int calcDir( int splitidx, const vector<int>& _sidx, vector<int>& _sleft, vector<int>& _sright );
322        virtual int pruneCV( int root );
323
324        virtual double updateTreeRNC( int root, double T, int fold );
325        virtual bool cutTree( int root, double T, int fold, double min_alpha );
326        virtual float predictTrees( const Range& range, const Mat& sample, int flags ) const;
327        virtual float predict( InputArray inputs, OutputArray outputs, int flags ) const;
328
329        virtual void writeTrainingParams( FileStorage& fs ) const;
330        virtual void writeParams( FileStorage& fs ) const;
331        virtual void writeSplit( FileStorage& fs, int splitidx ) const;
332        virtual void writeNode( FileStorage& fs, int nidx, int depth ) const;
333        virtual void writeTree( FileStorage& fs, int root ) const;
334        virtual void write( FileStorage& fs ) const;
335
336        virtual void readParams( const FileNode& fn );
337        virtual int readSplit( const FileNode& fn );
338        virtual int readNode( const FileNode& fn );
339        virtual int readTree( const FileNode& fn );
340        virtual void read( const FileNode& fn );
341
342        virtual const std::vector<int>& getRoots() const { return roots; }
343        virtual const std::vector<Node>& getNodes() const { return nodes; }
344        virtual const std::vector<Split>& getSplits() const { return splits; }
345        virtual const std::vector<int>& getSubsets() const { return subsets; }
346
347        TreeParams params;
348
349        vector<int> varIdx;
350        vector<int> compVarIdx;
351        vector<uchar> varType;
352        vector<Vec2i> catOfs;
353        vector<int> catMap;
354        vector<int> roots;
355        vector<Node> nodes;
356        vector<Split> splits;
357        vector<int> subsets;
358        vector<int> classLabels;
359        vector<float> missingSubst;
360        vector<int> varMapping;
361        bool _isClassifier;
362
363        Ptr<WorkData> w;
364    };
365
366    template <typename T>
367    static inline void readVectorOrMat(const FileNode & node, std::vector<T> & v)
368    {
369        if (node.type() == FileNode::MAP)
370        {
371            Mat m;
372            node >> m;
373            m.copyTo(v);
374        }
375        else if (node.type() == FileNode::SEQ)
376        {
377            node >> v;
378        }
379    }
380
381}}
382
383#endif /* __OPENCV_ML_PRECOMP_HPP__ */
384