1793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler/*M///////////////////////////////////////////////////////////////////////////////////////
2793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//
3793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//
5793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//  By downloading, copying, installing or using the software you agree to this license.
6793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//  If you do not agree to this license, do not download, install,
7793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//  copy or use the software.
8793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//
9793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//
10793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//                           License Agreement
11793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//                For Open Source Computer Vision Library
12793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//
13793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// Copyright (C) 2000, Intel Corporation, all rights reserved.
14793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// Copyright (C) 2014, Itseez Inc, all rights reserved.
15793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// Third party copyrights are property of their respective owners.
16793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//
17793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// Redistribution and use in source and binary forms, with or without modification,
18793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// are permitted provided that the following conditions are met:
19793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//
20793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//   * Redistribution's of source code must retain the above copyright notice,
21793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//     this list of conditions and the following disclaimer.
22793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//
23793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//   * Redistribution's in binary form must reproduce the above copyright notice,
24793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//     this list of conditions and the following disclaimer in the documentation
25793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//     and/or other materials provided with the distribution.
26793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//
27793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//   * The name of the copyright holders may not be used to endorse or promote products
28793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//     derived from this software without specific prior written permission.
29793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//
30793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// This software is provided by the copyright holders and contributors "as is" and
31793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// any express or implied warranties, including, but not limited to, the implied
32793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// warranties of merchantability and fitness for a particular purpose are disclaimed.
33793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// In no event shall the Intel Corporation or contributors be liable for any direct,
34793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// indirect, incidental, special, exemplary, or consequential damages
35793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// (including, but not limited to, procurement of substitute goods or services;
36793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// loss of use, data, or profits; or business interruption) however caused
37793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// and on any theory of liability, whether in contract, strict liability,
38793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// or tort (including negligence or otherwise) arising in any way out of
39793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// the use of this software, even if advised of the possibility of such damage.
40793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//
41793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//M*/
42793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
43793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler#include "precomp.hpp"
44793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
45793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslernamespace cv {
46793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslernamespace ml {
47793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
48793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//////////////////////////////////////////////////////////////////////////////////////////
49793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//                                  Random trees                                        //
50793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//////////////////////////////////////////////////////////////////////////////////////////
51793ee12c6df9cad3806238d32528c49a3ff9331dNoah PreslerRTreeParams::RTreeParams()
52793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler{
53793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    calcVarImportance = false;
54793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    nactiveVars = 0;
55793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    termCrit = TermCriteria(TermCriteria::EPS + TermCriteria::COUNT, 50, 0.1);
56793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler}
57793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
58793ee12c6df9cad3806238d32528c49a3ff9331dNoah PreslerRTreeParams::RTreeParams(bool _calcVarImportance,
59793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                         int _nactiveVars,
60793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                         TermCriteria _termCrit )
61793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler{
62793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    calcVarImportance = _calcVarImportance;
63793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    nactiveVars = _nactiveVars;
64793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    termCrit = _termCrit;
65793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler}
66793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
67793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
68793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerclass DTreesImplForRTrees : public DTreesImpl
69793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler{
70793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerpublic:
71793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    DTreesImplForRTrees()
72793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
73793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        params.setMaxDepth(5);
74793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        params.setMinSampleCount(10);
75793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        params.setRegressionAccuracy(0.f);
76793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        params.useSurrogates = false;
77793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        params.setMaxCategories(10);
78793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        params.setCVFolds(0);
79793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        params.use1SERule = false;
80793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        params.truncatePrunedTree = false;
81793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        params.priors = Mat();
82793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
83793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    virtual ~DTreesImplForRTrees() {}
84793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
85793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void clear()
86793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
87793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        DTreesImpl::clear();
88793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        oobError = 0.;
89793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        rng = RNG((uint64)-1);
90793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
91793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
92793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    const vector<int>& getActiveVars()
93793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
94793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int i, nvars = (int)allVars.size(), m = (int)activeVars.size();
95793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        for( i = 0; i < nvars; i++ )
96793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
97793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            int i1 = rng.uniform(0, nvars);
98793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            int i2 = rng.uniform(0, nvars);
99793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            std::swap(allVars[i1], allVars[i2]);
100793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
101793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        for( i = 0; i < m; i++ )
102793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            activeVars[i] = allVars[i];
103793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        return activeVars;
104793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
105793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
106793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void startTraining( const Ptr<TrainData>& trainData, int flags )
107793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
108793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        DTreesImpl::startTraining(trainData, flags);
109793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int nvars = w->data->getNVars();
110793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int i, m = rparams.nactiveVars > 0 ? rparams.nactiveVars : cvRound(std::sqrt((double)nvars));
111793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        m = std::min(std::max(m, 1), nvars);
112793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        allVars.resize(nvars);
113793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        activeVars.resize(m);
114793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        for( i = 0; i < nvars; i++ )
115793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            allVars[i] = varIdx[i];
116793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
117793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
118793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void endTraining()
119793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
120793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        DTreesImpl::endTraining();
121793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        vector<int> a, b;
122793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        std::swap(allVars, a);
123793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        std::swap(activeVars, b);
124793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
125793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
126793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    bool train( const Ptr<TrainData>& trainData, int flags )
127793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
128793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        startTraining(trainData, flags);
129793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int treeidx, ntrees = (rparams.termCrit.type & TermCriteria::COUNT) != 0 ?
130793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            rparams.termCrit.maxCount : 10000;
131793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int i, j, k, vi, vi_, n = (int)w->sidx.size();
132793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int nclasses = (int)classLabels.size();
133793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        double eps = (rparams.termCrit.type & TermCriteria::EPS) != 0 &&
134793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            rparams.termCrit.epsilon > 0 ? rparams.termCrit.epsilon : 0.;
135793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        vector<int> sidx(n);
136793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        vector<uchar> oobmask(n);
137793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        vector<int> oobidx;
138793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        vector<int> oobperm;
139793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        vector<double> oobres(n, 0.);
140793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        vector<int> oobcount(n, 0);
141793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        vector<int> oobvotes(n*nclasses, 0);
142793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int nvars = w->data->getNVars();
143793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int nallvars = w->data->getNAllVars();
144793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        const int* vidx = !varIdx.empty() ? &varIdx[0] : 0;
145793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        vector<float> samplebuf(nallvars);
146793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        Mat samples = w->data->getSamples();
147793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        float* psamples = samples.ptr<float>();
148793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        size_t sstep0 = samples.step1(), sstep1 = 1;
149793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        Mat sample0, sample(nallvars, 1, CV_32F, &samplebuf[0]);
150793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int predictFlags = _isClassifier ? (PREDICT_MAX_VOTE + RAW_OUTPUT) : PREDICT_SUM;
151793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
152793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        bool calcOOBError = eps > 0 || rparams.calcVarImportance;
153793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        double max_response = 0.;
154793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
155793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( w->data->getLayout() == COL_SAMPLE )
156793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            std::swap(sstep0, sstep1);
157793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
158793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( !_isClassifier )
159793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
160793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( i = 0; i < n; i++ )
161793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
162793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double val = std::abs(w->ord_responses[w->sidx[i]]);
163793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                max_response = std::max(max_response, val);
164793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
165793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
166793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
167793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( rparams.calcVarImportance )
168793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            varImportance.resize(nallvars, 0.f);
169793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
170793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        for( treeidx = 0; treeidx < ntrees; treeidx++ )
171793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
172793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( i = 0; i < n; i++ )
173793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                oobmask[i] = (uchar)1;
174793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
175793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( i = 0; i < n; i++ )
176793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
177793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                j = rng.uniform(0, n);
178793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                sidx[i] = w->sidx[j];
179793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                oobmask[j] = (uchar)0;
180793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
181793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            int root = addTree( sidx );
182793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            if( root < 0 )
183793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                return false;
184793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
185793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            if( calcOOBError )
186793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            {
187793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                oobidx.clear();
188793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                for( i = 0; i < n; i++ )
189793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
190793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    if( !oobmask[i] )
191793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        oobidx.push_back(i);
192793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
193793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                int n_oob = (int)oobidx.size();
194793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                // if there is no out-of-bag samples, we can not compute OOB error
195793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                // nor update the variable importance vector; so we proceed to the next tree
196793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                if( n_oob == 0 )
197793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    continue;
198793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                double ncorrect_responses = 0.;
199793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
200793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                oobError = 0.;
201793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                for( i = 0; i < n_oob; i++ )
202793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
203793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    j = oobidx[i];
204793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    sample = Mat( nallvars, 1, CV_32F, psamples + sstep0*w->sidx[j], sstep1*sizeof(psamples[0]) );
205793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
206793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    double val = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags);
207793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    if( !_isClassifier )
208793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    {
209793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        oobres[j] += val;
210793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        oobcount[j]++;
211793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        double true_val = w->ord_responses[w->sidx[j]];
212793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        double a = oobres[j]/oobcount[j] - true_val;
213793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        oobError += a*a;
214793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        val = (val - true_val)/max_response;
215793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        ncorrect_responses += std::exp( -val*val );
216793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    }
217793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    else
218793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    {
219793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        int ival = cvRound(val);
220793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        int* votes = &oobvotes[j*nclasses];
221793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        votes[ival]++;
222793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        int best_class = 0;
223793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        for( k = 1; k < nclasses; k++ )
224793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            if( votes[best_class] < votes[k] )
225793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                                best_class = k;
226793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        int diff = best_class != w->cat_responses[w->sidx[j]];
227793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        oobError += diff;
228793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        ncorrect_responses += diff == 0;
229793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    }
230793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
231793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
232793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                oobError /= n_oob;
233793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                if( rparams.calcVarImportance && n_oob > 1 )
234793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                {
235793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    oobperm.resize(n_oob);
236793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    for( i = 0; i < n_oob; i++ )
237793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        oobperm[i] = oobidx[i];
238793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
239793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    for( vi_ = 0; vi_ < nvars; vi_++ )
240793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    {
241793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        vi = vidx ? vidx[vi_] : vi_;
242793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        double ncorrect_responses_permuted = 0;
243793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        for( i = 0; i < n_oob; i++ )
244793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        {
245793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            int i1 = rng.uniform(0, n_oob);
246793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            int i2 = rng.uniform(0, n_oob);
247793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            std::swap(i1, i2);
248793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        }
249793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
250793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        for( i = 0; i < n_oob; i++ )
251793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        {
252793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            j = oobidx[i];
253793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            int vj = oobperm[i];
254793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            sample0 = Mat( nallvars, 1, CV_32F, psamples + sstep0*w->sidx[j], sstep1*sizeof(psamples[0]) );
255793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            for( k = 0; k < nallvars; k++ )
256793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                                sample.at<float>(k) = sample0.at<float>(k);
257793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            sample.at<float>(vi) = psamples[sstep0*w->sidx[vj] + sstep1*vi];
258793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
259793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            double val = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags);
260793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            if( !_isClassifier )
261793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            {
262793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                                val = (val - w->ord_responses[w->sidx[j]])/max_response;
263793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                                ncorrect_responses_permuted += exp( -val*val );
264793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            }
265793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                            else
266793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                                ncorrect_responses_permuted += cvRound(val) == w->cat_responses[w->sidx[j]];
267793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        }
268793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                        varImportance[vi] += (float)(ncorrect_responses - ncorrect_responses_permuted);
269793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                    }
270793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                }
271793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            }
272793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            if( calcOOBError && oobError < eps )
273793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                break;
274793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
275793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
276793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( rparams.calcVarImportance )
277793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
278793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            for( vi_ = 0; vi_ < nallvars; vi_++ )
279793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler                varImportance[vi_] = std::max(varImportance[vi_], 0.f);
280793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            normalize(varImportance, varImportance, 1., 0, NORM_L1);
281793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
282793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        endTraining();
283793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        return true;
284793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
285793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
286793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void writeTrainingParams( FileStorage& fs ) const
287793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
288793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        DTreesImpl::writeTrainingParams(fs);
289793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        fs << "nactive_vars" << rparams.nactiveVars;
290793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
291793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
292793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void write( FileStorage& fs ) const
293793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
294793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( roots.empty() )
295793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            CV_Error( CV_StsBadArg, "RTrees have not been trained" );
296793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
297793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        writeParams(fs);
298793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
299793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        fs << "oob_error" << oobError;
300793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        if( !varImportance.empty() )
301793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            fs << "var_importance" << varImportance;
302793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
303793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int k, ntrees = (int)roots.size();
304793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
305793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        fs << "ntrees" << ntrees
306793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler           << "trees" << "[";
307793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
308793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        for( k = 0; k < ntrees; k++ )
309793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
310793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            fs << "{";
311793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            writeTree(fs, roots[k]);
312793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            fs << "}";
313793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
314793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
315793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        fs << "]";
316793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
317793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
318793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void readParams( const FileNode& fn )
319793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
320793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        DTreesImpl::readParams(fn);
321793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
322793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        FileNode tparams_node = fn["training_params"];
323793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        rparams.nactiveVars = (int)tparams_node["nactive_vars"];
324793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
325793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
326793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void read( const FileNode& fn )
327793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
328793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        clear();
329793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
330793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        //int nclasses = (int)fn["nclasses"];
331793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        //int nsamples = (int)fn["nsamples"];
332793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        oobError = (double)fn["oob_error"];
333793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        int ntrees = (int)fn["ntrees"];
334793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
335793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        readVectorOrMat(fn["var_importance"], varImportance);
336793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
337793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        readParams(fn);
338793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
339793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        FileNode trees_node = fn["trees"];
340793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        FileNodeIterator it = trees_node.begin();
341793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        CV_Assert( ntrees == (int)trees_node.size() );
342793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
343793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        for( int treeidx = 0; treeidx < ntrees; treeidx++, ++it )
344793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        {
345793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            FileNode nfn = (*it)["nodes"];
346793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler            readTree(nfn);
347793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        }
348793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
349793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
350793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    RTreeParams rparams;
351793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    double oobError;
352793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    vector<float> varImportance;
353793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    vector<int> allVars, activeVars;
354793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    RNG rng;
355793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler};
356793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
357793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
358793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerclass RTreesImpl : public RTrees
359793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler{
360793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerpublic:
361793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CV_IMPL_PROPERTY(bool, CalculateVarImportance, impl.rparams.calcVarImportance)
362793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CV_IMPL_PROPERTY(int, ActiveVarCount, impl.rparams.nactiveVars)
363793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CV_IMPL_PROPERTY_S(TermCriteria, TermCriteria, impl.rparams.termCrit)
364793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
365793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CV_WRAP_SAME_PROPERTY(int, MaxCategories, impl.params)
366793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CV_WRAP_SAME_PROPERTY(int, MaxDepth, impl.params)
367793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CV_WRAP_SAME_PROPERTY(int, MinSampleCount, impl.params)
368793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CV_WRAP_SAME_PROPERTY(int, CVFolds, impl.params)
369793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CV_WRAP_SAME_PROPERTY(bool, UseSurrogates, impl.params)
370793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CV_WRAP_SAME_PROPERTY(bool, Use1SERule, impl.params)
371793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CV_WRAP_SAME_PROPERTY(bool, TruncatePrunedTree, impl.params)
372793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CV_WRAP_SAME_PROPERTY(float, RegressionAccuracy, impl.params)
373793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    CV_WRAP_SAME_PROPERTY_S(cv::Mat, Priors, impl.params)
374793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
375793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    RTreesImpl() {}
376793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    virtual ~RTreesImpl() {}
377793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
378793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    String getDefaultName() const { return "opencv_ml_rtrees"; }
379793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
380793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    bool train( const Ptr<TrainData>& trainData, int flags )
381793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
382793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        return impl.train(trainData, flags);
383793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
384793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
385793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    float predict( InputArray samples, OutputArray results, int flags ) const
386793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
387793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        return impl.predict(samples, results, flags);
388793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
389793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
390793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void write( FileStorage& fs ) const
391793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
392793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        impl.write(fs);
393793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
394793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
395793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    void read( const FileNode& fn )
396793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    {
397793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler        impl.read(fn);
398793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    }
399793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
400793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    Mat getVarImportance() const { return Mat_<float>(impl.varImportance, true); }
401793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    int getVarCount() const { return impl.getVarCount(); }
402793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
403793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    bool isTrained() const { return impl.isTrained(); }
404793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    bool isClassifier() const { return impl.isClassifier(); }
405793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
406793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    const vector<int>& getRoots() const { return impl.getRoots(); }
407793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    const vector<Node>& getNodes() const { return impl.getNodes(); }
408793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    const vector<Split>& getSplits() const { return impl.getSplits(); }
409793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    const vector<int>& getSubsets() const { return impl.getSubsets(); }
410793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
411793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    DTreesImplForRTrees impl;
412793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler};
413793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
414793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
415793ee12c6df9cad3806238d32528c49a3ff9331dNoah PreslerPtr<RTrees> RTrees::create()
416793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler{
417793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler    return makePtr<RTreesImpl>();
418793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler}
419793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
420793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler}}
421793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler
422793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// End of file.
423