1793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler/*M/////////////////////////////////////////////////////////////////////////////////////// 2793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// 3793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 4793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// 5793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// By downloading, copying, installing or using the software you agree to this license. 6793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// If you do not agree to this license, do not download, install, 7793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// copy or use the software. 8793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// 9793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// 10793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// License Agreement 11793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// For Open Source Computer Vision Library 12793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// 13793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// Copyright (C) 2000, Intel Corporation, all rights reserved. 14793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// Copyright (C) 2014, Itseez Inc, all rights reserved. 15793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// Third party copyrights are property of their respective owners. 16793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// 17793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// Redistribution and use in source and binary forms, with or without modification, 18793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// are permitted provided that the following conditions are met: 19793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// 20793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// * Redistribution's of source code must retain the above copyright notice, 21793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// this list of conditions and the following disclaimer. 22793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// 23793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// * Redistribution's in binary form must reproduce the above copyright notice, 24793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// this list of conditions and the following disclaimer in the documentation 25793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// and/or other materials provided with the distribution. 26793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// 27793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// * The name of the copyright holders may not be used to endorse or promote products 28793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// derived from this software without specific prior written permission. 29793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// 30793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// This software is provided by the copyright holders and contributors "as is" and 31793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// any express or implied warranties, including, but not limited to, the implied 32793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// warranties of merchantability and fitness for a particular purpose are disclaimed. 33793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// In no event shall the Intel Corporation or contributors be liable for any direct, 34793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// indirect, incidental, special, exemplary, or consequential damages 35793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// (including, but not limited to, procurement of substitute goods or services; 36793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// loss of use, data, or profits; or business interruption) however caused 37793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// and on any theory of liability, whether in contract, strict liability, 38793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// or tort (including negligence or otherwise) arising in any way out of 39793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// the use of this software, even if advised of the possibility of such damage. 40793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// 41793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler//M*/ 42793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 43793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler#include "precomp.hpp" 44793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 45793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslernamespace cv { 46793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslernamespace ml { 47793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 48793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler////////////////////////////////////////////////////////////////////////////////////////// 49793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// Random trees // 50793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler////////////////////////////////////////////////////////////////////////////////////////// 51793ee12c6df9cad3806238d32528c49a3ff9331dNoah PreslerRTreeParams::RTreeParams() 52793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler{ 53793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler calcVarImportance = false; 54793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler nactiveVars = 0; 55793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler termCrit = TermCriteria(TermCriteria::EPS + TermCriteria::COUNT, 50, 0.1); 56793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler} 57793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 58793ee12c6df9cad3806238d32528c49a3ff9331dNoah PreslerRTreeParams::RTreeParams(bool _calcVarImportance, 59793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int _nactiveVars, 60793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler TermCriteria _termCrit ) 61793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler{ 62793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler calcVarImportance = _calcVarImportance; 63793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler nactiveVars = _nactiveVars; 64793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler termCrit = _termCrit; 65793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler} 66793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 67793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 68793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerclass DTreesImplForRTrees : public DTreesImpl 69793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler{ 70793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerpublic: 71793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler DTreesImplForRTrees() 72793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 73793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.setMaxDepth(5); 74793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.setMinSampleCount(10); 75793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.setRegressionAccuracy(0.f); 76793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.useSurrogates = false; 77793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.setMaxCategories(10); 78793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.setCVFolds(0); 79793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.use1SERule = false; 80793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.truncatePrunedTree = false; 81793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler params.priors = Mat(); 82793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 83793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler virtual ~DTreesImplForRTrees() {} 84793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 85793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void clear() 86793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 87793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler DTreesImpl::clear(); 88793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler oobError = 0.; 89793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler rng = RNG((uint64)-1); 90793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 91793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 92793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const vector<int>& getActiveVars() 93793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 94793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int i, nvars = (int)allVars.size(), m = (int)activeVars.size(); 95793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < nvars; i++ ) 96793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 97793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int i1 = rng.uniform(0, nvars); 98793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int i2 = rng.uniform(0, nvars); 99793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler std::swap(allVars[i1], allVars[i2]); 100793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 101793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < m; i++ ) 102793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler activeVars[i] = allVars[i]; 103793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return activeVars; 104793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 105793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 106793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void startTraining( const Ptr<TrainData>& trainData, int flags ) 107793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 108793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler DTreesImpl::startTraining(trainData, flags); 109793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int nvars = w->data->getNVars(); 110793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int i, m = rparams.nactiveVars > 0 ? rparams.nactiveVars : cvRound(std::sqrt((double)nvars)); 111793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler m = std::min(std::max(m, 1), nvars); 112793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler allVars.resize(nvars); 113793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler activeVars.resize(m); 114793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < nvars; i++ ) 115793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler allVars[i] = varIdx[i]; 116793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 117793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 118793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void endTraining() 119793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 120793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler DTreesImpl::endTraining(); 121793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler vector<int> a, b; 122793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler std::swap(allVars, a); 123793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler std::swap(activeVars, b); 124793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 125793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 126793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler bool train( const Ptr<TrainData>& trainData, int flags ) 127793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 128793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler startTraining(trainData, flags); 129793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int treeidx, ntrees = (rparams.termCrit.type & TermCriteria::COUNT) != 0 ? 130793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler rparams.termCrit.maxCount : 10000; 131793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int i, j, k, vi, vi_, n = (int)w->sidx.size(); 132793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int nclasses = (int)classLabels.size(); 133793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double eps = (rparams.termCrit.type & TermCriteria::EPS) != 0 && 134793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler rparams.termCrit.epsilon > 0 ? rparams.termCrit.epsilon : 0.; 135793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler vector<int> sidx(n); 136793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler vector<uchar> oobmask(n); 137793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler vector<int> oobidx; 138793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler vector<int> oobperm; 139793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler vector<double> oobres(n, 0.); 140793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler vector<int> oobcount(n, 0); 141793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler vector<int> oobvotes(n*nclasses, 0); 142793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int nvars = w->data->getNVars(); 143793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int nallvars = w->data->getNAllVars(); 144793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const int* vidx = !varIdx.empty() ? &varIdx[0] : 0; 145793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler vector<float> samplebuf(nallvars); 146793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat samples = w->data->getSamples(); 147793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler float* psamples = samples.ptr<float>(); 148793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler size_t sstep0 = samples.step1(), sstep1 = 1; 149793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat sample0, sample(nallvars, 1, CV_32F, &samplebuf[0]); 150793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int predictFlags = _isClassifier ? (PREDICT_MAX_VOTE + RAW_OUTPUT) : PREDICT_SUM; 151793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 152793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler bool calcOOBError = eps > 0 || rparams.calcVarImportance; 153793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double max_response = 0.; 154793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 155793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( w->data->getLayout() == COL_SAMPLE ) 156793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler std::swap(sstep0, sstep1); 157793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 158793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( !_isClassifier ) 159793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 160793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < n; i++ ) 161793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 162793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double val = std::abs(w->ord_responses[w->sidx[i]]); 163793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler max_response = std::max(max_response, val); 164793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 165793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 166793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 167793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( rparams.calcVarImportance ) 168793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler varImportance.resize(nallvars, 0.f); 169793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 170793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( treeidx = 0; treeidx < ntrees; treeidx++ ) 171793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 172793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < n; i++ ) 173793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler oobmask[i] = (uchar)1; 174793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 175793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < n; i++ ) 176793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 177793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler j = rng.uniform(0, n); 178793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler sidx[i] = w->sidx[j]; 179793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler oobmask[j] = (uchar)0; 180793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 181793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int root = addTree( sidx ); 182793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( root < 0 ) 183793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return false; 184793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 185793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( calcOOBError ) 186793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 187793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler oobidx.clear(); 188793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < n; i++ ) 189793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 190793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( !oobmask[i] ) 191793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler oobidx.push_back(i); 192793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 193793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int n_oob = (int)oobidx.size(); 194793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler // if there is no out-of-bag samples, we can not compute OOB error 195793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler // nor update the variable importance vector; so we proceed to the next tree 196793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( n_oob == 0 ) 197793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler continue; 198793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double ncorrect_responses = 0.; 199793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 200793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler oobError = 0.; 201793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < n_oob; i++ ) 202793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 203793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler j = oobidx[i]; 204793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler sample = Mat( nallvars, 1, CV_32F, psamples + sstep0*w->sidx[j], sstep1*sizeof(psamples[0]) ); 205793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 206793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double val = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags); 207793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( !_isClassifier ) 208793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 209793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler oobres[j] += val; 210793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler oobcount[j]++; 211793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double true_val = w->ord_responses[w->sidx[j]]; 212793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double a = oobres[j]/oobcount[j] - true_val; 213793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler oobError += a*a; 214793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler val = (val - true_val)/max_response; 215793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler ncorrect_responses += std::exp( -val*val ); 216793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 217793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler else 218793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 219793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int ival = cvRound(val); 220793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int* votes = &oobvotes[j*nclasses]; 221793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler votes[ival]++; 222793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int best_class = 0; 223793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( k = 1; k < nclasses; k++ ) 224793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( votes[best_class] < votes[k] ) 225793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler best_class = k; 226793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int diff = best_class != w->cat_responses[w->sidx[j]]; 227793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler oobError += diff; 228793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler ncorrect_responses += diff == 0; 229793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 230793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 231793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 232793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler oobError /= n_oob; 233793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( rparams.calcVarImportance && n_oob > 1 ) 234793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 235793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler oobperm.resize(n_oob); 236793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < n_oob; i++ ) 237793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler oobperm[i] = oobidx[i]; 238793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 239793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( vi_ = 0; vi_ < nvars; vi_++ ) 240793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 241793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler vi = vidx ? vidx[vi_] : vi_; 242793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double ncorrect_responses_permuted = 0; 243793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < n_oob; i++ ) 244793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 245793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int i1 = rng.uniform(0, n_oob); 246793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int i2 = rng.uniform(0, n_oob); 247793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler std::swap(i1, i2); 248793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 249793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 250793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( i = 0; i < n_oob; i++ ) 251793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 252793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler j = oobidx[i]; 253793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int vj = oobperm[i]; 254793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler sample0 = Mat( nallvars, 1, CV_32F, psamples + sstep0*w->sidx[j], sstep1*sizeof(psamples[0]) ); 255793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( k = 0; k < nallvars; k++ ) 256793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler sample.at<float>(k) = sample0.at<float>(k); 257793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler sample.at<float>(vi) = psamples[sstep0*w->sidx[vj] + sstep1*vi]; 258793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 259793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double val = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags); 260793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( !_isClassifier ) 261793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 262793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler val = (val - w->ord_responses[w->sidx[j]])/max_response; 263793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler ncorrect_responses_permuted += exp( -val*val ); 264793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 265793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler else 266793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler ncorrect_responses_permuted += cvRound(val) == w->cat_responses[w->sidx[j]]; 267793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 268793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler varImportance[vi] += (float)(ncorrect_responses - ncorrect_responses_permuted); 269793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 270793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 271793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 272793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( calcOOBError && oobError < eps ) 273793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler break; 274793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 275793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 276793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( rparams.calcVarImportance ) 277793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 278793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( vi_ = 0; vi_ < nallvars; vi_++ ) 279793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler varImportance[vi_] = std::max(varImportance[vi_], 0.f); 280793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler normalize(varImportance, varImportance, 1., 0, NORM_L1); 281793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 282793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler endTraining(); 283793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return true; 284793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 285793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 286793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void writeTrainingParams( FileStorage& fs ) const 287793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 288793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler DTreesImpl::writeTrainingParams(fs); 289793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "nactive_vars" << rparams.nactiveVars; 290793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 291793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 292793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void write( FileStorage& fs ) const 293793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 294793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( roots.empty() ) 295793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_Error( CV_StsBadArg, "RTrees have not been trained" ); 296793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 297793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler writeParams(fs); 298793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 299793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "oob_error" << oobError; 300793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler if( !varImportance.empty() ) 301793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "var_importance" << varImportance; 302793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 303793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int k, ntrees = (int)roots.size(); 304793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 305793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "ntrees" << ntrees 306793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler << "trees" << "["; 307793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 308793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( k = 0; k < ntrees; k++ ) 309793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 310793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "{"; 311793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler writeTree(fs, roots[k]); 312793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "}"; 313793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 314793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 315793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler fs << "]"; 316793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 317793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 318793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void readParams( const FileNode& fn ) 319793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 320793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler DTreesImpl::readParams(fn); 321793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 322793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler FileNode tparams_node = fn["training_params"]; 323793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler rparams.nactiveVars = (int)tparams_node["nactive_vars"]; 324793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 325793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 326793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void read( const FileNode& fn ) 327793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 328793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler clear(); 329793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 330793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler //int nclasses = (int)fn["nclasses"]; 331793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler //int nsamples = (int)fn["nsamples"]; 332793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler oobError = (double)fn["oob_error"]; 333793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int ntrees = (int)fn["ntrees"]; 334793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 335793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler readVectorOrMat(fn["var_importance"], varImportance); 336793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 337793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler readParams(fn); 338793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 339793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler FileNode trees_node = fn["trees"]; 340793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler FileNodeIterator it = trees_node.begin(); 341793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_Assert( ntrees == (int)trees_node.size() ); 342793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 343793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler for( int treeidx = 0; treeidx < ntrees; treeidx++, ++it ) 344793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 345793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler FileNode nfn = (*it)["nodes"]; 346793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler readTree(nfn); 347793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 348793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 349793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 350793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler RTreeParams rparams; 351793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler double oobError; 352793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler vector<float> varImportance; 353793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler vector<int> allVars, activeVars; 354793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler RNG rng; 355793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler}; 356793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 357793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 358793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerclass RTreesImpl : public RTrees 359793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler{ 360793ee12c6df9cad3806238d32528c49a3ff9331dNoah Preslerpublic: 361793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_IMPL_PROPERTY(bool, CalculateVarImportance, impl.rparams.calcVarImportance) 362793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_IMPL_PROPERTY(int, ActiveVarCount, impl.rparams.nactiveVars) 363793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_IMPL_PROPERTY_S(TermCriteria, TermCriteria, impl.rparams.termCrit) 364793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 365793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_WRAP_SAME_PROPERTY(int, MaxCategories, impl.params) 366793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_WRAP_SAME_PROPERTY(int, MaxDepth, impl.params) 367793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_WRAP_SAME_PROPERTY(int, MinSampleCount, impl.params) 368793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_WRAP_SAME_PROPERTY(int, CVFolds, impl.params) 369793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_WRAP_SAME_PROPERTY(bool, UseSurrogates, impl.params) 370793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_WRAP_SAME_PROPERTY(bool, Use1SERule, impl.params) 371793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_WRAP_SAME_PROPERTY(bool, TruncatePrunedTree, impl.params) 372793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_WRAP_SAME_PROPERTY(float, RegressionAccuracy, impl.params) 373793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler CV_WRAP_SAME_PROPERTY_S(cv::Mat, Priors, impl.params) 374793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 375793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler RTreesImpl() {} 376793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler virtual ~RTreesImpl() {} 377793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 378793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler String getDefaultName() const { return "opencv_ml_rtrees"; } 379793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 380793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler bool train( const Ptr<TrainData>& trainData, int flags ) 381793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 382793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return impl.train(trainData, flags); 383793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 384793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 385793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler float predict( InputArray samples, OutputArray results, int flags ) const 386793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 387793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return impl.predict(samples, results, flags); 388793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 389793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 390793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void write( FileStorage& fs ) const 391793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 392793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler impl.write(fs); 393793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 394793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 395793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler void read( const FileNode& fn ) 396793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler { 397793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler impl.read(fn); 398793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler } 399793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 400793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler Mat getVarImportance() const { return Mat_<float>(impl.varImportance, true); } 401793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler int getVarCount() const { return impl.getVarCount(); } 402793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 403793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler bool isTrained() const { return impl.isTrained(); } 404793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler bool isClassifier() const { return impl.isClassifier(); } 405793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 406793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const vector<int>& getRoots() const { return impl.getRoots(); } 407793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const vector<Node>& getNodes() const { return impl.getNodes(); } 408793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const vector<Split>& getSplits() const { return impl.getSplits(); } 409793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler const vector<int>& getSubsets() const { return impl.getSubsets(); } 410793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 411793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler DTreesImplForRTrees impl; 412793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler}; 413793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 414793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 415793ee12c6df9cad3806238d32528c49a3ff9331dNoah PreslerPtr<RTrees> RTrees::create() 416793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler{ 417793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler return makePtr<RTreesImpl>(); 418793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler} 419793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 420793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler}} 421793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler 422793ee12c6df9cad3806238d32528c49a3ff9331dNoah Presler// End of file. 423