16acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/*M/////////////////////////////////////////////////////////////////////////////////////// 26acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 36acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 46acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 56acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// By downloading, copying, installing or using the software you agree to this license. 66acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// If you do not agree to this license, do not download, install, 76acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// copy or use the software. 86acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 96acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Intel License Agreement 116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Copyright (C) 2000, Intel Corporation, all rights reserved. 136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Third party copyrights are property of their respective owners. 146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Redistribution and use in source and binary forms, with or without modification, 166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// are permitted provided that the following conditions are met: 176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// * Redistribution's of source code must retain the above copyright notice, 196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// this list of conditions and the following disclaimer. 206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// * Redistribution's in binary form must reproduce the above copyright notice, 226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// this list of conditions and the following disclaimer in the documentation 236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// and/or other materials provided with the distribution. 246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// * The name of Intel Corporation may not be used to endorse or promote products 266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// derived from this software without specific prior written permission. 276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// This software is provided by the copyright holders and contributors "as is" and 296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// any express or implied warranties, including, but not limited to, the implied 306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// warranties of merchantability and fitness for a particular purpose are disclaimed. 316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// In no event shall the Intel Corporation or contributors be liable for any direct, 326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// indirect, incidental, special, exemplary, or consequential damages 336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// (including, but not limited to, procurement of substitute goods or services; 346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// loss of use, data, or profits; or business interruption) however caused 356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// and on any theory of liability, whether in contract, strict liability, 366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// or tort (including negligence or otherwise) arising in any way out of 376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// the use of this software, even if advised of the possibility of such damage. 386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//M*/ 406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#include "_ml.h" 426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic const float ord_nan = FLT_MAX*0.5f; 446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic const int min_block_size = 1 << 16; 456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic const int block_size_delta = 1 << 10; 466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeTrainData::CvDTreeTrainData() 486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn var_idx = var_type = cat_count = cat_ofs = cat_map = 506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn priors = priors_mult = counts = buf = direction = split_buf = 0; 516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn tree_storage = temp_storage = 0; 526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn clear(); 546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag, 586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* _responses, const CvMat* _var_idx, 596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* _sample_idx, const CvMat* _var_type, 606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* _missing_mask, const CvDTreeParams& _params, 616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bool _shared, bool _add_labels ) 626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn var_idx = var_type = cat_count = cat_ofs = cat_map = 646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn priors = priors_mult = counts = buf = direction = split_buf = 0; 656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn tree_storage = temp_storage = 0; 666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx, 686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _var_type, _missing_mask, _params, _shared, _add_labels ); 696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeTrainData::~CvDTreeTrainData() 736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn clear(); 756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool CvDTreeTrainData::set_params( const CvDTreeParams& _params ) 796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bool ok = false; 816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvDTreeTrainData::set_params" ); 836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // set parameters 876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params = _params; 886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.max_categories < 2 ) 906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" ); 916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.max_categories = MIN( params.max_categories, 15 ); 926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.max_depth < 0 ) 946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" ); 956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.max_depth = MIN( params.max_depth, 25 ); 966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.min_sample_count = MAX(params.min_sample_count,1); 986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.cv_folds < 0 ) 1006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsOutOfRange, 1016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "params.cv_folds should be =0 (the tree is not pruned) " 1026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "or n>0 (tree is pruned using n-fold cross-validation)" ); 1036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.cv_folds == 1 ) 1056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.cv_folds = 0; 1066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.regression_accuracy < 0 ) 1086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" ); 1096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ok = true; 1116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 1136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return ok; 1156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 1166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#define CV_CMP_NUM_PTR(a,b) (*(a) < *(b)) 1196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int ) 1206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int ) 1216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#define CV_CMP_PAIRS(a,b) ((a).val < (b).val) 1236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair32s32f, CV_CMP_PAIRS, int ) 1246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag, 1266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx, 1276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params, 1286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bool _shared, bool _add_labels, bool _update_data ) 1296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 1306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* sample_idx = 0; 1316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* var_type0 = 0; 1326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* tmp_map = 0; 1336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int** int_ptr = 0; 1346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeTrainData* data = 0; 1356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvDTreeTrainData::set_data" ); 1376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 1396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int sample_all = 0, r_type = 0, cv_n; 1416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int total_c_count = 0; 1426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0; 1436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step 1446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int vi, i; 1456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn char err[100]; 1466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int *sidx = 0, *vidx = 0; 1476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( _update_data && data_root ) 1496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 1506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx, 1516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _sample_idx, _var_type, _missing_mask, _params, _shared, _add_labels ); 1526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // compare new and old train data 1546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !(data->var_count == var_count && 1556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON && 1566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvNorm( data->cat_count, cat_count, CV_C ) < FLT_EPSILON && 1576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvNorm( data->cat_map, cat_map, CV_C ) < FLT_EPSILON) ) 1586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, 1596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "The new training data must have the same types and the input and output variables " 1606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "and the same categories for categorical variables" ); 1616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &priors ); 1636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &priors_mult ); 1646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &buf ); 1656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &direction ); 1666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &split_buf ); 1676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMemStorage( &temp_storage ); 1686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn priors = data->priors; data->priors = 0; 1706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn priors_mult = data->priors_mult; data->priors_mult = 0; 1716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn buf = data->buf; data->buf = 0; 1726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn buf_count = data->buf_count; buf_size = data->buf_size; 1736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sample_count = data->sample_count; 1746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn direction = data->direction; data->direction = 0; 1766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split_buf = data->split_buf; data->split_buf = 0; 1776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn temp_storage = data->temp_storage; data->temp_storage = 0; 1786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn nv_heap = data->nv_heap; cv_heap = data->cv_heap; 1796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data_root = new_node( 0, sample_count, 0, 0 ); 1816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn EXIT; 1826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 1836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn clear(); 1856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn var_all = 0; 1876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rng = cvRNG(-1); 1886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( set_params( _params )); 1906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // check parameter types and sizes 1926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all )); 1936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( _tflag == CV_ROW_SAMPLE ) 1946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 1956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type); 1966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dv_step = 1; 1976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( _missing_mask ) 1986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ms_step = _missing_mask->step, mv_step = 1; 1996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 2016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type); 2036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ds_step = 1; 2046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( _missing_mask ) 2056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn mv_step = _missing_mask->step, ms_step = 1; 2066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sample_count = sample_all; 2096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn var_count = var_all; 2106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( _sample_idx ) 2126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( sample_idx = cvPreprocessIndexArray( _sample_idx, sample_all )); 2146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sidx = sample_idx->data.i; 2156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sample_count = sample_idx->rows + sample_idx->cols - 1; 2166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( _var_idx ) 2196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all )); 2216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn vidx = var_idx->data.i; 2226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn var_count = var_idx->rows + var_idx->cols - 1; 2236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_IS_MAT(_responses) || 2266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn (CV_MAT_TYPE(_responses->type) != CV_32SC1 && 2276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_MAT_TYPE(_responses->type) != CV_32FC1) || 2286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _responses->rows != 1 && _responses->cols != 1 || 2296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _responses->rows + _responses->cols - 1 != sample_all ) 2306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or " 2316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "floating-point vector containing as many elements as " 2326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "the total number of samples in the training data matrix" ); 2336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_all, &r_type )); 2356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 )); 2366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cat_var_count = 0; 2386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ord_var_count = -1; 2396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn is_classifier = r_type == CV_VAR_CATEGORICAL; 2416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // step 0. calc the number of categorical vars 2436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( vi = 0; vi < var_count; vi++ ) 2446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ? 2466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cat_var_count++ : ord_var_count--; 2476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ord_var_count = ~ord_var_count; 2506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cv_n = params.cv_folds; 2516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // set the two last elements of var_type array to be able 2526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // to locate responses and cross-validation labels using 2536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // the corresponding get_* functions. 2546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn var_type->data.i[var_count] = cat_var_count; 2556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn var_type->data.i[var_count+1] = cat_var_count+1; 2566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // in case of single ordered predictor we need dummy cv_labels 2586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // for safe split_node_data() operation 2596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn have_labels = cv_n > 0 || ord_var_count == 1 && cat_var_count == 0 || _add_labels; 2606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn buf_size = (ord_var_count + get_work_var_count())*sample_count + 2; 2626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn shared = _shared; 2636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn buf_count = shared ? 3 : 2; 2646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 )); 2656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cat_count = cvCreateMat( 1, cat_var_count+1, CV_32SC1 )); 2666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cat_ofs = cvCreateMat( 1, cat_count->cols+1, CV_32SC1 )); 2676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cat_map = cvCreateMat( 1, cat_count->cols*10 + 128, CV_32SC1 )); 2686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // now calculate the maximum size of split, 2706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // create memory storage that will keep nodes and splits of the decision tree 2716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // allocate root node and the buffer for the whole training data 2726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_split_size = cvAlign(sizeof(CvDTreeSplit) + 2736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*)); 2746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size); 2756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size); 2766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size )); 2776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage )); 2786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn nv_size = var_count*sizeof(int); 2806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn nv_size = MAX( nv_size, (int)sizeof(CvSetElem) ); 2816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn temp_block_size = nv_size; 2836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cv_n ) 2856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( sample_count < cv_n*MAX(params.min_sample_count,10) ) 2876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsOutOfRange, 2886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "The many folds in cross-validation for such a small dataset" ); 2896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) ); 2916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn temp_block_size = MAX(temp_block_size, cv_size); 2926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size ); 2956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size )); 2966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage )); 2976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cv_size ) 2986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage )); 2996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( data_root = new_node( 0, sample_count, 0, 0 )); 3016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) )); 3026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_c_count = 1; 3046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // transform the training data to convenient representation 3066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( vi = 0; vi <= var_count; vi++ ) 3076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int ci; 3096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const uchar* mask = 0; 3106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int m_step = 0, step; 3116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* idata = 0; 3126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float* fdata = 0; 3136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int num_valid = 0; 3146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( vi < var_count ) // analyze i-th input variable 3166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int vi0 = vidx ? vidx[vi] : vi; 3186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ci = get_var_type(vi); 3196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn step = ds_step; m_step = ms_step; 3206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 ) 3216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn idata = _train_data->data.i + vi0*dv_step; 3226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 3236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn fdata = _train_data->data.fl + vi0*dv_step; 3246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( _missing_mask ) 3256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn mask = _missing_mask->data.ptr + vi0*mv_step; 3266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else // analyze _responses 3286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ci = cat_var_count; 3306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn step = CV_IS_MAT_CONT(_responses->type) ? 3316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1 : _responses->step / CV_ELEM_SIZE(_responses->type); 3326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( CV_MAT_TYPE(_responses->type) == CV_32SC1 ) 3336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn idata = _responses->data.i; 3346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 3356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn fdata = _responses->data.fl; 3366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( vi < var_count && ci >= 0 || 3396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn vi == var_count && is_classifier ) // process categorical variable or response 3406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int c_count, prev_label; 3426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* c_map, *dst = get_cat_var_data( data_root, vi ); 3436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // copy data 3456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < sample_count; i++ ) 3466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int val = INT_MAX, si = sidx ? sidx[i] : i; 3486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !mask || !mask[si*m_step] ) 3496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( idata ) 3516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn val = idata[si*step]; 3526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 3536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float t = fdata[si*step]; 3556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn val = cvRound(t); 3566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( val != t ) 3576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sprintf( err, "%d-th value of %d-th (categorical) " 3596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "variable is not an integer", i, vi ); 3606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, err ); 3616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( val == INT_MAX ) 3656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sprintf( err, "%d-th value of %d-th (categorical) " 3676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "variable is too large", i, vi ); 3686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, err ); 3696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn num_valid++; 3716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dst[i] = val; 3736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int_ptr[i] = dst + i; 3746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // sort all the values, including the missing measurements 3776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // that should all move to the end 3786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn icvSortIntPtr( int_ptr, sample_count, 0 ); 3796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn //qsort( int_ptr, sample_count, sizeof(int_ptr[0]), icvCmpIntPtr ); 3806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn c_count = num_valid > 0; 3826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // count the categories 3846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 1; i < num_valid; i++ ) 3856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn c_count += *int_ptr[i] != *int_ptr[i-1]; 3866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( vi > 0 ) 3886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_c_count = MAX( max_c_count, c_count ); 3896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cat_count->data.i[ci] = c_count; 3906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cat_ofs->data.i[ci] = total_c_count; 3916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // resize cat_map, if need 3936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cat_map->cols < total_c_count + c_count ) 3946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn tmp_map = cat_map; 3966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cat_map = cvCreateMat( 1, 3976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 )); 3986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < total_c_count; i++ ) 3996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cat_map->data.i[i] = tmp_map->data.i[i]; 4006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &tmp_map ); 4016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn c_map = cat_map->data.i + total_c_count; 4046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn total_c_count += c_count; 4056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // compact the class indices and build the map 4076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn prev_label = ~*int_ptr[0]; 4086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn c_count = -1; 4096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < num_valid; i++ ) 4116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int cur_label = *int_ptr[i]; 4136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cur_label != prev_label ) 4146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn c_map[++c_count] = prev_label = cur_label; 4156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn *int_ptr[i] = c_count; 4166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // replace labels for missing values with -1 4196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ; i < sample_count; i++ ) 4206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn *int_ptr[i] = -1; 4216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( ci < 0 ) // process ordered variable 4236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvPair32s32f* dst = get_ord_var_data( data_root, vi ); 4256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < sample_count; i++ ) 4276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float val = ord_nan; 4296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int si = sidx ? sidx[i] : i; 4306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !mask || !mask[si*m_step] ) 4316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( idata ) 4336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn val = (float)idata[si*step]; 4346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 4356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn val = fdata[si*step]; 4366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( fabs(val) >= ord_nan ) 4386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sprintf( err, "%d-th value of %d-th (ordered) " 4406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "variable (=%g) is too large", i, vi, val ); 4416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, err ); 4426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn num_valid++; 4446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dst[i].i = i; 4466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dst[i].val = val; 4476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn icvSortPairs( dst, sample_count, 0 ); 4506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else // special case: process ordered response, 4526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // it will be stored similarly to categorical vars (i.e. no pairs) 4536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float* dst = get_ord_responses( data_root ); 4556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < sample_count; i++ ) 4576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float val = ord_nan; 4596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int si = sidx ? sidx[i] : i; 4606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( idata ) 4616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn val = (float)idata[si*step]; 4626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 4636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn val = fdata[si*step]; 4646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( fabs(val) >= ord_nan ) 4666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sprintf( err, "%d-th value of %d-th (ordered) " 4686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "variable (=%g) is out of range", i, vi, val ); 4696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, err ); 4706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dst[i] = val; 4726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cat_count->data.i[cat_var_count] = 0; 4756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cat_ofs->data.i[cat_var_count] = total_c_count; 4766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn num_valid = sample_count; 4776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( vi < var_count ) 4806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data_root->set_num_valid(vi, num_valid); 4816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cv_n ) 4846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* dst = get_labels(data_root); 4866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvRNG* r = &rng; 4876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = vi = 0; i < sample_count; i++ ) 4896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dst[i] = vi++; 4916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn vi &= vi < cv_n ? -1 : 0; 4926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < sample_count; i++ ) 4956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int a = cvRandInt(r) % sample_count; 4976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int b = cvRandInt(r) % sample_count; 4986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_SWAP( dst[a], dst[b], vi ); 4996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cat_map->cols = MAX( total_c_count, 1 ); 5036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_split_size = cvAlign(sizeof(CvDTreeSplit) + 5056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*)); 5066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage )); 5076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn have_priors = is_classifier && params.priors; 5096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( is_classifier ) 5106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int m = get_num_classes(); 5126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sum = 0; 5136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( priors = cvCreateMat( 1, m, CV_64F )); 5146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < m; i++ ) 5156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double val = have_priors ? params.priors[i] : 1.; 5176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( val <= 0 ) 5186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" ); 5196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn priors->data.db[i] = val; 5206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum += val; 5216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // normalize weights 5246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( have_priors ) 5256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvScale( priors, priors, 1./sum ); 5266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( priors_mult = cvCloneMat( priors )); 5286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 )); 5296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 )); 5326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 )); 5336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 5356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( data ) 5376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn delete data; 5386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &int_ptr ); 5406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &sample_idx ); 5416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &var_type0 ); 5426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &tmp_map ); 5436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 5446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx ) 5476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 5486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeNode* root = 0; 5496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* isubsample_idx = 0; 5506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* subsample_co = 0; 5516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvDTreeTrainData::subsample_data" ); 5536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 5556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !data_root ) 5576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsError, "No training data has been set" ); 5586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( _subsample_idx ) 5606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count )); 5616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !isubsample_idx ) 5636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // make a copy of the root node 5656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeNode temp; 5666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i; 5676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn root = new_node( 0, 1, 0, 0 ); 5686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn temp = *root; 5696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn *root = *data_root; 5706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn root->num_valid = temp.num_valid; 5716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( root->num_valid ) 5726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < var_count; i++ ) 5746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn root->num_valid[i] = data_root->num_valid[i]; 5756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn root->cv_Tn = temp.cv_Tn; 5776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn root->cv_node_risk = temp.cv_node_risk; 5786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn root->cv_node_error = temp.cv_node_error; 5796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 5816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* sidx = isubsample_idx->data.i; 5836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // co - array of count/offset pairs (to handle duplicated values in _subsample_idx) 5846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* co, cur_ofs = 0; 5856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int vi, i, total = data_root->sample_count; 5866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int count = isubsample_idx->rows + isubsample_idx->cols - 1; 5876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int work_var_count = get_work_var_count(); 5886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn root = new_node( 0, count, 1, 0 ); 5896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( subsample_co = cvCreateMat( 1, total*2, CV_32SC1 )); 5916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvZero( subsample_co ); 5926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn co = subsample_co->data.i; 5936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 5946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn co[sidx[i]*2]++; 5956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < total; i++ ) 5966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( co[i*2] ) 5986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn co[i*2+1] = cur_ofs; 6006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cur_ofs += co[i*2]; 6016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 6036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn co[i*2+1] = -1; 6046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( vi = 0; vi < work_var_count; vi++ ) 6076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int ci = get_var_type(vi); 6096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ci >= 0 || vi >= var_count ) 6116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* src = get_cat_var_data( data_root, vi ); 6136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* dst = get_cat_var_data( root, vi ); 6146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int num_valid = 0; 6156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 6176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int val = src[sidx[i]]; 6196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dst[i] = val; 6206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn num_valid += val >= 0; 6216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( vi < var_count ) 6246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn root->set_num_valid(vi, num_valid); 6256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 6276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvPair32s32f* src = get_ord_var_data( data_root, vi ); 6296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvPair32s32f* dst = get_ord_var_data( root, vi ); 6306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int j = 0, idx, count_i; 6316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int num_valid = data_root->get_num_valid(vi); 6326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < num_valid; i++ ) 6346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn idx = src[i].i; 6366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn count_i = co[idx*2]; 6376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( count_i ) 6386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float val = src[i].val; 6406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ ) 6416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dst[j].val = val; 6436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dst[j].i = cur_ofs; 6446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn root->set_num_valid(vi, j); 6496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ; i < total; i++ ) 6516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn idx = src[i].i; 6536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn count_i = co[idx*2]; 6546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( count_i ) 6556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float val = src[i].val; 6576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ ) 6586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dst[j].val = val; 6606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dst[j].i = cur_ofs; 6616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 6696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &isubsample_idx ); 6716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &subsample_co ); 6726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return root; 6746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 6756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx, 6786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float* values, uchar* missing, 6796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float* responses, bool get_class_idx ) 6806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 6816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* subsample_idx = 0; 6826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* subsample_co = 0; 6836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvDTreeTrainData::get_vectors" ); 6856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 6876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, vi, total = sample_count, count = total, cur_ofs = 0; 6896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* sidx = 0; 6906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* co = 0; 6916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( _subsample_idx ) 6936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count )); 6956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sidx = subsample_idx->data.i; 6966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 )); 6976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn co = subsample_co->data.i; 6986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvZero( subsample_co ); 6996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn count = subsample_idx->cols + subsample_idx->rows - 1; 7006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 7016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn co[sidx[i]*2]++; 7026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < total; i++ ) 7036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int count_i = co[i*2]; 7056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( count_i ) 7066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn co[i*2+1] = cur_ofs*var_count; 7086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cur_ofs += count_i; 7096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( missing ) 7146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn memset( missing, 1, count*var_count ); 7156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( vi = 0; vi < var_count; vi++ ) 7176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int ci = get_var_type(vi); 7196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ci >= 0 ) // categorical 7206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float* dst = values + vi; 7226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn uchar* m = missing ? missing + vi : 0; 7236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* src = get_cat_var_data(data_root, vi); 7246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++, dst += var_count ) 7266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = sidx ? sidx[i] : i; 7286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int val = src[idx]; 7296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn *dst = (float)val; 7306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( m ) 7316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn *m = val < 0; 7336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn m += var_count; 7346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else // ordered 7386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float* dst = values + vi; 7406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn uchar* m = missing ? missing + vi : 0; 7416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvPair32s32f* src = get_ord_var_data(data_root, vi); 7426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int count1 = data_root->get_num_valid(vi); 7436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count1; i++ ) 7456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = src[i].i; 7476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int count_i = 1; 7486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( co ) 7496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn count_i = co[idx*2]; 7516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cur_ofs = co[idx*2+1]; 7526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 7546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cur_ofs = idx*var_count; 7556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( count_i ) 7566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float val = src[i].val; 7586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ; count_i > 0; count_i--, cur_ofs += var_count ) 7596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dst[cur_ofs] = val; 7616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( m ) 7626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn m[cur_ofs] = 0; 7636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // copy responses 7706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( responses ) 7716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( is_classifier ) 7736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* src = get_class_labels(data_root); 7756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 7766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = sidx ? sidx[i] : i; 7786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int val = get_class_idx ? src[idx] : 7796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]]; 7806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn responses[i] = (float)val; 7816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 7846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float* src = get_ord_responses(data_root); 7866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 7876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = sidx ? sidx[i] : i; 7896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn responses[i] = src[idx]; 7906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 7956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &subsample_idx ); 7976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &subsample_co ); 7986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 7996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count, 8026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int storage_idx, int offset ) 8036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 8046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap ); 8056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->sample_count = count; 8076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->depth = parent ? parent->depth + 1 : 0; 8086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->parent = parent; 8096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->left = node->right = 0; 8106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->split = 0; 8116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->value = 0; 8126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->class_idx = 0; 8136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->maxlr = 0.; 8146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->buf_idx = storage_idx; 8166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->offset = offset; 8176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( nv_heap ) 8186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->num_valid = (int*)cvSetNew( nv_heap ); 8196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 8206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->num_valid = 0; 8216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.; 8226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->complexity = 0; 8236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.cv_folds > 0 && cv_heap ) 8256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int cv_n = params.cv_folds; 8276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->Tn = INT_MAX; 8286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->cv_Tn = (int*)cvSetNew( cv_heap ); 8296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double)); 8306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->cv_node_error = node->cv_node_risk + cv_n; 8316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 8336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->Tn = 0; 8356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->cv_Tn = 0; 8366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->cv_node_risk = 0; 8376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->cv_node_error = 0; 8386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return node; 8416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 8426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val, 8456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int split_point, int inversed, float quality ) 8466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 8476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap ); 8486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->var_idx = vi; 8496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->ord.c = cmp_val; 8506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->ord.split_point = split_point; 8516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->inversed = inversed; 8526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->quality = quality; 8536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->next = 0; 8546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return split; 8566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 8576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality ) 8606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 8616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap ); 8626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, n = (max_c_count + 31)/32; 8636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->var_idx = vi; 8656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->inversed = 0; 8666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->quality = quality; 8676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 8686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->subset[i] = 0; 8696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->next = 0; 8706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return split; 8726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 8736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTreeTrainData::free_node( CvDTreeNode* node ) 8766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 8776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeSplit* split = node->split; 8786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn free_node_data( node ); 8796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn while( split ) 8806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeSplit* next = split->next; 8826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSetRemoveByPtr( split_heap, split ); 8836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split = next; 8846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->split = 0; 8866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSetRemoveByPtr( node_heap, node ); 8876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 8886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTreeTrainData::free_node_data( CvDTreeNode* node ) 8916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 8926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( node->num_valid ) 8936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSetRemoveByPtr( nv_heap, node->num_valid ); 8956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->num_valid = 0; 8966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // do not free cv_* fields, as all the cross-validation related data is released at once. 8986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 8996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTreeTrainData::free_train_data() 9026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 9036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &counts ); 9046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &buf ); 9056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &direction ); 9066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &split_buf ); 9076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMemStorage( &temp_storage ); 9086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cv_heap = nv_heap = 0; 9096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 9106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTreeTrainData::clear() 9136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 9146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn free_train_data(); 9156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMemStorage( &tree_storage ); 9176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &var_idx ); 9196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &var_type ); 9206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &cat_count ); 9216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &cat_ofs ); 9226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &cat_map ); 9236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &priors ); 9246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &priors_mult ); 9256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node_heap = split_heap = 0; 9276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0; 9296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn have_labels = have_priors = is_classifier = false; 9306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn buf_count = buf_size = 0; 9326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn shared = false; 9336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data_root = 0; 9356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rng = cvRNG(-1); 9376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 9386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvDTreeTrainData::get_num_classes() const 9416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 9426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return is_classifier ? cat_count->data.i[cat_var_count] : 0; 9436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 9446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvDTreeTrainData::get_var_type(int vi) const 9476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 9486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return var_type->data.i[vi]; 9496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 9506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvDTreeTrainData::get_work_var_count() const 9536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 9546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return var_count + 1 + (have_labels ? 1 : 0); 9556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 9566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvPair32s32f* CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi ) 9586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 9596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int oi = ~get_var_type(vi); 9606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn assert( 0 <= oi && oi < ord_var_count ); 9616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return (CvPair32s32f*)(buf->data.i + n->buf_idx*buf->cols + 9626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn n->offset + oi*n->sample_count*2); 9636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 9646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint* CvDTreeTrainData::get_class_labels( CvDTreeNode* n ) 9676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 9686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return get_cat_var_data( n, var_count ); 9696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 9706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennfloat* CvDTreeTrainData::get_ord_responses( CvDTreeNode* n ) 9736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 9746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return (float*)get_cat_var_data( n, var_count ); 9756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 9766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint* CvDTreeTrainData::get_labels( CvDTreeNode* n ) 9796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 9806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return have_labels ? get_cat_var_data( n, var_count + 1 ) : 0; 9816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 9826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint* CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi ) 9856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 9866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int ci = get_var_type(vi); 9876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn assert( 0 <= ci && ci <= cat_var_count + 1 ); 9886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return buf->data.i + n->buf_idx*buf->cols + n->offset + 9896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn (ord_var_count*2 + ci)*n->sample_count; 9906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 9916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n ) 9946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 9956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = n->buf_idx + 1; 9966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( idx >= buf_count ) 9976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn idx = shared ? 1 : 0; 9986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return idx; 9996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 10006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTreeTrainData::write_params( CvFileStorage* fs ) 10036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 10046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvDTreeTrainData::write_params" ); 10056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 10076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int vi, vcount = var_count; 10096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "is_classifier", is_classifier ? 1 : 0 ); 10116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "var_all", var_all ); 10126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "var_count", var_count ); 10136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "ord_var_count", ord_var_count ); 10146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "cat_var_count", cat_var_count ); 10156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartWriteStruct( fs, "training_params", CV_NODE_MAP ); 10176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "use_surrogates", params.use_surrogates ? 1 : 0 ); 10186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( is_classifier ) 10206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "max_categories", params.max_categories ); 10226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 10246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteReal( fs, "regression_accuracy", params.regression_accuracy ); 10266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "max_depth", params.max_depth ); 10296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "min_sample_count", params.min_sample_count ); 10306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "cross_validation_folds", params.cv_folds ); 10316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.cv_folds > 1 ) 10336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "use_1se_rule", params.use_1se_rule ? 1 : 0 ); 10356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "truncate_pruned_tree", params.truncate_pruned_tree ? 1 : 0 ); 10366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( priors ) 10396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWrite( fs, "priors", priors ); 10406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvEndWriteStruct( fs ); 10426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( var_idx ) 10446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWrite( fs, "var_idx", var_idx ); 10456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW ); 10476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( vi = 0; vi < vcount; vi++ ) 10496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, 0, var_type->data.i[vi] >= 0 ); 10506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvEndWriteStruct( fs ); 10526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cat_count && (cat_var_count > 0 || is_classifier) ) 10546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ASSERT( cat_count != 0 ); 10566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWrite( fs, "cat_count", cat_count ); 10576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWrite( fs, "cat_map", cat_map ); 10586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 10616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 10626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node ) 10656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 10666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvDTreeTrainData::read_params" ); 10676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 10696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvFileNode *tparams_node, *vartype_node; 10716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvSeqReader reader; 10726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int vi, max_split_size, tree_block_size; 10736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0); 10756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn var_all = cvReadIntByName( fs, node, "var_all" ); 10766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn var_count = cvReadIntByName( fs, node, "var_count", var_all ); 10776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cat_var_count = cvReadIntByName( fs, node, "cat_var_count" ); 10786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ord_var_count = cvReadIntByName( fs, node, "ord_var_count" ); 10796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn tparams_node = cvGetFileNodeByName( fs, node, "training_params" ); 10816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( tparams_node ) // training parameters are not necessary 10836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0; 10856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( is_classifier ) 10876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" ); 10896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 10916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.regression_accuracy = 10936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" ); 10946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" ); 10976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" ); 10986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" ); 10996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.cv_folds > 1 ) 11016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0; 11036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.truncate_pruned_tree = 11046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0; 11056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" ); 11086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( priors ) 11096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_IS_MAT(priors) ) 11116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsParseError, "priors must stored as a matrix" ); 11126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn priors_mult = cvCloneMat( priors ); 11136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" )); 11176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( var_idx ) 11186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_IS_MAT(var_idx) || 11206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn var_idx->cols != 1 && var_idx->rows != 1 || 11216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn var_idx->cols + var_idx->rows - 1 != var_count || 11226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_MAT_TYPE(var_idx->type) != CV_32SC1 ) 11236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsParseError, 11246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" ); 11256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( vi = 0; vi < var_count; vi++ ) 11276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( (unsigned)var_idx->data.i[vi] >= (unsigned)var_all ) 11286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" ); 11296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ////// read var type 11326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 )); 11336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cat_var_count = 0; 11356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ord_var_count = -1; 11366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn vartype_node = cvGetFileNodeByName( fs, node, "var_type" ); 11376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && var_count == 1 ) 11396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--; 11406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 11416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ || 11436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn vartype_node->data.seq->total != var_count ) 11446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" ); 11456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartReadSeq( vartype_node->data.seq, &reader ); 11476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( vi = 0; vi < var_count; vi++ ) 11496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvFileNode* n = (CvFileNode*)reader.ptr; 11516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) ) 11526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" ); 11536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--; 11546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader ); 11556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn var_type->data.i[var_count] = cat_var_count; 11586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ord_var_count = ~ord_var_count; 11606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cat_var_count != cat_var_count || ord_var_count != ord_var_count ) 11616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" ); 11626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ////// 11636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cat_var_count > 0 || is_classifier ) 11656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int ccount, total_c_count = 0; 11676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" )); 11686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" )); 11696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_IS_MAT(cat_count) || !CV_IS_MAT(cat_map) || 11716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cat_count->cols != 1 && cat_count->rows != 1 || 11726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_MAT_TYPE(cat_count->type) != CV_32SC1 || 11736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cat_count->cols + cat_count->rows - 1 != cat_var_count + is_classifier || 11746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cat_map->cols != 1 && cat_map->rows != 1 || 11756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_MAT_TYPE(cat_map->type) != CV_32SC1 ) 11766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsParseError, 11776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" ); 11786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ccount = cat_var_count + is_classifier; 11806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 )); 11826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cat_ofs->data.i[0] = 0; 11836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_c_count = 1; 11846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( vi = 0; vi < ccount; vi++ ) 11866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int val = cat_count->data.i[vi]; 11886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( val <= 0 ) 11896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" ); 11906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_c_count = MAX( max_c_count, val ); 11916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cat_ofs->data.i[vi+1] = total_c_count += val; 11926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cat_map->cols + cat_map->rows - 1 != total_c_count ) 11956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadSize, 11966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "cat_map vector length is not equal to the total number of categories in all categorical vars" ); 11976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_split_size = cvAlign(sizeof(CvDTreeSplit) + 12006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*)); 12016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size); 12036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size); 12046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size )); 12056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( node_heap = cvCreateSet( 0, sizeof(node_heap[0]), 12066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sizeof(CvDTreeNode), tree_storage )); 12076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( split_heap = cvCreateSet( 0, sizeof(split_heap[0]), 12086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_split_size, tree_storage )); 12096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 12116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 12126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/////////////////////// Decision Tree ///////////////////////// 12156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTree::CvDTree() 12176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 12186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data = 0; 12196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn var_importance = 0; 12206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn default_model_name = "my_tree"; 12216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn clear(); 12236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 12246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::clear() 12276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 12286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &var_importance ); 12296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( data ) 12306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 12316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !data->shared ) 12326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn delete data; 12336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 12346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn free_tree(); 12356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data = 0; 12366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 12376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn root = 0; 12386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn pruned_tree_idx = -1; 12396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 12406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTree::~CvDTree() 12436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 12446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn clear(); 12456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 12466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennconst CvDTreeNode* CvDTree::get_root() const 12496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 12506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return root; 12516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 12526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvDTree::get_pruned_tree_idx() const 12556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 12566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return pruned_tree_idx; 12576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 12586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeTrainData* CvDTree::get_data() 12616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 12626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return data; 12636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 12646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool CvDTree::train( const CvMat* _train_data, int _tflag, 12676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* _responses, const CvMat* _var_idx, 12686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* _sample_idx, const CvMat* _var_type, 12696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* _missing_mask, CvDTreeParams _params ) 12706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 12716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bool result = false; 12726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvDTree::train" ); 12746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 12766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn clear(); 12786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data = new CvDTreeTrainData( _train_data, _tflag, _responses, 12796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _var_idx, _sample_idx, _var_type, 12806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _missing_mask, _params, false ); 12816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( result = do_train(0)); 12826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 12846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return result; 12866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 12876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx ) 12906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 12916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bool result = false; 12926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvDTree::train" ); 12946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 12966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn clear(); 12986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data = _data; 12996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data->shared = true; 13006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( result = do_train(_subsample_idx)); 13016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 13036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return result; 13056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 13066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool CvDTree::do_train( const CvMat* _subsample_idx ) 13096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 13106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bool result = false; 13116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvDTree::do_train" ); 13136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 13156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn root = data->subsample_data( _subsample_idx ); 13176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( try_split_node(root)); 13196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( data->params.cv_folds > 0 ) 13216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( prune_cv()); 13226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !data->shared ) 13246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data->free_train_data(); 13256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn result = true; 13276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 13296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return result; 13316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 13326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::try_split_node( CvDTreeNode* node ) 13356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 13366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeSplit* best_split = 0; 13376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, n = node->sample_count, vi; 13386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bool can_split = true; 13396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double quality_scale; 13406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn calc_node_value( node ); 13426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( node->sample_count <= data->params.min_sample_count || 13446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->depth >= data->params.max_depth ) 13456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn can_split = false; 13466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( can_split && data->is_classifier ) 13486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 13496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // check if we have a "pure" node, 13506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // we assume that cls_count is filled by calc_node_value() 13516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* cls_count = data->counts->data.i; 13526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int nz = 0, m = data->get_num_classes(); 13536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < m; i++ ) 13546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn nz += cls_count[i] != 0; 13556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( nz == 1 ) // there is only one class 13566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn can_split = false; 13576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 13586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( can_split ) 13596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 13606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( sqrt(node->node_risk)/n < data->params.regression_accuracy ) 13616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn can_split = false; 13626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 13636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( can_split ) 13656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 13666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_split = find_best_split(node); 13676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // TODO: check the split quality ... 13686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->split = best_split; 13696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 13706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !can_split || !best_split ) 13726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 13736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data->free_node_data(node); 13746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return; 13756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 13766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn quality_scale = calc_node_dir( node ); 13786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( data->params.use_surrogates ) 13806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 13816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // find all the surrogate splits 13826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // and sort them by their similarity to the primary one 13836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( vi = 0; vi < data->var_count; vi++ ) 13846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 13856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeSplit* split; 13866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int ci = data->get_var_type(vi); 13876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( vi == best_split->var_idx ) 13896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn continue; 13906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ci >= 0 ) 13926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split = find_surrogate_split_cat( node, vi ); 13936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 13946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split = find_surrogate_split_ord( node, vi ); 13956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( split ) 13976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 13986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // insert the split 13996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeSplit* prev_split = node->split; 14006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->quality = (float)(split->quality*quality_scale); 14016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn while( prev_split->next && 14036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn prev_split->next->quality > split->quality ) 14046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn prev_split = prev_split->next; 14056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->next = prev_split->next; 14066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn prev_split->next = split; 14076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 14086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 14096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 14106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split_node_data( node ); 14126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn try_split_node( node->left ); 14136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn try_split_node( node->right ); 14146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 14156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// calculate direction (left(-1),right(1),missing(0)) 14186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// for each sample using the best split 14196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// the function returns scale coefficients for surrogate split quality factors. 14206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// the scale is applied to normalize surrogate split quality relatively to the 14216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// best (primary) split quality. That is, if a surrogate split is absolutely 14226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// identical to the primary split, its quality will be set to the maximum value = 14236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// quality of the primary split; otherwise, it will be lower. 14246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// besides, the function compute node->maxlr, 14256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// minimum possible quality (w/o considering the above mentioned scale) 14266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// for a surrogate split. Surrogate splits with quality less than node->maxlr 14276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// are not discarded. 14286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renndouble CvDTree::calc_node_dir( CvDTreeNode* node ) 14296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 14306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn char* dir = (char*)data->direction->data.ptr; 14316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, n = node->sample_count, vi = node->split->var_idx; 14326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double L, R; 14336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn assert( !node->split->inversed ); 14356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( data->get_var_type(vi) >= 0 ) // split on categorical var 14376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 14386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* labels = data->get_cat_var_data(node,vi); 14396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* subset = node->split->subset; 14406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !data->have_priors ) 14426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 14436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int sum = 0, sum_abs = 0; 14446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 14466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 14476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = labels[i]; 14486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0; 14496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum += d; sum_abs += d & 1; 14506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dir[i] = (char)d; 14516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 14526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn R = (sum_abs + sum) >> 1; 14546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn L = (sum_abs - sum) >> 1; 14556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 14566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 14576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 14586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* responses = data->get_class_labels(node); 14596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* priors = data->priors_mult->data.db; 14606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sum = 0, sum_abs = 0; 14616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 14636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 14646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = labels[i]; 14656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = priors[responses[i]]; 14666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0; 14676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum += d*w; sum_abs += (d & 1)*w; 14686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dir[i] = (char)d; 14696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 14706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn R = (sum_abs + sum) * 0.5; 14726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn L = (sum_abs - sum) * 0.5; 14736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 14746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 14756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else // split on ordered var 14766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 14776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvPair32s32f* sorted = data->get_ord_var_data(node,vi); 14786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int split_point = node->split->ord.split_point; 14796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n1 = node->get_num_valid(vi); 14806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn assert( 0 <= split_point && split_point < n1-1 ); 14826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !data->have_priors ) 14846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 14856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i <= split_point; i++ ) 14866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dir[sorted[i].i] = (char)-1; 14876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ; i < n1; i++ ) 14886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dir[sorted[i].i] = (char)1; 14896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ; i < n; i++ ) 14906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dir[sorted[i].i] = (char)0; 14916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn L = split_point-1; 14936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn R = n1 - split_point + 1; 14946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 14956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 14966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 14976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* responses = data->get_class_labels(node); 14986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* priors = data->priors_mult->data.db; 14996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn L = R = 0; 15006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i <= split_point; i++ ) 15026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 15036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = sorted[i].i; 15046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = priors[responses[idx]]; 15056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dir[idx] = (char)-1; 15066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn L += w; 15076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 15086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ; i < n1; i++ ) 15106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 15116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = sorted[i].i; 15126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = priors[responses[idx]]; 15136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dir[idx] = (char)1; 15146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn R += w; 15156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 15166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ; i < n; i++ ) 15186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dir[sorted[i].i] = (char)0; 15196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 15206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 15216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->maxlr = MAX( L, R ); 15236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return node->split->quality/(L + R); 15246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 15256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node ) 15286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 15296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int vi; 15306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeSplit *best_split = 0, *split = 0, *t; 15316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( vi = 0; vi < data->var_count; vi++ ) 15336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 15346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int ci = data->get_var_type(vi); 15356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( node->get_num_valid(vi) <= 1 ) 15366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn continue; 15376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( data->is_classifier ) 15396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 15406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ci >= 0 ) 15416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split = find_split_cat_class( node, vi ); 15426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 15436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split = find_split_ord_class( node, vi ); 15446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 15456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 15466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 15476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ci >= 0 ) 15486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split = find_split_cat_reg( node, vi ); 15496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 15506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split = find_split_ord_reg( node, vi ); 15516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 15526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( split ) 15546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 15556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !best_split || best_split->quality < split->quality ) 15566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_SWAP( best_split, split, t ); 15576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( split ) 15586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSetRemoveByPtr( data->split_heap, split ); 15596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 15606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 15616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return best_split; 15636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 15646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi ) 15676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 15686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float epsilon = FLT_EPSILON*2; 15696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvPair32s32f* sorted = data->get_ord_var_data(node, vi); 15706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* responses = data->get_class_labels(node); 15716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n = node->sample_count; 15726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n1 = node->get_num_valid(vi); 15736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int m = data->get_num_classes(); 15746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* rc0 = data->counts->data.i; 15756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* lc = (int*)cvStackAlloc(m*sizeof(lc[0])); 15766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* rc = (int*)cvStackAlloc(m*sizeof(rc[0])); 15776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, best_i = -1; 15786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double lsum2 = 0, rsum2 = 0, best_val = 0; 15796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* priors = data->have_priors ? data->priors_mult->data.db : 0; 15806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // init arrays of class instance counters on both sides of the split 15826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < m; i++ ) 15836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 15846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lc[i] = 0; 15856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rc[i] = rc0[i]; 15866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 15876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // compensate for missing values 15896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = n1; i < n; i++ ) 15906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rc[responses[sorted[i].i]]--; 15916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !priors ) 15936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 15946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int L = 0, R = n1; 15956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < m; i++ ) 15976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rsum2 += (double)rc[i]*rc[i]; 15986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n1 - 1; i++ ) 16006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 16016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = responses[sorted[i].i]; 16026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int lv, rv; 16036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn L++; R--; 16046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lv = lc[idx]; rv = rc[idx]; 16056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lsum2 += lv*2 + 1; 16066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rsum2 -= rv*2 - 1; 16076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lc[idx] = lv + 1; rc[idx] = rv - 1; 16086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( sorted[i].val + epsilon < sorted[i+1].val ) 16106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 16116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double val = (lsum2*R + rsum2*L)/((double)L*R); 16126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( best_val < val ) 16136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 16146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val = val; 16156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_i = i; 16166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 16176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 16186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 16196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 16206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 16216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 16226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double L = 0, R = 0; 16236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < m; i++ ) 16246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 16256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double wv = rc[i]*priors[i]; 16266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn R += wv; 16276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rsum2 += wv*wv; 16286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 16296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n1 - 1; i++ ) 16316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 16326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = responses[sorted[i].i]; 16336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int lv, rv; 16346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double p = priors[idx], p2 = p*p; 16356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn L += p; R -= p; 16366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lv = lc[idx]; rv = rc[idx]; 16376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lsum2 += p2*(lv*2 + 1); 16386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rsum2 -= p2*(rv*2 - 1); 16396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lc[idx] = lv + 1; rc[idx] = rv - 1; 16406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( sorted[i].val + epsilon < sorted[i+1].val ) 16426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 16436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double val = (lsum2*R + rsum2*L)/((double)L*R); 16446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( best_val < val ) 16456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 16466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val = val; 16476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_i = i; 16486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 16496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 16506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 16516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 16526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return best_i >= 0 ? data->new_split_ord( vi, 16546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i, 16556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 0, (float)best_val ) : 0; 16566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 16576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::cluster_categories( const int* vectors, int n, int m, 16606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* csums, int k, int* labels ) 16616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 16626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // TODO: consider adding priors (class weights) and sample weights to the clustering algorithm 16636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int iters = 0, max_iters = 100; 16646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, j, idx; 16656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* buf = (double*)cvStackAlloc( (n + k)*sizeof(buf[0]) ); 16666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double *v_weights = buf, *c_weights = buf + k; 16676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bool modified = true; 16686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvRNG* r = &data->rng; 16696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // assign labels randomly 16716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = idx = 0; i < n; i++ ) 16726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 16736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int sum = 0; 16746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* v = vectors + i*m; 16756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn labels[i] = idx++; 16766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn idx &= idx < k ? -1 : 0; 16776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // compute weight of each vector 16796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < m; j++ ) 16806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum += v[j]; 16816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn v_weights[i] = sum ? 1./sum : 0.; 16826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 16836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 16856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 16866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i1 = cvRandInt(r) % n; 16876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i2 = cvRandInt(r) % n; 16886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_SWAP( labels[i1], labels[i2], j ); 16896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 16906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( iters = 0; iters <= max_iters; iters++ ) 16926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 16936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // calculate csums 16946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < k; i++ ) 16956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 16966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < m; j++ ) 16976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn csums[i*m + j] = 0; 16986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 16996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 17006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 17016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 17026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* v = vectors + i*m; 17036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* s = csums + labels[i]*m; 17046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < m; j++ ) 17056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn s[j] += v[j]; 17066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 17076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 17086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // exit the loop here, when we have up-to-date csums 17096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( iters == max_iters || !modified ) 17106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 17116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 17126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn modified = false; 17136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 17146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // calculate weight of each cluster 17156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < k; i++ ) 17166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 17176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* s = csums + i*m; 17186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int sum = 0; 17196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < m; j++ ) 17206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum += s[j]; 17216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn c_weights[i] = sum ? 1./sum : 0; 17226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 17236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 17246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // now for each vector determine the closest cluster 17256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 17266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 17276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* v = vectors + i*m; 17286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double alpha = v_weights[i]; 17296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double min_dist2 = DBL_MAX; 17306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int min_idx = -1; 17316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 17326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( idx = 0; idx < k; idx++ ) 17336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 17346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* s = csums + idx*m; 17356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double dist2 = 0., beta = c_weights[idx]; 17366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < m; j++ ) 17376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 17386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double t = v[j]*alpha - s[j]*beta; 17396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dist2 += t*t; 17406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 17416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( min_dist2 > dist2 ) 17426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 17436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn min_dist2 = dist2; 17446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn min_idx = idx; 17456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 17466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 17476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 17486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( min_idx != labels[i] ) 17496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn modified = true; 17506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn labels[i] = min_idx; 17516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 17526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 17536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 17546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 17556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 17566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi ) 17576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 17586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeSplit* split; 17596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* labels = data->get_cat_var_data(node, vi); 17606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* responses = data->get_class_labels(node); 17616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int ci = data->get_var_type(vi); 17626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n = node->sample_count; 17636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int m = data->get_num_classes(); 17646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int _mi = data->cat_count->data.i[ci], mi = _mi; 17656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* lc = (int*)cvStackAlloc(m*sizeof(lc[0])); 17666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* rc = (int*)cvStackAlloc(m*sizeof(rc[0])); 17676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* _cjk = (int*)cvStackAlloc(m*(mi+1)*sizeof(_cjk[0]))+m, *cjk = _cjk; 17686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* c_weights = (double*)cvStackAlloc( mi*sizeof(c_weights[0]) ); 17696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* cluster_labels = 0; 17706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int** int_ptr = 0; 17716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, j, k, idx; 17726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double L = 0, R = 0; 17736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double best_val = 0; 17746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0; 17756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* priors = data->priors_mult->data.db; 17766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 17776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // init array of counters: 17786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // c_{jk} - number of samples that have vi-th input variable = j and response = k. 17796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = -1; j < mi; j++ ) 17806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < m; k++ ) 17816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cjk[j*m + k] = 0; 17826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 17836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 17846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 17856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn j = labels[i]; 17866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn k = responses[i]; 17876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cjk[j*m + k]++; 17886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 17896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 17906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( m > 2 ) 17916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 17926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( mi > data->params.max_categories ) 17936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 17946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn mi = MIN(data->params.max_categories, n); 17956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cjk += _mi*m; 17966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cluster_labels = (int*)cvStackAlloc(mi*sizeof(cluster_labels[0])); 17976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels ); 17986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 17996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn subset_i = 1; 18006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn subset_n = 1 << mi; 18016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 18026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 18036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 18046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn assert( m == 2 ); 18056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int_ptr = (int**)cvStackAlloc( mi*sizeof(int_ptr[0]) ); 18066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < mi; j++ ) 18076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int_ptr[j] = cjk + j*2 + 1; 18086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn icvSortIntPtr( int_ptr, mi, 0 ); 18096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn subset_i = 0; 18106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn subset_n = mi; 18116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 18126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 18136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < m; k++ ) 18146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 18156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int sum = 0; 18166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < mi; j++ ) 18176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum += cjk[j*m + k]; 18186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rc[k] = sum; 18196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lc[k] = 0; 18206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 18216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 18226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < mi; j++ ) 18236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 18246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sum = 0; 18256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < m; k++ ) 18266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum += cjk[j*m + k]*priors[k]; 18276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn c_weights[j] = sum; 18286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn R += c_weights[j]; 18296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 18306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 18316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ; subset_i < subset_n; subset_i++ ) 18326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 18336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double weight; 18346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* crow; 18356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double lsum2 = 0, rsum2 = 0; 18366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 18376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( m == 2 ) 18386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn idx = (int)(int_ptr[subset_i] - cjk)/2; 18396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 18406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 18416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int graycode = (subset_i>>1)^subset_i; 18426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int diff = graycode ^ prevcode; 18436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 18446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // determine index of the changed bit. 18456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn Cv32suf u; 18466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn idx = diff >= (1 << 16) ? 16 : 0; 18476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn u.f = (float)(((diff >> 16) | diff) & 65535); 18486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn idx += (u.i >> 23) - 127; 18496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn subtract = graycode < prevcode; 18506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn prevcode = graycode; 18516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 18526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 18536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn crow = cjk + idx*m; 18546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weight = c_weights[idx]; 18556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( weight < FLT_EPSILON ) 18566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn continue; 18576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 18586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !subtract ) 18596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 18606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < m; k++ ) 18616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 18626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int t = crow[k]; 18636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int lval = lc[k] + t; 18646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int rval = rc[k] - t; 18656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double p = priors[k], p2 = p*p; 18666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lsum2 += p2*lval*lval; 18676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rsum2 += p2*rval*rval; 18686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lc[k] = lval; rc[k] = rval; 18696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 18706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn L += weight; 18716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn R -= weight; 18726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 18736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 18746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 18756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < m; k++ ) 18766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 18776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int t = crow[k]; 18786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int lval = lc[k] - t; 18796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int rval = rc[k] + t; 18806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double p = priors[k], p2 = p*p; 18816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lsum2 += p2*lval*lval; 18826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rsum2 += p2*rval*rval; 18836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lc[k] = lval; rc[k] = rval; 18846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 18856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn L -= weight; 18866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn R += weight; 18876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 18886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 18896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( L > FLT_EPSILON && R > FLT_EPSILON ) 18906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 18916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double val = (lsum2*R + rsum2*L)/((double)L*R); 18926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( best_val < val ) 18936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 18946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val = val; 18956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_subset = subset_i; 18966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 18976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 18986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 18996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 19006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( best_subset < 0 ) 19016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return 0; 19026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 19036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split = data->new_split_cat( vi, (float)best_val ); 19046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 19056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( m == 2 ) 19066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 19076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i <= best_subset; i++ ) 19086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 19096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn idx = (int)(int_ptr[i] - cjk) >> 1; 19106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->subset[idx >> 5] |= 1 << (idx & 31); 19116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 19126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 19136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 19146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 19156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < _mi; i++ ) 19166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 19176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn idx = cluster_labels ? cluster_labels[i] : i; 19186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( best_subset & (1 << idx) ) 19196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->subset[i >> 5] |= 1 << (i & 31); 19206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 19216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 19226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 19236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return split; 19246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 19256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 19266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 19276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi ) 19286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 19296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float epsilon = FLT_EPSILON*2; 19306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvPair32s32f* sorted = data->get_ord_var_data(node, vi); 19316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float* responses = data->get_ord_responses(node); 19326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n = node->sample_count; 19336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n1 = node->get_num_valid(vi); 19346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, best_i = -1; 19356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double best_val = 0, lsum = 0, rsum = node->value*n; 19366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int L = 0, R = n1; 19376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 19386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // compensate for missing values 19396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = n1; i < n; i++ ) 19406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rsum -= responses[sorted[i].i]; 19416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 19426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // find the optimal split 19436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n1 - 1; i++ ) 19446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 19456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float t = responses[sorted[i].i]; 19466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn L++; R--; 19476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lsum += t; 19486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rsum -= t; 19496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 19506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( sorted[i].val + epsilon < sorted[i+1].val ) 19516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 19526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R); 19536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( best_val < val ) 19546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 19556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val = val; 19566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_i = i; 19576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 19586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 19596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 19606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 19616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return best_i >= 0 ? data->new_split_ord( vi, 19626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i, 19636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 0, (float)best_val ) : 0; 19646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 19656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 19666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 19676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi ) 19686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 19696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeSplit* split; 19706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* labels = data->get_cat_var_data(node, vi); 19716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float* responses = data->get_ord_responses(node); 19726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int ci = data->get_var_type(vi); 19736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n = node->sample_count; 19746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int mi = data->cat_count->data.i[ci]; 19756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1; 19766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* counts = (int*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1; 19776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double** sum_ptr = 0; 19786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, L = 0, R = 0; 19796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double best_val = 0, lsum = 0, rsum = 0; 19806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int best_subset = -1, subset_i; 19816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 19826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = -1; i < mi; i++ ) 19836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum[i] = counts[i] = 0; 19846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 19856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // calculate sum response and weight of each category of the input var 19866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 19876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 19886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = labels[i]; 19896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double s = sum[idx] + responses[i]; 19906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int nc = counts[idx] + 1; 19916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum[idx] = s; 19926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn counts[idx] = nc; 19936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 19946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 19956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // calculate average response in each category 19966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < mi; i++ ) 19976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 19986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn R += counts[i]; 19996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rsum += sum[i]; 20006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum[i] /= MAX(counts[i],1); 20016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum_ptr[i] = sum + i; 20026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 20036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 20046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn icvSortDblPtr( sum_ptr, mi, 0 ); 20056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 20066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // revert back to unnormalized sums 20076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // (there should be a very little loss of accuracy) 20086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < mi; i++ ) 20096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum[i] *= counts[i]; 20106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 20116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( subset_i = 0; subset_i < mi-1; subset_i++ ) 20126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 20136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = (int)(sum_ptr[subset_i] - sum); 20146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int ni = counts[idx]; 20156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 20166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ni ) 20176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 20186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double s = sum[idx]; 20196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lsum += s; L += ni; 20206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rsum -= s; R -= ni; 20216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 20226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( L && R ) 20236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 20246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R); 20256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( best_val < val ) 20266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 20276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val = val; 20286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_subset = subset_i; 20296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 20306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 20316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 20326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 20336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 20346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( best_subset < 0 ) 20356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return 0; 20366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 20376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split = data->new_split_cat( vi, (float)best_val ); 20386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i <= best_subset; i++ ) 20396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 20406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = (int)(sum_ptr[i] - sum); 20416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->subset[idx >> 5] |= 1 << (idx & 31); 20426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 20436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 20446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return split; 20456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 20466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 20476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 20486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi ) 20496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 20506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float epsilon = FLT_EPSILON*2; 20516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvPair32s32f* sorted = data->get_ord_var_data(node, vi); 20526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const char* dir = (char*)data->direction->data.ptr; 20536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n1 = node->get_num_valid(vi); 20546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // LL - number of samples that both the primary and the surrogate splits send to the left 20556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // LR - ... primary split sends to the left and the surrogate split sends to the right 20566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // RL - ... primary split sends to the right and the surrogate split sends to the left 20576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // RR - ... both send to the right 20586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, best_i = -1, best_inversed = 0; 20596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double best_val; 20606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 20616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !data->have_priors ) 20626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 20636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int LL = 0, RL = 0, LR, RR; 20646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int worst_val = cvFloor(node->maxlr), _best_val = worst_val; 20656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int sum = 0, sum_abs = 0; 20666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 20676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n1; i++ ) 20686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 20696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int d = dir[sorted[i].i]; 20706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum += d; sum_abs += d & 1; 20716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 20726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 20736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // sum_abs = R + L; sum = R - L 20746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn RR = (sum_abs + sum) >> 1; 20756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn LR = (sum_abs - sum) >> 1; 20766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 20776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // initially all the samples are sent to the right by the surrogate split, 20786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // LR of them are sent to the left by primary split, and RR - to the right. 20796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // now iteratively compute LL, LR, RL and RR for every possible surrogate split value. 20806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n1 - 1; i++ ) 20816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 20826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int d = dir[sorted[i].i]; 20836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 20846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( d < 0 ) 20856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 20866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn LL++; LR--; 20876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( LL + RR > _best_val && sorted[i].val + epsilon < sorted[i+1].val ) 20886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 20896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val = LL + RR; 20906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_i = i; best_inversed = 0; 20916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 20926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 20936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( d > 0 ) 20946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 20956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn RL++; RR--; 20966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( RL + LR > _best_val && sorted[i].val + epsilon < sorted[i+1].val ) 20976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 20986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val = RL + LR; 20996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_i = i; best_inversed = 1; 21006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 21016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 21026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 21036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val = _best_val; 21046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 21056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 21066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 21076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double LL = 0, RL = 0, LR, RR; 21086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double worst_val = node->maxlr; 21096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sum = 0, sum_abs = 0; 21106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* priors = data->priors_mult->data.db; 21116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* responses = data->get_class_labels(node); 21126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val = worst_val; 21136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 21146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n1; i++ ) 21156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 21166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = sorted[i].i; 21176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = priors[responses[idx]]; 21186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int d = dir[idx]; 21196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum += d*w; sum_abs += (d & 1)*w; 21206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 21216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 21226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // sum_abs = R + L; sum = R - L 21236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn RR = (sum_abs + sum)*0.5; 21246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn LR = (sum_abs - sum)*0.5; 21256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 21266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // initially all the samples are sent to the right by the surrogate split, 21276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // LR of them are sent to the left by primary split, and RR - to the right. 21286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // now iteratively compute LL, LR, RL and RR for every possible surrogate split value. 21296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n1 - 1; i++ ) 21306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 21316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = sorted[i].i; 21326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = priors[responses[idx]]; 21336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int d = dir[idx]; 21346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 21356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( d < 0 ) 21366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 21376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn LL += w; LR -= w; 21386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( LL + RR > best_val && sorted[i].val + epsilon < sorted[i+1].val ) 21396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 21406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val = LL + RR; 21416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_i = i; best_inversed = 0; 21426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 21436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 21446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( d > 0 ) 21456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 21466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn RL += w; RR -= w; 21476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( RL + LR > best_val && sorted[i].val + epsilon < sorted[i+1].val ) 21486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 21496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val = RL + LR; 21506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_i = i; best_inversed = 1; 21516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 21526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 21536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 21546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 21556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 21566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi, 21576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i, 21586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_inversed, (float)best_val ) : 0; 21596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 21606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 21616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 21626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi ) 21636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 21646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* labels = data->get_cat_var_data(node, vi); 21656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const char* dir = (char*)data->direction->data.ptr; 21666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n = node->sample_count; 21676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // LL - number of samples that both the primary and the surrogate splits send to the left 21686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // LR - ... primary split sends to the left and the surrogate split sends to the right 21696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // RL - ... primary split sends to the right and the surrogate split sends to the left 21706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // RR - ... both send to the right 21716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeSplit* split = data->new_split_cat( vi, 0 ); 21726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, mi = data->cat_count->data.i[data->get_var_type(vi)], l_win = 0; 21736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double best_val = 0; 21746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1; 21756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* rc = lc + mi + 1; 21766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 21776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = -1; i < mi; i++ ) 21786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lc[i] = rc[i] = 0; 21796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 21806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // for each category calculate the weight of samples 21816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // sent to the left (lc) and to the right (rc) by the primary split 21826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !data->have_priors ) 21836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 21846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* _lc = (int*)cvStackAlloc((mi+2)*2*sizeof(_lc[0])) + 1; 21856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* _rc = _lc + mi + 1; 21866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 21876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = -1; i < mi; i++ ) 21886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _lc[i] = _rc[i] = 0; 21896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 21906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 21916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 21926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = labels[i]; 21936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int d = dir[i]; 21946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int sum = _lc[idx] + d; 21956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int sum_abs = _rc[idx] + (d & 1); 21966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _lc[idx] = sum; _rc[idx] = sum_abs; 21976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 21986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 21996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < mi; i++ ) 22006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 22016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int sum = _lc[i]; 22026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int sum_abs = _rc[i]; 22036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lc[i] = (sum_abs - sum) >> 1; 22046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rc[i] = (sum_abs + sum) >> 1; 22056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 22066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 22076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 22086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 22096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* priors = data->priors_mult->data.db; 22106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* responses = data->get_class_labels(node); 22116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 22126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 22136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 22146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = labels[i]; 22156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = priors[responses[i]]; 22166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int d = dir[i]; 22176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sum = lc[idx] + d*w; 22186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sum_abs = rc[idx] + (d & 1)*w; 22196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lc[idx] = sum; rc[idx] = sum_abs; 22206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 22216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 22226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < mi; i++ ) 22236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 22246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sum = lc[i]; 22256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sum_abs = rc[i]; 22266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lc[i] = (sum_abs - sum) * 0.5; 22276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rc[i] = (sum_abs + sum) * 0.5; 22286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 22296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 22306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 22316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // 2. now form the split. 22326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // in each category send all the samples to the same direction as majority 22336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < mi; i++ ) 22346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 22356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double lval = lc[i], rval = rc[i]; 22366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( lval > rval ) 22376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 22386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->subset[i >> 5] |= 1 << (i & 31); 22396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val += lval; 22406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn l_win++; 22416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 22426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 22436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val += rval; 22446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 22456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 22466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->quality = (float)best_val; 22476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( split->quality <= node->maxlr || l_win == 0 || l_win == mi ) 22486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSetRemoveByPtr( data->split_heap, split ), split = 0; 22496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 22506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return split; 22516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 22526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 22536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 22546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::calc_node_value( CvDTreeNode* node ) 22556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 22566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds; 22576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* cv_labels = data->get_labels(node); 22586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 22596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( data->is_classifier ) 22606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 22616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // in case of classification tree: 22626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // * node value is the label of the class that has the largest weight in the node. 22636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // * node risk is the weighted number of misclassified samples, 22646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // * j-th cross-validation fold value and risk are calculated as above, 22656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // but using the samples with cv_labels(*)!=j. 22666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // * j-th cross-validation fold error is calculated as the weighted number of 22676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // misclassified samples with cv_labels(*)==j. 22686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 22696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // compute the number of instances of each class 22706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* cls_count = data->counts->data.i; 22716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* responses = data->get_class_labels(node); 22726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int m = data->get_num_classes(); 22736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* cv_cls_count = (int*)cvStackAlloc(m*cv_n*sizeof(cv_cls_count[0])); 22746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double max_val = -1, total_weight = 0; 22756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int max_k = -1; 22766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* priors = data->priors_mult->data.db; 22776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 22786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < m; k++ ) 22796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cls_count[k] = 0; 22806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 22816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cv_n == 0 ) 22826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 22836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 22846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cls_count[responses[i]]++; 22856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 22866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 22876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 22886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < cv_n; j++ ) 22896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < m; k++ ) 22906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cv_cls_count[j*m + k] = 0; 22916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 22926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 22936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 22946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn j = cv_labels[i]; k = responses[i]; 22956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cv_cls_count[j*m + k]++; 22966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 22976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 22986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < cv_n; j++ ) 22996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < m; k++ ) 23006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cls_count[k] += cv_cls_count[j*m + k]; 23016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 23026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 23036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( data->have_priors && node->parent == 0 ) 23046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 23056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // compute priors_mult from priors, take the sample ratio into account. 23066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sum = 0; 23076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < m; k++ ) 23086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 23096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n_k = cls_count[k]; 23106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn priors[k] = data->priors->data.db[k]*(n_k ? 1./n_k : 0.); 23116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum += priors[k]; 23126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 23136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum = 1./sum; 23146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < m; k++ ) 23156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn priors[k] *= sum; 23166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 23176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 23186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < m; k++ ) 23196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 23206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double val = cls_count[k]*priors[k]; 23216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn total_weight += val; 23226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( max_val < val ) 23236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 23246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_val = val; 23256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_k = k; 23266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 23276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 23286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 23296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->class_idx = max_k; 23306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->value = data->cat_map->data.i[ 23316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data->cat_ofs->data.i[data->cat_var_count] + max_k]; 23326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->node_risk = total_weight - max_val; 23336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 23346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < cv_n; j++ ) 23356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 23366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sum_k = 0, sum = 0, max_val_k = 0; 23376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_val = -1; max_k = -1; 23386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 23396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < m; k++ ) 23406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 23416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = priors[k]; 23426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double val_k = cv_cls_count[j*m + k]*w; 23436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double val = cls_count[k]*w - val_k; 23446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum_k += val_k; 23456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum += val; 23466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( max_val < val ) 23476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 23486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_val = val; 23496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_val_k = val_k; 23506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_k = k; 23516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 23526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 23536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 23546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->cv_Tn[j] = INT_MAX; 23556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->cv_node_risk[j] = sum - max_val; 23566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->cv_node_error[j] = sum_k - max_val_k; 23576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 23586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 23596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 23606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 23616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // in case of regression tree: 23626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // * node value is 1/n*sum_i(Y_i), where Y_i is i-th response, 23636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // n is the number of samples in the node. 23646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2) 23656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // * j-th cross-validation fold value and risk are calculated as above, 23666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // but using the samples with cv_labels(*)!=j. 23676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // * j-th cross-validation fold error is calculated 23686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // using samples with cv_labels(*)==j as the test subset: 23696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2), 23706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // where node_value_j is the node value calculated 23716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // as described in the previous bullet, and summation is done 23726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // over the samples with cv_labels(*)==j. 23736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 23746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sum = 0, sum2 = 0; 23756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float* values = data->get_ord_responses(node); 23766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double *cv_sum = 0, *cv_sum2 = 0; 23776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* cv_count = 0; 23786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 23796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cv_n == 0 ) 23806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 23816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 23826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 23836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double t = values[i]; 23846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum += t; 23856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum2 += t*t; 23866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 23876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 23886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 23896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 23906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cv_sum = (double*)cvStackAlloc( cv_n*sizeof(cv_sum[0]) ); 23916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cv_sum2 = (double*)cvStackAlloc( cv_n*sizeof(cv_sum2[0]) ); 23926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cv_count = (int*)cvStackAlloc( cv_n*sizeof(cv_count[0]) ); 23936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 23946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < cv_n; j++ ) 23956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 23966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cv_sum[j] = cv_sum2[j] = 0.; 23976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cv_count[j] = 0; 23986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 23996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 24006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 24016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 24026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn j = cv_labels[i]; 24036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double t = values[i]; 24046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double s = cv_sum[j] + t; 24056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double s2 = cv_sum2[j] + t*t; 24066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int nc = cv_count[j] + 1; 24076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cv_sum[j] = s; 24086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cv_sum2[j] = s2; 24096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cv_count[j] = nc; 24106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 24116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 24126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < cv_n; j++ ) 24136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 24146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum += cv_sum[j]; 24156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum2 += cv_sum2[j]; 24166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 24176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 24186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 24196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->node_risk = sum2 - (sum/n)*sum; 24206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->value = sum/n; 24216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 24226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < cv_n; j++ ) 24236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 24246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double s = cv_sum[j], si = sum - s; 24256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double s2 = cv_sum2[j], s2i = sum2 - s2; 24266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int c = cv_count[j], ci = n - c; 24276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double r = si/MAX(ci,1); 24286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->cv_node_risk[j] = s2i - r*r*ci; 24296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->cv_node_error[j] = s2 - 2*r*s + c*r*r; 24306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->cv_Tn[j] = INT_MAX; 24316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 24326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 24336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 24346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 24356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 24366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::complete_node_dir( CvDTreeNode* node ) 24376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 24386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1; 24396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int nz = n - node->get_num_valid(node->split->var_idx); 24406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn char* dir = (char*)data->direction->data.ptr; 24416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 24426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // try to complete direction using surrogate splits 24436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( nz && data->params.use_surrogates ) 24446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 24456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeSplit* split = node->split->next; 24466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ; split != 0 && nz; split = split->next ) 24476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 24486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int inversed_mask = split->inversed ? -1 : 0; 24496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn vi = split->var_idx; 24506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 24516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( data->get_var_type(vi) >= 0 ) // split on categorical var 24526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 24536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* labels = data->get_cat_var_data(node, vi); 24546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* subset = split->subset; 24556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 24566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 24576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 24586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx; 24596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !dir[i] && (idx = labels[i]) >= 0 ) 24606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 24616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int d = CV_DTREE_CAT_DIR(idx,subset); 24626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dir[i] = (char)((d ^ inversed_mask) - inversed_mask); 24636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( --nz ) 24646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 24656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 24666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 24676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 24686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else // split on ordered var 24696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 24706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvPair32s32f* sorted = data->get_ord_var_data(node, vi); 24716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int split_point = split->ord.split_point; 24726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n1 = node->get_num_valid(vi); 24736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 24746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn assert( 0 <= split_point && split_point < n-1 ); 24756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 24766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n1; i++ ) 24776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 24786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = sorted[i].i; 24796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !dir[idx] ) 24806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 24816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int d = i <= split_point ? -1 : 1; 24826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dir[idx] = (char)((d ^ inversed_mask) - inversed_mask); 24836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( --nz ) 24846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 24856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 24866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 24876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 24886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 24896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 24906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 24916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // find the default direction for the rest 24926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( nz ) 24936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 24946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = nr = 0; i < n; i++ ) 24956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn nr += dir[i] > 0; 24966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn nl = n - nr - nz; 24976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn d0 = nl > nr ? -1 : nr > nl; 24986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 24996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 25006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // make sure that every sample is directed either to the left or to the right 25016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 25026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 25036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int d = dir[i]; 25046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !d ) 25056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 25066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn d = d0; 25076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !d ) 25086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn d = d1, d1 = -d1; 25096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 25106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn d = d > 0; 25116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dir[i] = (char)d; // remap (-1,1) to (0,1) 25126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 25136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 25146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 25156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 25166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::split_node_data( CvDTreeNode* node ) 25176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 25186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int vi, i, n = node->sample_count, nl, nr; 25196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn char* dir = (char*)data->direction->data.ptr; 25206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeNode *left = 0, *right = 0; 25216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* new_idx = data->split_buf->data.i; 25226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int new_buf_idx = data->get_child_buf_idx( node ); 25236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int work_var_count = data->get_work_var_count(); 25246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 25256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // speedup things a little, especially for tree ensembles with a lots of small trees: 25266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // do not physically split the input data between the left and right child nodes 25276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // when we are not going to split them further, 25286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // as calc_node_value() does not requires input features anyway. 25296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bool split_input_data; 25306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 25316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn complete_node_dir(node); 25326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 25336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = nl = nr = 0; i < n; i++ ) 25346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 25356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int d = dir[i]; 25366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // initialize new indices for splitting ordered variables 25376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li 25386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn nr += d; 25396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn nl += d^1; 25406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 25416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 25426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->left = left = data->new_node( node, nl, new_buf_idx, node->offset ); 25436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + 25446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn (data->ord_var_count + work_var_count)*nl ); 25456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 25466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split_input_data = node->depth + 1 < data->params.max_depth && 25476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn (node->left->sample_count > data->params.min_sample_count || 25486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->right->sample_count > data->params.min_sample_count); 25496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 25506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // split ordered variables, keep both halves sorted. 25516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( vi = 0; vi < data->var_count; vi++ ) 25526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 25536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int ci = data->get_var_type(vi); 25546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n1 = node->get_num_valid(vi); 25556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvPair32s32f *src, *ldst0, *rdst0, *ldst, *rdst; 25566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvPair32s32f tl, tr; 25576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 25586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ci >= 0 || !split_input_data ) 25596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn continue; 25606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 25616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn src = data->get_ord_var_data(node, vi); 25626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ldst0 = ldst = data->get_ord_var_data(left, vi); 25636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rdst0 = rdst = data->get_ord_var_data(right, vi); 25646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn tl = ldst0[nl]; tr = rdst0[nr]; 25656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 25666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // split sorted 25676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n1; i++ ) 25686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 25696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = src[i].i; 25706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float val = src[i].val; 25716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int d = dir[idx]; 25726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn idx = new_idx[idx]; 25736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ldst->i = rdst->i = idx; 25746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ldst->val = rdst->val = val; 25756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ldst += d^1; 25766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rdst += d; 25776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 25786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 25796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn left->set_num_valid(vi, (int)(ldst - ldst0)); 25806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn right->set_num_valid(vi, (int)(rdst - rdst0)); 25816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 25826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // split missing 25836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ; i < n; i++ ) 25846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 25856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = src[i].i; 25866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int d = dir[idx]; 25876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn idx = new_idx[idx]; 25886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ldst->i = rdst->i = idx; 25896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ldst->val = rdst->val = ord_nan; 25906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ldst += d^1; 25916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rdst += d; 25926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 25936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 25946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ldst0[nl] = tl; rdst0[nr] = tr; 25956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 25966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 25976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // split categorical vars, responses and cv_labels using new_idx relocation table 25986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( vi = 0; vi < work_var_count; vi++ ) 25996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 26006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int ci = data->get_var_type(vi); 26016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n1 = node->get_num_valid(vi), nr1 = 0; 26026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int *src, *ldst0, *rdst0, *ldst, *rdst; 26036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int tl, tr; 26046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 26056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ci < 0 || (vi < data->var_count && !split_input_data) ) 26066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn continue; 26076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 26086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn src = data->get_cat_var_data(node, vi); 26096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ldst0 = ldst = data->get_cat_var_data(left, vi); 26106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rdst0 = rdst = data->get_cat_var_data(right, vi); 26116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn tl = ldst0[nl]; tr = rdst0[nr]; 26126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 26136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 26146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 26156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int d = dir[i]; 26166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int val = src[i]; 26176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn *ldst = *rdst = val; 26186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ldst += d^1; 26196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rdst += d; 26206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn nr1 += (val >= 0)&d; 26216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 26226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 26236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( vi < data->var_count ) 26246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 26256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn left->set_num_valid(vi, n1 - nr1); 26266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn right->set_num_valid(vi, nr1); 26276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 26286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 26296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ldst0[nl] = tl; rdst0[nr] = tr; 26306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 26316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 26326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // deallocate the parent node data that is not needed anymore 26336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data->free_node_data(node); 26346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 26356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 26366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 26376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::prune_cv() 26386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 26396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* ab = 0; 26406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* temp = 0; 26416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* err_jk = 0; 26426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 26436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}. 26446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // 2. choose the best tree index (if need, apply 1SE rule). 26456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // 3. store the best index and cut the branches. 26466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 26476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvDTree::prune_cv" ); 26486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 26496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 26506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 26516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count; 26526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // currently, 1SE for regression is not implemented 26536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier; 26546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* err; 26556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double min_err = 0, min_err_se = 0; 26566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int min_idx = -1; 26576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 26586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( ab = cvCreateMat( 1, 256, CV_64F )); 26596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 26606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // build the main tree sequence, calculate alpha's 26616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for(;;tree_count++) 26626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 26636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double min_alpha = update_tree_rnc(tree_count, -1); 26646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cut_tree(tree_count, -1, min_alpha) ) 26656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 26666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 26676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ab->cols <= tree_count ) 26686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 26696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F )); 26706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ti = 0; ti < ab->cols; ti++ ) 26716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn temp->data.db[ti] = ab->data.db[ti]; 26726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &ab ); 26736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ab = temp; 26746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn temp = 0; 26756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 26766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 26776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ab->data.db[tree_count] = min_alpha; 26786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 26796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 26806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ab->data.db[0] = 0.; 26816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 26826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( tree_count > 0 ) 26836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 26846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ti = 1; ti < tree_count-1; ti++ ) 26856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]); 26866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ab->data.db[tree_count-1] = DBL_MAX*0.5; 26876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 26886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F )); 26896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn err = err_jk->data.db; 26906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 26916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < cv_n; j++ ) 26926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 26936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int tj = 0, tk = 0; 26946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ; tk < tree_count; tj++ ) 26956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 26966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double min_alpha = update_tree_rnc(tj, j); 26976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cut_tree(tj, j, min_alpha) ) 26986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn min_alpha = DBL_MAX; 26996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 27006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ; tk < tree_count; tk++ ) 27016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 27026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ab->data.db[tk] > min_alpha ) 27036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 27046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn err[j*tree_count + tk] = root->tree_error; 27056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 27066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 27076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 27086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 27096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ti = 0; ti < tree_count; ti++ ) 27106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 27116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sum_err = 0; 27126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < cv_n; j++ ) 27136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum_err += err[j*tree_count + ti]; 27146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ti == 0 || sum_err < min_err ) 27156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 27166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn min_err = sum_err; 27176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn min_idx = ti; 27186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( use_1se ) 27196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn min_err_se = sqrt( sum_err*(n - sum_err) ); 27206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 27216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( sum_err < min_err + min_err_se ) 27226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn min_idx = ti; 27236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 27246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 27256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 27266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn pruned_tree_idx = min_idx; 27276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn free_prune_data(data->params.truncate_pruned_tree != 0); 27286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 27296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 27306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 27316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &err_jk ); 27326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &ab ); 27336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &temp ); 27346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 27356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 27366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 27376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renndouble CvDTree::update_tree_rnc( int T, int fold ) 27386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 27396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeNode* node = root; 27406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double min_alpha = DBL_MAX; 27416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 27426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for(;;) 27436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 27446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeNode* parent; 27456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for(;;) 27466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 27476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn; 27486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( t <= T || !node->left ) 27496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 27506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->complexity = 1; 27516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->tree_risk = node->node_risk; 27526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->tree_error = 0.; 27536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( fold >= 0 ) 27546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 27556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->tree_risk = node->cv_node_risk[fold]; 27566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->tree_error = node->cv_node_error[fold]; 27576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 27586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 27596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 27606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node = node->left; 27616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 27626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 27636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( parent = node->parent; parent && parent->right == node; 27646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node = parent, parent = parent->parent ) 27656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 27666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn parent->complexity += node->complexity; 27676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn parent->tree_risk += node->tree_risk; 27686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn parent->tree_error += node->tree_error; 27696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 27706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk) 27716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn - parent->tree_risk)/(parent->complexity - 1); 27726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn min_alpha = MIN( min_alpha, parent->alpha ); 27736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 27746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 27756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !parent ) 27766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 27776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 27786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn parent->complexity = node->complexity; 27796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn parent->tree_risk = node->tree_risk; 27806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn parent->tree_error = node->tree_error; 27816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node = parent->right; 27826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 27836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 27846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return min_alpha; 27856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 27866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 27876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 27886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvDTree::cut_tree( int T, int fold, double min_alpha ) 27896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 27906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeNode* node = root; 27916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !node->left ) 27926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return 1; 27936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 27946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for(;;) 27956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 27966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeNode* parent; 27976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for(;;) 27986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 27996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn; 28006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( t <= T || !node->left ) 28016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 28026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( node->alpha <= min_alpha + FLT_EPSILON ) 28036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 28046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( fold >= 0 ) 28056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->cv_Tn[fold] = T; 28066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 28076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->Tn = T; 28086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( node == root ) 28096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return 1; 28106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 28116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 28126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node = node->left; 28136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 28146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 28156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( parent = node->parent; parent && parent->right == node; 28166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node = parent, parent = parent->parent ) 28176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ; 28186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 28196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !parent ) 28206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 28216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 28226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node = parent->right; 28236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 28246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 28256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return 0; 28266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 28276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 28286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 28296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::free_prune_data(bool cut_tree) 28306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 28316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeNode* node = root; 28326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 28336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for(;;) 28346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 28356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeNode* parent; 28366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for(;;) 28376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 28386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn ) 28396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // as we will clear the whole cross-validation heap at the end 28406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->cv_Tn = 0; 28416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->cv_node_error = node->cv_node_risk = 0; 28426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !node->left ) 28436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 28446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node = node->left; 28456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 28466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 28476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( parent = node->parent; parent && parent->right == node; 28486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node = parent, parent = parent->parent ) 28496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 28506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cut_tree && parent->Tn <= pruned_tree_idx ) 28516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 28526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data->free_node( parent->left ); 28536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data->free_node( parent->right ); 28546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn parent->left = parent->right = 0; 28556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 28566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 28576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 28586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !parent ) 28596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 28606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 28616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node = parent->right; 28626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 28636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 28646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( data->cv_heap ) 28656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvClearSet( data->cv_heap ); 28666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 28676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 28686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 28696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::free_tree() 28706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 28716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( root && data && data->shared ) 28726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 28736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn pruned_tree_idx = INT_MIN; 28746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn free_prune_data(true); 28756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data->free_node(root); 28766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn root = 0; 28776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 28786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 28796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 28806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 28816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeNode* CvDTree::predict( const CvMat* _sample, 28826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* _missing, bool preprocessed_input ) const 28836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 28846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeNode* result = 0; 28856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* catbuf = 0; 28866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 28876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvDTree::predict" ); 28886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 28896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 28906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 28916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, step, mstep = 0; 28926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float* sample; 28936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const uchar* m = 0; 28946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeNode* node = root; 28956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* vtype; 28966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* vidx; 28976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* cmap; 28986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* cofs; 28996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 29006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !node ) 29016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsError, "The tree has not been trained yet" ); 29026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 29036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 || 29046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _sample->cols != 1 && _sample->rows != 1 || 29056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input || 29066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input ) 29076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, 29086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "the input sample must be 1d floating-point vector with the same " 29096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "number of elements as the total number of variables used for training" ); 29106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 29116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sample = _sample->data.fl; 29126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]); 29136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 29146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( data->cat_count && !preprocessed_input ) // cache for categorical variables 29156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 29166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n = data->cat_count->cols; 29176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn catbuf = (int*)cvStackAlloc(n*sizeof(catbuf[0])); 29186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 29196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn catbuf[i] = -1; 29206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 29216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 29226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( _missing ) 29236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 29246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) || 29256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn !CV_ARE_SIZES_EQ(_missing, _sample) ) 29266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, 29276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "the missing data mask must be 8-bit vector of the same size as input sample" ); 29286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn m = _missing->data.ptr; 29296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]); 29306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 29316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 29326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn vtype = data->var_type->data.i; 29336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0; 29346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cmap = data->cat_map ? data->cat_map->data.i : 0; 29356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cofs = data->cat_ofs ? data->cat_ofs->data.i : 0; 29366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 29376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn while( node->Tn > pruned_tree_idx && node->left ) 29386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 29396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeSplit* split = node->split; 29406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int dir = 0; 29416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ; !dir && split != 0; split = split->next ) 29426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 29436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int vi = split->var_idx; 29446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int ci = vtype[vi]; 29456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn i = vidx ? vidx[vi] : vi; 29466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float val = sample[i*step]; 29476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( m && m[i*mstep] ) 29486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn continue; 29496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ci < 0 ) // ordered 29506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dir = val <= split->ord.c ? -1 : 1; 29516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else // categorical 29526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 29536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int c; 29546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( preprocessed_input ) 29556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn c = cvRound(val); 29566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 29576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 29586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn c = catbuf[ci]; 29596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( c < 0 ) 29606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 29616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int a = c = cofs[ci]; 29626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int b = cofs[ci+1]; 29636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int ival = cvRound(val); 29646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ival != val ) 29656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, 29666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "one of input categorical variable is not an integer" ); 29676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 29686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn while( a < b ) 29696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 29706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn c = (a + b) >> 1; 29716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ival < cmap[c] ) 29726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn b = c; 29736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( ival > cmap[c] ) 29746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn a = c+1; 29756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 29766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 29776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 29786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 29796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( c < 0 || ival != cmap[c] ) 29806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn continue; 29816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 29826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn catbuf[ci] = c -= cofs[ci]; 29836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 29846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 29856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dir = CV_DTREE_CAT_DIR(c, split->subset); 29866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 29876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 29886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( split->inversed ) 29896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dir = -dir; 29906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 29916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 29926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !dir ) 29936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 29946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double diff = node->right->sample_count - node->left->sample_count; 29956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dir = diff < 0 ? -1 : 1; 29966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 29976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node = dir < 0 ? node->left : node->right; 29986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 29996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn result = node; 30016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 30036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return result; 30056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 30066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennconst CvMat* CvDTree::get_var_importance() 30096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 30106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !var_importance ) 30116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 30126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeNode* node = root; 30136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* importance; 30146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !node ) 30156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return 0; 30166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn var_importance = cvCreateMat( 1, data->var_count, CV_64F ); 30176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvZero( var_importance ); 30186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn importance = var_importance->data.db; 30196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for(;;) 30216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 30226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeNode* parent; 30236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ;; node = node->left ) 30246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 30256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeSplit* split = node->split; 30266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !node->left || node->Tn <= pruned_tree_idx ) 30286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 30296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ; split != 0; split = split->next ) 30316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn importance[split->var_idx] += split->quality; 30326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 30336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( parent = node->parent; parent && parent->right == node; 30356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node = parent, parent = parent->parent ) 30366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ; 30376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !parent ) 30396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 30406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node = parent->right; 30426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 30436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvNormalize( var_importance, var_importance, 1., 0, CV_L1 ); 30456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 30466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return var_importance; 30486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 30496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split ) 30526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 30536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int ci; 30546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW ); 30566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "var", split->var_idx ); 30576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteReal( fs, "quality", split->quality ); 30586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ci = data->get_var_type(split->var_idx); 30606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ci >= 0 ) // split on a categorical var 30616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 30626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir; 30636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 30646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn to_right += CV_DTREE_CAT_DIR(i,split->subset) > 0; 30656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // ad-hoc rule when to use inverse categorical split notation 30676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // to achieve more compact and clear representation 30686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1; 30696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ? 30716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW ); 30726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 30746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 30756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int dir = CV_DTREE_CAT_DIR(i,split->subset); 30766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( dir*default_dir < 0 ) 30776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, 0, i ); 30786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 30796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvEndWriteStruct( fs ); 30806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 30816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 30826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c ); 30836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvEndWriteStruct( fs ); 30856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 30866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node ) 30896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 30906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeSplit* split; 30916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartWriteStruct( fs, 0, CV_NODE_MAP ); 30936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "depth", node->depth ); 30956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "sample_count", node->sample_count ); 30966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteReal( fs, "value", node->value ); 30976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 30986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( data->is_classifier ) 30996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "norm_class_idx", node->class_idx ); 31006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "Tn", node->Tn ); 31026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "complexity", node->complexity ); 31036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteReal( fs, "alpha", node->alpha ); 31046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteReal( fs, "node_risk", node->node_risk ); 31056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteReal( fs, "tree_risk", node->tree_risk ); 31066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteReal( fs, "tree_error", node->tree_error ); 31076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( node->left ) 31096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 31106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartWriteStruct( fs, "splits", CV_NODE_SEQ ); 31116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( split = node->split; split != 0; split = split->next ) 31136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn write_split( fs, split ); 31146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvEndWriteStruct( fs ); 31166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 31176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvEndWriteStruct( fs ); 31196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 31206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::write_tree_nodes( CvFileStorage* fs ) 31236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 31246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn //CV_FUNCNAME( "CvDTree::write_tree_nodes" ); 31256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 31276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeNode* node = root; 31296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // traverse the tree and save all the nodes in depth-first order 31316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for(;;) 31326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 31336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeNode* parent; 31346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for(;;) 31356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 31366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn write_node( fs, node ); 31376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !node->left ) 31386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 31396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node = node->left; 31406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 31416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( parent = node->parent; parent && parent->right == node; 31436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node = parent, parent = parent->parent ) 31446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ; 31456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !parent ) 31476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 31486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node = parent->right; 31506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 31516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 31536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 31546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::write( CvFileStorage* fs, const char* name ) 31576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 31586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn //CV_FUNCNAME( "CvDTree::write" ); 31596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 31616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE ); 31636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn get_var_importance(); 31656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data->write_params( fs ); 31666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( var_importance ) 31676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWrite( fs, "var_importance", var_importance ); 31686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn write( fs ); 31696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvEndWriteStruct( fs ); 31716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 31736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 31746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::write( CvFileStorage* fs ) 31776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 31786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn //CV_FUNCNAME( "CvDTree::write" ); 31796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 31816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "best_tree_idx", pruned_tree_idx ); 31836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ ); 31856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn write_tree_nodes( fs ); 31866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvEndWriteStruct( fs ); 31876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 31896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 31906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode ) 31936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 31946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeSplit* split = 0; 31956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvDTree::read_split" ); 31976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 31986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 31996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 32006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int vi, ci; 32016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 32026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP ) 32036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" ); 32046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 32056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn vi = cvReadIntByName( fs, fnode, "var", -1 ); 32066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( (unsigned)vi >= (unsigned)data->var_count ) 32076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" ); 32086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 32096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ci = data->get_var_type(vi); 32106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ci >= 0 ) // split on categorical var 32116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 32126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, n = data->cat_count->data.i[ci], inversed = 0, val; 32136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvSeqReader reader; 32146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvFileNode* inseq; 32156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split = data->new_split_cat( vi, 0 ); 32166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn inseq = cvGetFileNodeByName( fs, fnode, "in" ); 32176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !inseq ) 32186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 32196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn inseq = cvGetFileNodeByName( fs, fnode, "not_in" ); 32206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn inversed = 1; 32216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 32226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !inseq || 32236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn (CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ && CV_NODE_TYPE(inseq->tag) != CV_NODE_INT)) 32246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsParseError, 32256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "Either 'in' or 'not_in' tags should be inside a categorical split data" ); 32266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 32276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( CV_NODE_TYPE(inseq->tag) == CV_NODE_INT ) 32286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 32296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn val = inseq->data.i; 32306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( (unsigned)val >= (unsigned)n ) 32316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" ); 32326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 32336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->subset[val >> 5] |= 1 << (val & 31); 32346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 32356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 32366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 32376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartReadSeq( inseq->data.seq, &reader ); 32386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 32396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < reader.seq->total; i++ ) 32406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 32416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvFileNode* inode = (CvFileNode*)reader.ptr; 32426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn val = inode->data.i; 32436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n ) 32446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" ); 32456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 32466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->subset[val >> 5] |= 1 << (val & 31); 32476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader ); 32486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 32496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 32506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 32516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // for categorical splits we do not use inversed splits, 32526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // instead we inverse the variable set in the split 32536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( inversed ) 32546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < (n + 31) >> 5; i++ ) 32556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->subset[i] ^= -1; 32566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 32576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 32586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 32596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvFileNode* cmp_node; 32606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split = data->new_split_ord( vi, 0, 0, 0, 0 ); 32616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 32626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cmp_node = cvGetFileNodeByName( fs, fnode, "le" ); 32636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !cmp_node ) 32646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 32656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cmp_node = cvGetFileNodeByName( fs, fnode, "gt" ); 32666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->inversed = 1; 32676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 32686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 32696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->ord.c = (float)cvReadReal( cmp_node ); 32706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 32716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 32726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->quality = (float)cvReadRealByName( fs, fnode, "quality" ); 32736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 32746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 32756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 32766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return split; 32776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 32786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 32796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 32806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent ) 32816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 32826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeNode* node = 0; 32836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 32846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvDTree::read_node" ); 32856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 32866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 32876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 32886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvFileNode* splits; 32896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, depth; 32906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 32916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP ) 32926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" ); 32936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 32946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( node = data->new_node( parent, 0, 0, 0 )); 32956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn depth = cvReadIntByName( fs, fnode, "depth", -1 ); 32966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( depth != node->depth ) 32976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsParseError, "incorrect node depth" ); 32986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 32996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->sample_count = cvReadIntByName( fs, fnode, "sample_count" ); 33006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->value = cvReadRealByName( fs, fnode, "value" ); 33016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( data->is_classifier ) 33026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" ); 33036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->Tn = cvReadIntByName( fs, fnode, "Tn" ); 33056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->complexity = cvReadIntByName( fs, fnode, "complexity" ); 33066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->alpha = cvReadRealByName( fs, fnode, "alpha" ); 33076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->node_risk = cvReadRealByName( fs, fnode, "node_risk" ); 33086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" ); 33096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->tree_error = cvReadRealByName( fs, fnode, "tree_error" ); 33106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn splits = cvGetFileNodeByName( fs, fnode, "splits" ); 33126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( splits ) 33136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 33146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvSeqReader reader; 33156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeSplit* last_split = 0; 33166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ ) 33186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" ); 33196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartReadSeq( splits->data.seq, &reader ); 33216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < reader.seq->total; i++ ) 33226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 33236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeSplit* split; 33246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr )); 33256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !last_split ) 33266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->split = last_split = split; 33276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 33286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn last_split = last_split->next = split; 33296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader ); 33316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 33326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 33336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 33356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return node; 33376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 33386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode ) 33416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 33426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvDTree::read_tree_nodes" ); 33436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 33456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvSeqReader reader; 33476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeNode _root; 33486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeNode* parent = &_root; 33496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i; 33506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn parent->left = parent->right = parent->parent = 0; 33516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartReadSeq( fnode->data.seq, &reader ); 33536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < reader.seq->total; i++ ) 33556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 33566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeNode* node; 33576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 )); 33596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !parent->left ) 33606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn parent->left = node; 33616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 33626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn parent->right = node; 33636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( node->split ) 33646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn parent = node; 33656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 33666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 33676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn while( parent && parent->right ) 33686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn parent = parent->parent; 33696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 33706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader ); 33726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 33736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn root = _root.left; 33756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 33776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 33786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::read( CvFileStorage* fs, CvFileNode* fnode ) 33816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 33826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeTrainData* _data = new CvDTreeTrainData(); 33836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _data->read_params( fs, fnode ); 33846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn read( fs, fnode, _data ); 33866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn get_var_importance(); 33876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 33886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// a special entry point for reading weak decision trees from the tree ensembles 33916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data ) 33926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 33936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvDTree::read" ); 33946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 33966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvFileNode* tree_nodes; 33986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 33996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn clear(); 34006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data = _data; 34016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 34026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn tree_nodes = cvGetFileNodeByName( fs, node, "nodes" ); 34036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ ) 34046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsParseError, "nodes tag is missing" ); 34056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 34066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 ); 34076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn read_tree_nodes( fs, tree_nodes ); 34086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 34096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 34106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 34116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 34126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/* End of file. */ 3413