1/*M/////////////////////////////////////////////////////////////////////////////////////// 2// 3// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 4// 5// By downloading, copying, installing or using the software you agree to this license. 6// If you do not agree to this license, do not download, install, 7// copy or use the software. 8// 9// 10// Intel License Agreement 11// 12// Copyright (C) 2000, Intel Corporation, all rights reserved. 13// Third party copyrights are property of their respective owners. 14// 15// Redistribution and use in source and binary forms, with or without modification, 16// are permitted provided that the following conditions are met: 17// 18// * Redistribution's of source code must retain the above copyright notice, 19// this list of conditions and the following disclaimer. 20// 21// * Redistribution's in binary form must reproduce the above copyright notice, 22// this list of conditions and the following disclaimer in the documentation 23// and/or other materials provided with the distribution. 24// 25// * The name of Intel Corporation may not be used to endorse or promote products 26// derived from this software without specific prior written permission. 27// 28// This software is provided by the copyright holders and contributors "as is" and 29// any express or implied warranties, including, but not limited to, the implied 30// warranties of merchantability and fitness for a particular purpose are disclaimed. 31// In no event shall the Intel Corporation or contributors be liable for any direct, 32// indirect, incidental, special, exemplary, or consequential damages 33// (including, but not limited to, procurement of substitute goods or services; 34// loss of use, data, or profits; or business interruption) however caused 35// and on any theory of liability, whether in contract, strict liability, 36// or tort (including negligence or otherwise) arising in any way out of 37// the use of this software, even if advised of the possibility of such damage. 38// 39//M*/ 40 41#include "_ml.h" 42 43static const float ord_nan = FLT_MAX*0.5f; 44static const int min_block_size = 1 << 16; 45static const int block_size_delta = 1 << 10; 46 47CvDTreeTrainData::CvDTreeTrainData() 48{ 49 var_idx = var_type = cat_count = cat_ofs = cat_map = 50 priors = priors_mult = counts = buf = direction = split_buf = 0; 51 tree_storage = temp_storage = 0; 52 53 clear(); 54} 55 56 57CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag, 58 const CvMat* _responses, const CvMat* _var_idx, 59 const CvMat* _sample_idx, const CvMat* _var_type, 60 const CvMat* _missing_mask, const CvDTreeParams& _params, 61 bool _shared, bool _add_labels ) 62{ 63 var_idx = var_type = cat_count = cat_ofs = cat_map = 64 priors = priors_mult = counts = buf = direction = split_buf = 0; 65 tree_storage = temp_storage = 0; 66 67 set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx, 68 _var_type, _missing_mask, _params, _shared, _add_labels ); 69} 70 71 72CvDTreeTrainData::~CvDTreeTrainData() 73{ 74 clear(); 75} 76 77 78bool CvDTreeTrainData::set_params( const CvDTreeParams& _params ) 79{ 80 bool ok = false; 81 82 CV_FUNCNAME( "CvDTreeTrainData::set_params" ); 83 84 __BEGIN__; 85 86 // set parameters 87 params = _params; 88 89 if( params.max_categories < 2 ) 90 CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" ); 91 params.max_categories = MIN( params.max_categories, 15 ); 92 93 if( params.max_depth < 0 ) 94 CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" ); 95 params.max_depth = MIN( params.max_depth, 25 ); 96 97 params.min_sample_count = MAX(params.min_sample_count,1); 98 99 if( params.cv_folds < 0 ) 100 CV_ERROR( CV_StsOutOfRange, 101 "params.cv_folds should be =0 (the tree is not pruned) " 102 "or n>0 (tree is pruned using n-fold cross-validation)" ); 103 104 if( params.cv_folds == 1 ) 105 params.cv_folds = 0; 106 107 if( params.regression_accuracy < 0 ) 108 CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" ); 109 110 ok = true; 111 112 __END__; 113 114 return ok; 115} 116 117 118#define CV_CMP_NUM_PTR(a,b) (*(a) < *(b)) 119static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int ) 120static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int ) 121 122#define CV_CMP_PAIRS(a,b) ((a).val < (b).val) 123static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair32s32f, CV_CMP_PAIRS, int ) 124 125void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag, 126 const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx, 127 const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params, 128 bool _shared, bool _add_labels, bool _update_data ) 129{ 130 CvMat* sample_idx = 0; 131 CvMat* var_type0 = 0; 132 CvMat* tmp_map = 0; 133 int** int_ptr = 0; 134 CvDTreeTrainData* data = 0; 135 136 CV_FUNCNAME( "CvDTreeTrainData::set_data" ); 137 138 __BEGIN__; 139 140 int sample_all = 0, r_type = 0, cv_n; 141 int total_c_count = 0; 142 int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0; 143 int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step 144 int vi, i; 145 char err[100]; 146 const int *sidx = 0, *vidx = 0; 147 148 if( _update_data && data_root ) 149 { 150 data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx, 151 _sample_idx, _var_type, _missing_mask, _params, _shared, _add_labels ); 152 153 // compare new and old train data 154 if( !(data->var_count == var_count && 155 cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON && 156 cvNorm( data->cat_count, cat_count, CV_C ) < FLT_EPSILON && 157 cvNorm( data->cat_map, cat_map, CV_C ) < FLT_EPSILON) ) 158 CV_ERROR( CV_StsBadArg, 159 "The new training data must have the same types and the input and output variables " 160 "and the same categories for categorical variables" ); 161 162 cvReleaseMat( &priors ); 163 cvReleaseMat( &priors_mult ); 164 cvReleaseMat( &buf ); 165 cvReleaseMat( &direction ); 166 cvReleaseMat( &split_buf ); 167 cvReleaseMemStorage( &temp_storage ); 168 169 priors = data->priors; data->priors = 0; 170 priors_mult = data->priors_mult; data->priors_mult = 0; 171 buf = data->buf; data->buf = 0; 172 buf_count = data->buf_count; buf_size = data->buf_size; 173 sample_count = data->sample_count; 174 175 direction = data->direction; data->direction = 0; 176 split_buf = data->split_buf; data->split_buf = 0; 177 temp_storage = data->temp_storage; data->temp_storage = 0; 178 nv_heap = data->nv_heap; cv_heap = data->cv_heap; 179 180 data_root = new_node( 0, sample_count, 0, 0 ); 181 EXIT; 182 } 183 184 clear(); 185 186 var_all = 0; 187 rng = cvRNG(-1); 188 189 CV_CALL( set_params( _params )); 190 191 // check parameter types and sizes 192 CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all )); 193 if( _tflag == CV_ROW_SAMPLE ) 194 { 195 ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type); 196 dv_step = 1; 197 if( _missing_mask ) 198 ms_step = _missing_mask->step, mv_step = 1; 199 } 200 else 201 { 202 dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type); 203 ds_step = 1; 204 if( _missing_mask ) 205 mv_step = _missing_mask->step, ms_step = 1; 206 } 207 208 sample_count = sample_all; 209 var_count = var_all; 210 211 if( _sample_idx ) 212 { 213 CV_CALL( sample_idx = cvPreprocessIndexArray( _sample_idx, sample_all )); 214 sidx = sample_idx->data.i; 215 sample_count = sample_idx->rows + sample_idx->cols - 1; 216 } 217 218 if( _var_idx ) 219 { 220 CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all )); 221 vidx = var_idx->data.i; 222 var_count = var_idx->rows + var_idx->cols - 1; 223 } 224 225 if( !CV_IS_MAT(_responses) || 226 (CV_MAT_TYPE(_responses->type) != CV_32SC1 && 227 CV_MAT_TYPE(_responses->type) != CV_32FC1) || 228 _responses->rows != 1 && _responses->cols != 1 || 229 _responses->rows + _responses->cols - 1 != sample_all ) 230 CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or " 231 "floating-point vector containing as many elements as " 232 "the total number of samples in the training data matrix" ); 233 234 CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_all, &r_type )); 235 CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 )); 236 237 cat_var_count = 0; 238 ord_var_count = -1; 239 240 is_classifier = r_type == CV_VAR_CATEGORICAL; 241 242 // step 0. calc the number of categorical vars 243 for( vi = 0; vi < var_count; vi++ ) 244 { 245 var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ? 246 cat_var_count++ : ord_var_count--; 247 } 248 249 ord_var_count = ~ord_var_count; 250 cv_n = params.cv_folds; 251 // set the two last elements of var_type array to be able 252 // to locate responses and cross-validation labels using 253 // the corresponding get_* functions. 254 var_type->data.i[var_count] = cat_var_count; 255 var_type->data.i[var_count+1] = cat_var_count+1; 256 257 // in case of single ordered predictor we need dummy cv_labels 258 // for safe split_node_data() operation 259 have_labels = cv_n > 0 || ord_var_count == 1 && cat_var_count == 0 || _add_labels; 260 261 buf_size = (ord_var_count + get_work_var_count())*sample_count + 2; 262 shared = _shared; 263 buf_count = shared ? 3 : 2; 264 CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 )); 265 CV_CALL( cat_count = cvCreateMat( 1, cat_var_count+1, CV_32SC1 )); 266 CV_CALL( cat_ofs = cvCreateMat( 1, cat_count->cols+1, CV_32SC1 )); 267 CV_CALL( cat_map = cvCreateMat( 1, cat_count->cols*10 + 128, CV_32SC1 )); 268 269 // now calculate the maximum size of split, 270 // create memory storage that will keep nodes and splits of the decision tree 271 // allocate root node and the buffer for the whole training data 272 max_split_size = cvAlign(sizeof(CvDTreeSplit) + 273 (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*)); 274 tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size); 275 tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size); 276 CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size )); 277 CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage )); 278 279 nv_size = var_count*sizeof(int); 280 nv_size = MAX( nv_size, (int)sizeof(CvSetElem) ); 281 282 temp_block_size = nv_size; 283 284 if( cv_n ) 285 { 286 if( sample_count < cv_n*MAX(params.min_sample_count,10) ) 287 CV_ERROR( CV_StsOutOfRange, 288 "The many folds in cross-validation for such a small dataset" ); 289 290 cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) ); 291 temp_block_size = MAX(temp_block_size, cv_size); 292 } 293 294 temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size ); 295 CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size )); 296 CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage )); 297 if( cv_size ) 298 CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage )); 299 300 CV_CALL( data_root = new_node( 0, sample_count, 0, 0 )); 301 CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) )); 302 303 max_c_count = 1; 304 305 // transform the training data to convenient representation 306 for( vi = 0; vi <= var_count; vi++ ) 307 { 308 int ci; 309 const uchar* mask = 0; 310 int m_step = 0, step; 311 const int* idata = 0; 312 const float* fdata = 0; 313 int num_valid = 0; 314 315 if( vi < var_count ) // analyze i-th input variable 316 { 317 int vi0 = vidx ? vidx[vi] : vi; 318 ci = get_var_type(vi); 319 step = ds_step; m_step = ms_step; 320 if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 ) 321 idata = _train_data->data.i + vi0*dv_step; 322 else 323 fdata = _train_data->data.fl + vi0*dv_step; 324 if( _missing_mask ) 325 mask = _missing_mask->data.ptr + vi0*mv_step; 326 } 327 else // analyze _responses 328 { 329 ci = cat_var_count; 330 step = CV_IS_MAT_CONT(_responses->type) ? 331 1 : _responses->step / CV_ELEM_SIZE(_responses->type); 332 if( CV_MAT_TYPE(_responses->type) == CV_32SC1 ) 333 idata = _responses->data.i; 334 else 335 fdata = _responses->data.fl; 336 } 337 338 if( vi < var_count && ci >= 0 || 339 vi == var_count && is_classifier ) // process categorical variable or response 340 { 341 int c_count, prev_label; 342 int* c_map, *dst = get_cat_var_data( data_root, vi ); 343 344 // copy data 345 for( i = 0; i < sample_count; i++ ) 346 { 347 int val = INT_MAX, si = sidx ? sidx[i] : i; 348 if( !mask || !mask[si*m_step] ) 349 { 350 if( idata ) 351 val = idata[si*step]; 352 else 353 { 354 float t = fdata[si*step]; 355 val = cvRound(t); 356 if( val != t ) 357 { 358 sprintf( err, "%d-th value of %d-th (categorical) " 359 "variable is not an integer", i, vi ); 360 CV_ERROR( CV_StsBadArg, err ); 361 } 362 } 363 364 if( val == INT_MAX ) 365 { 366 sprintf( err, "%d-th value of %d-th (categorical) " 367 "variable is too large", i, vi ); 368 CV_ERROR( CV_StsBadArg, err ); 369 } 370 num_valid++; 371 } 372 dst[i] = val; 373 int_ptr[i] = dst + i; 374 } 375 376 // sort all the values, including the missing measurements 377 // that should all move to the end 378 icvSortIntPtr( int_ptr, sample_count, 0 ); 379 //qsort( int_ptr, sample_count, sizeof(int_ptr[0]), icvCmpIntPtr ); 380 381 c_count = num_valid > 0; 382 383 // count the categories 384 for( i = 1; i < num_valid; i++ ) 385 c_count += *int_ptr[i] != *int_ptr[i-1]; 386 387 if( vi > 0 ) 388 max_c_count = MAX( max_c_count, c_count ); 389 cat_count->data.i[ci] = c_count; 390 cat_ofs->data.i[ci] = total_c_count; 391 392 // resize cat_map, if need 393 if( cat_map->cols < total_c_count + c_count ) 394 { 395 tmp_map = cat_map; 396 CV_CALL( cat_map = cvCreateMat( 1, 397 MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 )); 398 for( i = 0; i < total_c_count; i++ ) 399 cat_map->data.i[i] = tmp_map->data.i[i]; 400 cvReleaseMat( &tmp_map ); 401 } 402 403 c_map = cat_map->data.i + total_c_count; 404 total_c_count += c_count; 405 406 // compact the class indices and build the map 407 prev_label = ~*int_ptr[0]; 408 c_count = -1; 409 410 for( i = 0; i < num_valid; i++ ) 411 { 412 int cur_label = *int_ptr[i]; 413 if( cur_label != prev_label ) 414 c_map[++c_count] = prev_label = cur_label; 415 *int_ptr[i] = c_count; 416 } 417 418 // replace labels for missing values with -1 419 for( ; i < sample_count; i++ ) 420 *int_ptr[i] = -1; 421 } 422 else if( ci < 0 ) // process ordered variable 423 { 424 CvPair32s32f* dst = get_ord_var_data( data_root, vi ); 425 426 for( i = 0; i < sample_count; i++ ) 427 { 428 float val = ord_nan; 429 int si = sidx ? sidx[i] : i; 430 if( !mask || !mask[si*m_step] ) 431 { 432 if( idata ) 433 val = (float)idata[si*step]; 434 else 435 val = fdata[si*step]; 436 437 if( fabs(val) >= ord_nan ) 438 { 439 sprintf( err, "%d-th value of %d-th (ordered) " 440 "variable (=%g) is too large", i, vi, val ); 441 CV_ERROR( CV_StsBadArg, err ); 442 } 443 num_valid++; 444 } 445 dst[i].i = i; 446 dst[i].val = val; 447 } 448 449 icvSortPairs( dst, sample_count, 0 ); 450 } 451 else // special case: process ordered response, 452 // it will be stored similarly to categorical vars (i.e. no pairs) 453 { 454 float* dst = get_ord_responses( data_root ); 455 456 for( i = 0; i < sample_count; i++ ) 457 { 458 float val = ord_nan; 459 int si = sidx ? sidx[i] : i; 460 if( idata ) 461 val = (float)idata[si*step]; 462 else 463 val = fdata[si*step]; 464 465 if( fabs(val) >= ord_nan ) 466 { 467 sprintf( err, "%d-th value of %d-th (ordered) " 468 "variable (=%g) is out of range", i, vi, val ); 469 CV_ERROR( CV_StsBadArg, err ); 470 } 471 dst[i] = val; 472 } 473 474 cat_count->data.i[cat_var_count] = 0; 475 cat_ofs->data.i[cat_var_count] = total_c_count; 476 num_valid = sample_count; 477 } 478 479 if( vi < var_count ) 480 data_root->set_num_valid(vi, num_valid); 481 } 482 483 if( cv_n ) 484 { 485 int* dst = get_labels(data_root); 486 CvRNG* r = &rng; 487 488 for( i = vi = 0; i < sample_count; i++ ) 489 { 490 dst[i] = vi++; 491 vi &= vi < cv_n ? -1 : 0; 492 } 493 494 for( i = 0; i < sample_count; i++ ) 495 { 496 int a = cvRandInt(r) % sample_count; 497 int b = cvRandInt(r) % sample_count; 498 CV_SWAP( dst[a], dst[b], vi ); 499 } 500 } 501 502 cat_map->cols = MAX( total_c_count, 1 ); 503 504 max_split_size = cvAlign(sizeof(CvDTreeSplit) + 505 (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*)); 506 CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage )); 507 508 have_priors = is_classifier && params.priors; 509 if( is_classifier ) 510 { 511 int m = get_num_classes(); 512 double sum = 0; 513 CV_CALL( priors = cvCreateMat( 1, m, CV_64F )); 514 for( i = 0; i < m; i++ ) 515 { 516 double val = have_priors ? params.priors[i] : 1.; 517 if( val <= 0 ) 518 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" ); 519 priors->data.db[i] = val; 520 sum += val; 521 } 522 523 // normalize weights 524 if( have_priors ) 525 cvScale( priors, priors, 1./sum ); 526 527 CV_CALL( priors_mult = cvCloneMat( priors )); 528 CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 )); 529 } 530 531 CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 )); 532 CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 )); 533 534 __END__; 535 536 if( data ) 537 delete data; 538 539 cvFree( &int_ptr ); 540 cvReleaseMat( &sample_idx ); 541 cvReleaseMat( &var_type0 ); 542 cvReleaseMat( &tmp_map ); 543} 544 545 546CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx ) 547{ 548 CvDTreeNode* root = 0; 549 CvMat* isubsample_idx = 0; 550 CvMat* subsample_co = 0; 551 552 CV_FUNCNAME( "CvDTreeTrainData::subsample_data" ); 553 554 __BEGIN__; 555 556 if( !data_root ) 557 CV_ERROR( CV_StsError, "No training data has been set" ); 558 559 if( _subsample_idx ) 560 CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count )); 561 562 if( !isubsample_idx ) 563 { 564 // make a copy of the root node 565 CvDTreeNode temp; 566 int i; 567 root = new_node( 0, 1, 0, 0 ); 568 temp = *root; 569 *root = *data_root; 570 root->num_valid = temp.num_valid; 571 if( root->num_valid ) 572 { 573 for( i = 0; i < var_count; i++ ) 574 root->num_valid[i] = data_root->num_valid[i]; 575 } 576 root->cv_Tn = temp.cv_Tn; 577 root->cv_node_risk = temp.cv_node_risk; 578 root->cv_node_error = temp.cv_node_error; 579 } 580 else 581 { 582 int* sidx = isubsample_idx->data.i; 583 // co - array of count/offset pairs (to handle duplicated values in _subsample_idx) 584 int* co, cur_ofs = 0; 585 int vi, i, total = data_root->sample_count; 586 int count = isubsample_idx->rows + isubsample_idx->cols - 1; 587 int work_var_count = get_work_var_count(); 588 root = new_node( 0, count, 1, 0 ); 589 590 CV_CALL( subsample_co = cvCreateMat( 1, total*2, CV_32SC1 )); 591 cvZero( subsample_co ); 592 co = subsample_co->data.i; 593 for( i = 0; i < count; i++ ) 594 co[sidx[i]*2]++; 595 for( i = 0; i < total; i++ ) 596 { 597 if( co[i*2] ) 598 { 599 co[i*2+1] = cur_ofs; 600 cur_ofs += co[i*2]; 601 } 602 else 603 co[i*2+1] = -1; 604 } 605 606 for( vi = 0; vi < work_var_count; vi++ ) 607 { 608 int ci = get_var_type(vi); 609 610 if( ci >= 0 || vi >= var_count ) 611 { 612 const int* src = get_cat_var_data( data_root, vi ); 613 int* dst = get_cat_var_data( root, vi ); 614 int num_valid = 0; 615 616 for( i = 0; i < count; i++ ) 617 { 618 int val = src[sidx[i]]; 619 dst[i] = val; 620 num_valid += val >= 0; 621 } 622 623 if( vi < var_count ) 624 root->set_num_valid(vi, num_valid); 625 } 626 else 627 { 628 const CvPair32s32f* src = get_ord_var_data( data_root, vi ); 629 CvPair32s32f* dst = get_ord_var_data( root, vi ); 630 int j = 0, idx, count_i; 631 int num_valid = data_root->get_num_valid(vi); 632 633 for( i = 0; i < num_valid; i++ ) 634 { 635 idx = src[i].i; 636 count_i = co[idx*2]; 637 if( count_i ) 638 { 639 float val = src[i].val; 640 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ ) 641 { 642 dst[j].val = val; 643 dst[j].i = cur_ofs; 644 } 645 } 646 } 647 648 root->set_num_valid(vi, j); 649 650 for( ; i < total; i++ ) 651 { 652 idx = src[i].i; 653 count_i = co[idx*2]; 654 if( count_i ) 655 { 656 float val = src[i].val; 657 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ ) 658 { 659 dst[j].val = val; 660 dst[j].i = cur_ofs; 661 } 662 } 663 } 664 } 665 } 666 } 667 668 __END__; 669 670 cvReleaseMat( &isubsample_idx ); 671 cvReleaseMat( &subsample_co ); 672 673 return root; 674} 675 676 677void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx, 678 float* values, uchar* missing, 679 float* responses, bool get_class_idx ) 680{ 681 CvMat* subsample_idx = 0; 682 CvMat* subsample_co = 0; 683 684 CV_FUNCNAME( "CvDTreeTrainData::get_vectors" ); 685 686 __BEGIN__; 687 688 int i, vi, total = sample_count, count = total, cur_ofs = 0; 689 int* sidx = 0; 690 int* co = 0; 691 692 if( _subsample_idx ) 693 { 694 CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count )); 695 sidx = subsample_idx->data.i; 696 CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 )); 697 co = subsample_co->data.i; 698 cvZero( subsample_co ); 699 count = subsample_idx->cols + subsample_idx->rows - 1; 700 for( i = 0; i < count; i++ ) 701 co[sidx[i]*2]++; 702 for( i = 0; i < total; i++ ) 703 { 704 int count_i = co[i*2]; 705 if( count_i ) 706 { 707 co[i*2+1] = cur_ofs*var_count; 708 cur_ofs += count_i; 709 } 710 } 711 } 712 713 if( missing ) 714 memset( missing, 1, count*var_count ); 715 716 for( vi = 0; vi < var_count; vi++ ) 717 { 718 int ci = get_var_type(vi); 719 if( ci >= 0 ) // categorical 720 { 721 float* dst = values + vi; 722 uchar* m = missing ? missing + vi : 0; 723 const int* src = get_cat_var_data(data_root, vi); 724 725 for( i = 0; i < count; i++, dst += var_count ) 726 { 727 int idx = sidx ? sidx[i] : i; 728 int val = src[idx]; 729 *dst = (float)val; 730 if( m ) 731 { 732 *m = val < 0; 733 m += var_count; 734 } 735 } 736 } 737 else // ordered 738 { 739 float* dst = values + vi; 740 uchar* m = missing ? missing + vi : 0; 741 const CvPair32s32f* src = get_ord_var_data(data_root, vi); 742 int count1 = data_root->get_num_valid(vi); 743 744 for( i = 0; i < count1; i++ ) 745 { 746 int idx = src[i].i; 747 int count_i = 1; 748 if( co ) 749 { 750 count_i = co[idx*2]; 751 cur_ofs = co[idx*2+1]; 752 } 753 else 754 cur_ofs = idx*var_count; 755 if( count_i ) 756 { 757 float val = src[i].val; 758 for( ; count_i > 0; count_i--, cur_ofs += var_count ) 759 { 760 dst[cur_ofs] = val; 761 if( m ) 762 m[cur_ofs] = 0; 763 } 764 } 765 } 766 } 767 } 768 769 // copy responses 770 if( responses ) 771 { 772 if( is_classifier ) 773 { 774 const int* src = get_class_labels(data_root); 775 for( i = 0; i < count; i++ ) 776 { 777 int idx = sidx ? sidx[i] : i; 778 int val = get_class_idx ? src[idx] : 779 cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]]; 780 responses[i] = (float)val; 781 } 782 } 783 else 784 { 785 const float* src = get_ord_responses(data_root); 786 for( i = 0; i < count; i++ ) 787 { 788 int idx = sidx ? sidx[i] : i; 789 responses[i] = src[idx]; 790 } 791 } 792 } 793 794 __END__; 795 796 cvReleaseMat( &subsample_idx ); 797 cvReleaseMat( &subsample_co ); 798} 799 800 801CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count, 802 int storage_idx, int offset ) 803{ 804 CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap ); 805 806 node->sample_count = count; 807 node->depth = parent ? parent->depth + 1 : 0; 808 node->parent = parent; 809 node->left = node->right = 0; 810 node->split = 0; 811 node->value = 0; 812 node->class_idx = 0; 813 node->maxlr = 0.; 814 815 node->buf_idx = storage_idx; 816 node->offset = offset; 817 if( nv_heap ) 818 node->num_valid = (int*)cvSetNew( nv_heap ); 819 else 820 node->num_valid = 0; 821 node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.; 822 node->complexity = 0; 823 824 if( params.cv_folds > 0 && cv_heap ) 825 { 826 int cv_n = params.cv_folds; 827 node->Tn = INT_MAX; 828 node->cv_Tn = (int*)cvSetNew( cv_heap ); 829 node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double)); 830 node->cv_node_error = node->cv_node_risk + cv_n; 831 } 832 else 833 { 834 node->Tn = 0; 835 node->cv_Tn = 0; 836 node->cv_node_risk = 0; 837 node->cv_node_error = 0; 838 } 839 840 return node; 841} 842 843 844CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val, 845 int split_point, int inversed, float quality ) 846{ 847 CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap ); 848 split->var_idx = vi; 849 split->ord.c = cmp_val; 850 split->ord.split_point = split_point; 851 split->inversed = inversed; 852 split->quality = quality; 853 split->next = 0; 854 855 return split; 856} 857 858 859CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality ) 860{ 861 CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap ); 862 int i, n = (max_c_count + 31)/32; 863 864 split->var_idx = vi; 865 split->inversed = 0; 866 split->quality = quality; 867 for( i = 0; i < n; i++ ) 868 split->subset[i] = 0; 869 split->next = 0; 870 871 return split; 872} 873 874 875void CvDTreeTrainData::free_node( CvDTreeNode* node ) 876{ 877 CvDTreeSplit* split = node->split; 878 free_node_data( node ); 879 while( split ) 880 { 881 CvDTreeSplit* next = split->next; 882 cvSetRemoveByPtr( split_heap, split ); 883 split = next; 884 } 885 node->split = 0; 886 cvSetRemoveByPtr( node_heap, node ); 887} 888 889 890void CvDTreeTrainData::free_node_data( CvDTreeNode* node ) 891{ 892 if( node->num_valid ) 893 { 894 cvSetRemoveByPtr( nv_heap, node->num_valid ); 895 node->num_valid = 0; 896 } 897 // do not free cv_* fields, as all the cross-validation related data is released at once. 898} 899 900 901void CvDTreeTrainData::free_train_data() 902{ 903 cvReleaseMat( &counts ); 904 cvReleaseMat( &buf ); 905 cvReleaseMat( &direction ); 906 cvReleaseMat( &split_buf ); 907 cvReleaseMemStorage( &temp_storage ); 908 cv_heap = nv_heap = 0; 909} 910 911 912void CvDTreeTrainData::clear() 913{ 914 free_train_data(); 915 916 cvReleaseMemStorage( &tree_storage ); 917 918 cvReleaseMat( &var_idx ); 919 cvReleaseMat( &var_type ); 920 cvReleaseMat( &cat_count ); 921 cvReleaseMat( &cat_ofs ); 922 cvReleaseMat( &cat_map ); 923 cvReleaseMat( &priors ); 924 cvReleaseMat( &priors_mult ); 925 926 node_heap = split_heap = 0; 927 928 sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0; 929 have_labels = have_priors = is_classifier = false; 930 931 buf_count = buf_size = 0; 932 shared = false; 933 934 data_root = 0; 935 936 rng = cvRNG(-1); 937} 938 939 940int CvDTreeTrainData::get_num_classes() const 941{ 942 return is_classifier ? cat_count->data.i[cat_var_count] : 0; 943} 944 945 946int CvDTreeTrainData::get_var_type(int vi) const 947{ 948 return var_type->data.i[vi]; 949} 950 951 952int CvDTreeTrainData::get_work_var_count() const 953{ 954 return var_count + 1 + (have_labels ? 1 : 0); 955} 956 957CvPair32s32f* CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi ) 958{ 959 int oi = ~get_var_type(vi); 960 assert( 0 <= oi && oi < ord_var_count ); 961 return (CvPair32s32f*)(buf->data.i + n->buf_idx*buf->cols + 962 n->offset + oi*n->sample_count*2); 963} 964 965 966int* CvDTreeTrainData::get_class_labels( CvDTreeNode* n ) 967{ 968 return get_cat_var_data( n, var_count ); 969} 970 971 972float* CvDTreeTrainData::get_ord_responses( CvDTreeNode* n ) 973{ 974 return (float*)get_cat_var_data( n, var_count ); 975} 976 977 978int* CvDTreeTrainData::get_labels( CvDTreeNode* n ) 979{ 980 return have_labels ? get_cat_var_data( n, var_count + 1 ) : 0; 981} 982 983 984int* CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi ) 985{ 986 int ci = get_var_type(vi); 987 assert( 0 <= ci && ci <= cat_var_count + 1 ); 988 return buf->data.i + n->buf_idx*buf->cols + n->offset + 989 (ord_var_count*2 + ci)*n->sample_count; 990} 991 992 993int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n ) 994{ 995 int idx = n->buf_idx + 1; 996 if( idx >= buf_count ) 997 idx = shared ? 1 : 0; 998 return idx; 999} 1000 1001 1002void CvDTreeTrainData::write_params( CvFileStorage* fs ) 1003{ 1004 CV_FUNCNAME( "CvDTreeTrainData::write_params" ); 1005 1006 __BEGIN__; 1007 1008 int vi, vcount = var_count; 1009 1010 cvWriteInt( fs, "is_classifier", is_classifier ? 1 : 0 ); 1011 cvWriteInt( fs, "var_all", var_all ); 1012 cvWriteInt( fs, "var_count", var_count ); 1013 cvWriteInt( fs, "ord_var_count", ord_var_count ); 1014 cvWriteInt( fs, "cat_var_count", cat_var_count ); 1015 1016 cvStartWriteStruct( fs, "training_params", CV_NODE_MAP ); 1017 cvWriteInt( fs, "use_surrogates", params.use_surrogates ? 1 : 0 ); 1018 1019 if( is_classifier ) 1020 { 1021 cvWriteInt( fs, "max_categories", params.max_categories ); 1022 } 1023 else 1024 { 1025 cvWriteReal( fs, "regression_accuracy", params.regression_accuracy ); 1026 } 1027 1028 cvWriteInt( fs, "max_depth", params.max_depth ); 1029 cvWriteInt( fs, "min_sample_count", params.min_sample_count ); 1030 cvWriteInt( fs, "cross_validation_folds", params.cv_folds ); 1031 1032 if( params.cv_folds > 1 ) 1033 { 1034 cvWriteInt( fs, "use_1se_rule", params.use_1se_rule ? 1 : 0 ); 1035 cvWriteInt( fs, "truncate_pruned_tree", params.truncate_pruned_tree ? 1 : 0 ); 1036 } 1037 1038 if( priors ) 1039 cvWrite( fs, "priors", priors ); 1040 1041 cvEndWriteStruct( fs ); 1042 1043 if( var_idx ) 1044 cvWrite( fs, "var_idx", var_idx ); 1045 1046 cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW ); 1047 1048 for( vi = 0; vi < vcount; vi++ ) 1049 cvWriteInt( fs, 0, var_type->data.i[vi] >= 0 ); 1050 1051 cvEndWriteStruct( fs ); 1052 1053 if( cat_count && (cat_var_count > 0 || is_classifier) ) 1054 { 1055 CV_ASSERT( cat_count != 0 ); 1056 cvWrite( fs, "cat_count", cat_count ); 1057 cvWrite( fs, "cat_map", cat_map ); 1058 } 1059 1060 __END__; 1061} 1062 1063 1064void CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node ) 1065{ 1066 CV_FUNCNAME( "CvDTreeTrainData::read_params" ); 1067 1068 __BEGIN__; 1069 1070 CvFileNode *tparams_node, *vartype_node; 1071 CvSeqReader reader; 1072 int vi, max_split_size, tree_block_size; 1073 1074 is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0); 1075 var_all = cvReadIntByName( fs, node, "var_all" ); 1076 var_count = cvReadIntByName( fs, node, "var_count", var_all ); 1077 cat_var_count = cvReadIntByName( fs, node, "cat_var_count" ); 1078 ord_var_count = cvReadIntByName( fs, node, "ord_var_count" ); 1079 1080 tparams_node = cvGetFileNodeByName( fs, node, "training_params" ); 1081 1082 if( tparams_node ) // training parameters are not necessary 1083 { 1084 params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0; 1085 1086 if( is_classifier ) 1087 { 1088 params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" ); 1089 } 1090 else 1091 { 1092 params.regression_accuracy = 1093 (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" ); 1094 } 1095 1096 params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" ); 1097 params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" ); 1098 params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" ); 1099 1100 if( params.cv_folds > 1 ) 1101 { 1102 params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0; 1103 params.truncate_pruned_tree = 1104 cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0; 1105 } 1106 1107 priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" ); 1108 if( priors ) 1109 { 1110 if( !CV_IS_MAT(priors) ) 1111 CV_ERROR( CV_StsParseError, "priors must stored as a matrix" ); 1112 priors_mult = cvCloneMat( priors ); 1113 } 1114 } 1115 1116 CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" )); 1117 if( var_idx ) 1118 { 1119 if( !CV_IS_MAT(var_idx) || 1120 var_idx->cols != 1 && var_idx->rows != 1 || 1121 var_idx->cols + var_idx->rows - 1 != var_count || 1122 CV_MAT_TYPE(var_idx->type) != CV_32SC1 ) 1123 CV_ERROR( CV_StsParseError, 1124 "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" ); 1125 1126 for( vi = 0; vi < var_count; vi++ ) 1127 if( (unsigned)var_idx->data.i[vi] >= (unsigned)var_all ) 1128 CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" ); 1129 } 1130 1131 ////// read var type 1132 CV_CALL( var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 )); 1133 1134 cat_var_count = 0; 1135 ord_var_count = -1; 1136 vartype_node = cvGetFileNodeByName( fs, node, "var_type" ); 1137 1138 if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && var_count == 1 ) 1139 var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--; 1140 else 1141 { 1142 if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ || 1143 vartype_node->data.seq->total != var_count ) 1144 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" ); 1145 1146 cvStartReadSeq( vartype_node->data.seq, &reader ); 1147 1148 for( vi = 0; vi < var_count; vi++ ) 1149 { 1150 CvFileNode* n = (CvFileNode*)reader.ptr; 1151 if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) ) 1152 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" ); 1153 var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--; 1154 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader ); 1155 } 1156 } 1157 var_type->data.i[var_count] = cat_var_count; 1158 1159 ord_var_count = ~ord_var_count; 1160 if( cat_var_count != cat_var_count || ord_var_count != ord_var_count ) 1161 CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" ); 1162 ////// 1163 1164 if( cat_var_count > 0 || is_classifier ) 1165 { 1166 int ccount, total_c_count = 0; 1167 CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" )); 1168 CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" )); 1169 1170 if( !CV_IS_MAT(cat_count) || !CV_IS_MAT(cat_map) || 1171 cat_count->cols != 1 && cat_count->rows != 1 || 1172 CV_MAT_TYPE(cat_count->type) != CV_32SC1 || 1173 cat_count->cols + cat_count->rows - 1 != cat_var_count + is_classifier || 1174 cat_map->cols != 1 && cat_map->rows != 1 || 1175 CV_MAT_TYPE(cat_map->type) != CV_32SC1 ) 1176 CV_ERROR( CV_StsParseError, 1177 "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" ); 1178 1179 ccount = cat_var_count + is_classifier; 1180 1181 CV_CALL( cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 )); 1182 cat_ofs->data.i[0] = 0; 1183 max_c_count = 1; 1184 1185 for( vi = 0; vi < ccount; vi++ ) 1186 { 1187 int val = cat_count->data.i[vi]; 1188 if( val <= 0 ) 1189 CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" ); 1190 max_c_count = MAX( max_c_count, val ); 1191 cat_ofs->data.i[vi+1] = total_c_count += val; 1192 } 1193 1194 if( cat_map->cols + cat_map->rows - 1 != total_c_count ) 1195 CV_ERROR( CV_StsBadSize, 1196 "cat_map vector length is not equal to the total number of categories in all categorical vars" ); 1197 } 1198 1199 max_split_size = cvAlign(sizeof(CvDTreeSplit) + 1200 (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*)); 1201 1202 tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size); 1203 tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size); 1204 CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size )); 1205 CV_CALL( node_heap = cvCreateSet( 0, sizeof(node_heap[0]), 1206 sizeof(CvDTreeNode), tree_storage )); 1207 CV_CALL( split_heap = cvCreateSet( 0, sizeof(split_heap[0]), 1208 max_split_size, tree_storage )); 1209 1210 __END__; 1211} 1212 1213 1214/////////////////////// Decision Tree ///////////////////////// 1215 1216CvDTree::CvDTree() 1217{ 1218 data = 0; 1219 var_importance = 0; 1220 default_model_name = "my_tree"; 1221 1222 clear(); 1223} 1224 1225 1226void CvDTree::clear() 1227{ 1228 cvReleaseMat( &var_importance ); 1229 if( data ) 1230 { 1231 if( !data->shared ) 1232 delete data; 1233 else 1234 free_tree(); 1235 data = 0; 1236 } 1237 root = 0; 1238 pruned_tree_idx = -1; 1239} 1240 1241 1242CvDTree::~CvDTree() 1243{ 1244 clear(); 1245} 1246 1247 1248const CvDTreeNode* CvDTree::get_root() const 1249{ 1250 return root; 1251} 1252 1253 1254int CvDTree::get_pruned_tree_idx() const 1255{ 1256 return pruned_tree_idx; 1257} 1258 1259 1260CvDTreeTrainData* CvDTree::get_data() 1261{ 1262 return data; 1263} 1264 1265 1266bool CvDTree::train( const CvMat* _train_data, int _tflag, 1267 const CvMat* _responses, const CvMat* _var_idx, 1268 const CvMat* _sample_idx, const CvMat* _var_type, 1269 const CvMat* _missing_mask, CvDTreeParams _params ) 1270{ 1271 bool result = false; 1272 1273 CV_FUNCNAME( "CvDTree::train" ); 1274 1275 __BEGIN__; 1276 1277 clear(); 1278 data = new CvDTreeTrainData( _train_data, _tflag, _responses, 1279 _var_idx, _sample_idx, _var_type, 1280 _missing_mask, _params, false ); 1281 CV_CALL( result = do_train(0)); 1282 1283 __END__; 1284 1285 return result; 1286} 1287 1288 1289bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx ) 1290{ 1291 bool result = false; 1292 1293 CV_FUNCNAME( "CvDTree::train" ); 1294 1295 __BEGIN__; 1296 1297 clear(); 1298 data = _data; 1299 data->shared = true; 1300 CV_CALL( result = do_train(_subsample_idx)); 1301 1302 __END__; 1303 1304 return result; 1305} 1306 1307 1308bool CvDTree::do_train( const CvMat* _subsample_idx ) 1309{ 1310 bool result = false; 1311 1312 CV_FUNCNAME( "CvDTree::do_train" ); 1313 1314 __BEGIN__; 1315 1316 root = data->subsample_data( _subsample_idx ); 1317 1318 CV_CALL( try_split_node(root)); 1319 1320 if( data->params.cv_folds > 0 ) 1321 CV_CALL( prune_cv()); 1322 1323 if( !data->shared ) 1324 data->free_train_data(); 1325 1326 result = true; 1327 1328 __END__; 1329 1330 return result; 1331} 1332 1333 1334void CvDTree::try_split_node( CvDTreeNode* node ) 1335{ 1336 CvDTreeSplit* best_split = 0; 1337 int i, n = node->sample_count, vi; 1338 bool can_split = true; 1339 double quality_scale; 1340 1341 calc_node_value( node ); 1342 1343 if( node->sample_count <= data->params.min_sample_count || 1344 node->depth >= data->params.max_depth ) 1345 can_split = false; 1346 1347 if( can_split && data->is_classifier ) 1348 { 1349 // check if we have a "pure" node, 1350 // we assume that cls_count is filled by calc_node_value() 1351 int* cls_count = data->counts->data.i; 1352 int nz = 0, m = data->get_num_classes(); 1353 for( i = 0; i < m; i++ ) 1354 nz += cls_count[i] != 0; 1355 if( nz == 1 ) // there is only one class 1356 can_split = false; 1357 } 1358 else if( can_split ) 1359 { 1360 if( sqrt(node->node_risk)/n < data->params.regression_accuracy ) 1361 can_split = false; 1362 } 1363 1364 if( can_split ) 1365 { 1366 best_split = find_best_split(node); 1367 // TODO: check the split quality ... 1368 node->split = best_split; 1369 } 1370 1371 if( !can_split || !best_split ) 1372 { 1373 data->free_node_data(node); 1374 return; 1375 } 1376 1377 quality_scale = calc_node_dir( node ); 1378 1379 if( data->params.use_surrogates ) 1380 { 1381 // find all the surrogate splits 1382 // and sort them by their similarity to the primary one 1383 for( vi = 0; vi < data->var_count; vi++ ) 1384 { 1385 CvDTreeSplit* split; 1386 int ci = data->get_var_type(vi); 1387 1388 if( vi == best_split->var_idx ) 1389 continue; 1390 1391 if( ci >= 0 ) 1392 split = find_surrogate_split_cat( node, vi ); 1393 else 1394 split = find_surrogate_split_ord( node, vi ); 1395 1396 if( split ) 1397 { 1398 // insert the split 1399 CvDTreeSplit* prev_split = node->split; 1400 split->quality = (float)(split->quality*quality_scale); 1401 1402 while( prev_split->next && 1403 prev_split->next->quality > split->quality ) 1404 prev_split = prev_split->next; 1405 split->next = prev_split->next; 1406 prev_split->next = split; 1407 } 1408 } 1409 } 1410 1411 split_node_data( node ); 1412 try_split_node( node->left ); 1413 try_split_node( node->right ); 1414} 1415 1416 1417// calculate direction (left(-1),right(1),missing(0)) 1418// for each sample using the best split 1419// the function returns scale coefficients for surrogate split quality factors. 1420// the scale is applied to normalize surrogate split quality relatively to the 1421// best (primary) split quality. That is, if a surrogate split is absolutely 1422// identical to the primary split, its quality will be set to the maximum value = 1423// quality of the primary split; otherwise, it will be lower. 1424// besides, the function compute node->maxlr, 1425// minimum possible quality (w/o considering the above mentioned scale) 1426// for a surrogate split. Surrogate splits with quality less than node->maxlr 1427// are not discarded. 1428double CvDTree::calc_node_dir( CvDTreeNode* node ) 1429{ 1430 char* dir = (char*)data->direction->data.ptr; 1431 int i, n = node->sample_count, vi = node->split->var_idx; 1432 double L, R; 1433 1434 assert( !node->split->inversed ); 1435 1436 if( data->get_var_type(vi) >= 0 ) // split on categorical var 1437 { 1438 const int* labels = data->get_cat_var_data(node,vi); 1439 const int* subset = node->split->subset; 1440 1441 if( !data->have_priors ) 1442 { 1443 int sum = 0, sum_abs = 0; 1444 1445 for( i = 0; i < n; i++ ) 1446 { 1447 int idx = labels[i]; 1448 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0; 1449 sum += d; sum_abs += d & 1; 1450 dir[i] = (char)d; 1451 } 1452 1453 R = (sum_abs + sum) >> 1; 1454 L = (sum_abs - sum) >> 1; 1455 } 1456 else 1457 { 1458 const int* responses = data->get_class_labels(node); 1459 const double* priors = data->priors_mult->data.db; 1460 double sum = 0, sum_abs = 0; 1461 1462 for( i = 0; i < n; i++ ) 1463 { 1464 int idx = labels[i]; 1465 double w = priors[responses[i]]; 1466 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0; 1467 sum += d*w; sum_abs += (d & 1)*w; 1468 dir[i] = (char)d; 1469 } 1470 1471 R = (sum_abs + sum) * 0.5; 1472 L = (sum_abs - sum) * 0.5; 1473 } 1474 } 1475 else // split on ordered var 1476 { 1477 const CvPair32s32f* sorted = data->get_ord_var_data(node,vi); 1478 int split_point = node->split->ord.split_point; 1479 int n1 = node->get_num_valid(vi); 1480 1481 assert( 0 <= split_point && split_point < n1-1 ); 1482 1483 if( !data->have_priors ) 1484 { 1485 for( i = 0; i <= split_point; i++ ) 1486 dir[sorted[i].i] = (char)-1; 1487 for( ; i < n1; i++ ) 1488 dir[sorted[i].i] = (char)1; 1489 for( ; i < n; i++ ) 1490 dir[sorted[i].i] = (char)0; 1491 1492 L = split_point-1; 1493 R = n1 - split_point + 1; 1494 } 1495 else 1496 { 1497 const int* responses = data->get_class_labels(node); 1498 const double* priors = data->priors_mult->data.db; 1499 L = R = 0; 1500 1501 for( i = 0; i <= split_point; i++ ) 1502 { 1503 int idx = sorted[i].i; 1504 double w = priors[responses[idx]]; 1505 dir[idx] = (char)-1; 1506 L += w; 1507 } 1508 1509 for( ; i < n1; i++ ) 1510 { 1511 int idx = sorted[i].i; 1512 double w = priors[responses[idx]]; 1513 dir[idx] = (char)1; 1514 R += w; 1515 } 1516 1517 for( ; i < n; i++ ) 1518 dir[sorted[i].i] = (char)0; 1519 } 1520 } 1521 1522 node->maxlr = MAX( L, R ); 1523 return node->split->quality/(L + R); 1524} 1525 1526 1527CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node ) 1528{ 1529 int vi; 1530 CvDTreeSplit *best_split = 0, *split = 0, *t; 1531 1532 for( vi = 0; vi < data->var_count; vi++ ) 1533 { 1534 int ci = data->get_var_type(vi); 1535 if( node->get_num_valid(vi) <= 1 ) 1536 continue; 1537 1538 if( data->is_classifier ) 1539 { 1540 if( ci >= 0 ) 1541 split = find_split_cat_class( node, vi ); 1542 else 1543 split = find_split_ord_class( node, vi ); 1544 } 1545 else 1546 { 1547 if( ci >= 0 ) 1548 split = find_split_cat_reg( node, vi ); 1549 else 1550 split = find_split_ord_reg( node, vi ); 1551 } 1552 1553 if( split ) 1554 { 1555 if( !best_split || best_split->quality < split->quality ) 1556 CV_SWAP( best_split, split, t ); 1557 if( split ) 1558 cvSetRemoveByPtr( data->split_heap, split ); 1559 } 1560 } 1561 1562 return best_split; 1563} 1564 1565 1566CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi ) 1567{ 1568 const float epsilon = FLT_EPSILON*2; 1569 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi); 1570 const int* responses = data->get_class_labels(node); 1571 int n = node->sample_count; 1572 int n1 = node->get_num_valid(vi); 1573 int m = data->get_num_classes(); 1574 const int* rc0 = data->counts->data.i; 1575 int* lc = (int*)cvStackAlloc(m*sizeof(lc[0])); 1576 int* rc = (int*)cvStackAlloc(m*sizeof(rc[0])); 1577 int i, best_i = -1; 1578 double lsum2 = 0, rsum2 = 0, best_val = 0; 1579 const double* priors = data->have_priors ? data->priors_mult->data.db : 0; 1580 1581 // init arrays of class instance counters on both sides of the split 1582 for( i = 0; i < m; i++ ) 1583 { 1584 lc[i] = 0; 1585 rc[i] = rc0[i]; 1586 } 1587 1588 // compensate for missing values 1589 for( i = n1; i < n; i++ ) 1590 rc[responses[sorted[i].i]]--; 1591 1592 if( !priors ) 1593 { 1594 int L = 0, R = n1; 1595 1596 for( i = 0; i < m; i++ ) 1597 rsum2 += (double)rc[i]*rc[i]; 1598 1599 for( i = 0; i < n1 - 1; i++ ) 1600 { 1601 int idx = responses[sorted[i].i]; 1602 int lv, rv; 1603 L++; R--; 1604 lv = lc[idx]; rv = rc[idx]; 1605 lsum2 += lv*2 + 1; 1606 rsum2 -= rv*2 - 1; 1607 lc[idx] = lv + 1; rc[idx] = rv - 1; 1608 1609 if( sorted[i].val + epsilon < sorted[i+1].val ) 1610 { 1611 double val = (lsum2*R + rsum2*L)/((double)L*R); 1612 if( best_val < val ) 1613 { 1614 best_val = val; 1615 best_i = i; 1616 } 1617 } 1618 } 1619 } 1620 else 1621 { 1622 double L = 0, R = 0; 1623 for( i = 0; i < m; i++ ) 1624 { 1625 double wv = rc[i]*priors[i]; 1626 R += wv; 1627 rsum2 += wv*wv; 1628 } 1629 1630 for( i = 0; i < n1 - 1; i++ ) 1631 { 1632 int idx = responses[sorted[i].i]; 1633 int lv, rv; 1634 double p = priors[idx], p2 = p*p; 1635 L += p; R -= p; 1636 lv = lc[idx]; rv = rc[idx]; 1637 lsum2 += p2*(lv*2 + 1); 1638 rsum2 -= p2*(rv*2 - 1); 1639 lc[idx] = lv + 1; rc[idx] = rv - 1; 1640 1641 if( sorted[i].val + epsilon < sorted[i+1].val ) 1642 { 1643 double val = (lsum2*R + rsum2*L)/((double)L*R); 1644 if( best_val < val ) 1645 { 1646 best_val = val; 1647 best_i = i; 1648 } 1649 } 1650 } 1651 } 1652 1653 return best_i >= 0 ? data->new_split_ord( vi, 1654 (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i, 1655 0, (float)best_val ) : 0; 1656} 1657 1658 1659void CvDTree::cluster_categories( const int* vectors, int n, int m, 1660 int* csums, int k, int* labels ) 1661{ 1662 // TODO: consider adding priors (class weights) and sample weights to the clustering algorithm 1663 int iters = 0, max_iters = 100; 1664 int i, j, idx; 1665 double* buf = (double*)cvStackAlloc( (n + k)*sizeof(buf[0]) ); 1666 double *v_weights = buf, *c_weights = buf + k; 1667 bool modified = true; 1668 CvRNG* r = &data->rng; 1669 1670 // assign labels randomly 1671 for( i = idx = 0; i < n; i++ ) 1672 { 1673 int sum = 0; 1674 const int* v = vectors + i*m; 1675 labels[i] = idx++; 1676 idx &= idx < k ? -1 : 0; 1677 1678 // compute weight of each vector 1679 for( j = 0; j < m; j++ ) 1680 sum += v[j]; 1681 v_weights[i] = sum ? 1./sum : 0.; 1682 } 1683 1684 for( i = 0; i < n; i++ ) 1685 { 1686 int i1 = cvRandInt(r) % n; 1687 int i2 = cvRandInt(r) % n; 1688 CV_SWAP( labels[i1], labels[i2], j ); 1689 } 1690 1691 for( iters = 0; iters <= max_iters; iters++ ) 1692 { 1693 // calculate csums 1694 for( i = 0; i < k; i++ ) 1695 { 1696 for( j = 0; j < m; j++ ) 1697 csums[i*m + j] = 0; 1698 } 1699 1700 for( i = 0; i < n; i++ ) 1701 { 1702 const int* v = vectors + i*m; 1703 int* s = csums + labels[i]*m; 1704 for( j = 0; j < m; j++ ) 1705 s[j] += v[j]; 1706 } 1707 1708 // exit the loop here, when we have up-to-date csums 1709 if( iters == max_iters || !modified ) 1710 break; 1711 1712 modified = false; 1713 1714 // calculate weight of each cluster 1715 for( i = 0; i < k; i++ ) 1716 { 1717 const int* s = csums + i*m; 1718 int sum = 0; 1719 for( j = 0; j < m; j++ ) 1720 sum += s[j]; 1721 c_weights[i] = sum ? 1./sum : 0; 1722 } 1723 1724 // now for each vector determine the closest cluster 1725 for( i = 0; i < n; i++ ) 1726 { 1727 const int* v = vectors + i*m; 1728 double alpha = v_weights[i]; 1729 double min_dist2 = DBL_MAX; 1730 int min_idx = -1; 1731 1732 for( idx = 0; idx < k; idx++ ) 1733 { 1734 const int* s = csums + idx*m; 1735 double dist2 = 0., beta = c_weights[idx]; 1736 for( j = 0; j < m; j++ ) 1737 { 1738 double t = v[j]*alpha - s[j]*beta; 1739 dist2 += t*t; 1740 } 1741 if( min_dist2 > dist2 ) 1742 { 1743 min_dist2 = dist2; 1744 min_idx = idx; 1745 } 1746 } 1747 1748 if( min_idx != labels[i] ) 1749 modified = true; 1750 labels[i] = min_idx; 1751 } 1752 } 1753} 1754 1755 1756CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi ) 1757{ 1758 CvDTreeSplit* split; 1759 const int* labels = data->get_cat_var_data(node, vi); 1760 const int* responses = data->get_class_labels(node); 1761 int ci = data->get_var_type(vi); 1762 int n = node->sample_count; 1763 int m = data->get_num_classes(); 1764 int _mi = data->cat_count->data.i[ci], mi = _mi; 1765 int* lc = (int*)cvStackAlloc(m*sizeof(lc[0])); 1766 int* rc = (int*)cvStackAlloc(m*sizeof(rc[0])); 1767 int* _cjk = (int*)cvStackAlloc(m*(mi+1)*sizeof(_cjk[0]))+m, *cjk = _cjk; 1768 double* c_weights = (double*)cvStackAlloc( mi*sizeof(c_weights[0]) ); 1769 int* cluster_labels = 0; 1770 int** int_ptr = 0; 1771 int i, j, k, idx; 1772 double L = 0, R = 0; 1773 double best_val = 0; 1774 int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0; 1775 const double* priors = data->priors_mult->data.db; 1776 1777 // init array of counters: 1778 // c_{jk} - number of samples that have vi-th input variable = j and response = k. 1779 for( j = -1; j < mi; j++ ) 1780 for( k = 0; k < m; k++ ) 1781 cjk[j*m + k] = 0; 1782 1783 for( i = 0; i < n; i++ ) 1784 { 1785 j = labels[i]; 1786 k = responses[i]; 1787 cjk[j*m + k]++; 1788 } 1789 1790 if( m > 2 ) 1791 { 1792 if( mi > data->params.max_categories ) 1793 { 1794 mi = MIN(data->params.max_categories, n); 1795 cjk += _mi*m; 1796 cluster_labels = (int*)cvStackAlloc(mi*sizeof(cluster_labels[0])); 1797 cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels ); 1798 } 1799 subset_i = 1; 1800 subset_n = 1 << mi; 1801 } 1802 else 1803 { 1804 assert( m == 2 ); 1805 int_ptr = (int**)cvStackAlloc( mi*sizeof(int_ptr[0]) ); 1806 for( j = 0; j < mi; j++ ) 1807 int_ptr[j] = cjk + j*2 + 1; 1808 icvSortIntPtr( int_ptr, mi, 0 ); 1809 subset_i = 0; 1810 subset_n = mi; 1811 } 1812 1813 for( k = 0; k < m; k++ ) 1814 { 1815 int sum = 0; 1816 for( j = 0; j < mi; j++ ) 1817 sum += cjk[j*m + k]; 1818 rc[k] = sum; 1819 lc[k] = 0; 1820 } 1821 1822 for( j = 0; j < mi; j++ ) 1823 { 1824 double sum = 0; 1825 for( k = 0; k < m; k++ ) 1826 sum += cjk[j*m + k]*priors[k]; 1827 c_weights[j] = sum; 1828 R += c_weights[j]; 1829 } 1830 1831 for( ; subset_i < subset_n; subset_i++ ) 1832 { 1833 double weight; 1834 int* crow; 1835 double lsum2 = 0, rsum2 = 0; 1836 1837 if( m == 2 ) 1838 idx = (int)(int_ptr[subset_i] - cjk)/2; 1839 else 1840 { 1841 int graycode = (subset_i>>1)^subset_i; 1842 int diff = graycode ^ prevcode; 1843 1844 // determine index of the changed bit. 1845 Cv32suf u; 1846 idx = diff >= (1 << 16) ? 16 : 0; 1847 u.f = (float)(((diff >> 16) | diff) & 65535); 1848 idx += (u.i >> 23) - 127; 1849 subtract = graycode < prevcode; 1850 prevcode = graycode; 1851 } 1852 1853 crow = cjk + idx*m; 1854 weight = c_weights[idx]; 1855 if( weight < FLT_EPSILON ) 1856 continue; 1857 1858 if( !subtract ) 1859 { 1860 for( k = 0; k < m; k++ ) 1861 { 1862 int t = crow[k]; 1863 int lval = lc[k] + t; 1864 int rval = rc[k] - t; 1865 double p = priors[k], p2 = p*p; 1866 lsum2 += p2*lval*lval; 1867 rsum2 += p2*rval*rval; 1868 lc[k] = lval; rc[k] = rval; 1869 } 1870 L += weight; 1871 R -= weight; 1872 } 1873 else 1874 { 1875 for( k = 0; k < m; k++ ) 1876 { 1877 int t = crow[k]; 1878 int lval = lc[k] - t; 1879 int rval = rc[k] + t; 1880 double p = priors[k], p2 = p*p; 1881 lsum2 += p2*lval*lval; 1882 rsum2 += p2*rval*rval; 1883 lc[k] = lval; rc[k] = rval; 1884 } 1885 L -= weight; 1886 R += weight; 1887 } 1888 1889 if( L > FLT_EPSILON && R > FLT_EPSILON ) 1890 { 1891 double val = (lsum2*R + rsum2*L)/((double)L*R); 1892 if( best_val < val ) 1893 { 1894 best_val = val; 1895 best_subset = subset_i; 1896 } 1897 } 1898 } 1899 1900 if( best_subset < 0 ) 1901 return 0; 1902 1903 split = data->new_split_cat( vi, (float)best_val ); 1904 1905 if( m == 2 ) 1906 { 1907 for( i = 0; i <= best_subset; i++ ) 1908 { 1909 idx = (int)(int_ptr[i] - cjk) >> 1; 1910 split->subset[idx >> 5] |= 1 << (idx & 31); 1911 } 1912 } 1913 else 1914 { 1915 for( i = 0; i < _mi; i++ ) 1916 { 1917 idx = cluster_labels ? cluster_labels[i] : i; 1918 if( best_subset & (1 << idx) ) 1919 split->subset[i >> 5] |= 1 << (i & 31); 1920 } 1921 } 1922 1923 return split; 1924} 1925 1926 1927CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi ) 1928{ 1929 const float epsilon = FLT_EPSILON*2; 1930 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi); 1931 const float* responses = data->get_ord_responses(node); 1932 int n = node->sample_count; 1933 int n1 = node->get_num_valid(vi); 1934 int i, best_i = -1; 1935 double best_val = 0, lsum = 0, rsum = node->value*n; 1936 int L = 0, R = n1; 1937 1938 // compensate for missing values 1939 for( i = n1; i < n; i++ ) 1940 rsum -= responses[sorted[i].i]; 1941 1942 // find the optimal split 1943 for( i = 0; i < n1 - 1; i++ ) 1944 { 1945 float t = responses[sorted[i].i]; 1946 L++; R--; 1947 lsum += t; 1948 rsum -= t; 1949 1950 if( sorted[i].val + epsilon < sorted[i+1].val ) 1951 { 1952 double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R); 1953 if( best_val < val ) 1954 { 1955 best_val = val; 1956 best_i = i; 1957 } 1958 } 1959 } 1960 1961 return best_i >= 0 ? data->new_split_ord( vi, 1962 (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i, 1963 0, (float)best_val ) : 0; 1964} 1965 1966 1967CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi ) 1968{ 1969 CvDTreeSplit* split; 1970 const int* labels = data->get_cat_var_data(node, vi); 1971 const float* responses = data->get_ord_responses(node); 1972 int ci = data->get_var_type(vi); 1973 int n = node->sample_count; 1974 int mi = data->cat_count->data.i[ci]; 1975 double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1; 1976 int* counts = (int*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1; 1977 double** sum_ptr = 0; 1978 int i, L = 0, R = 0; 1979 double best_val = 0, lsum = 0, rsum = 0; 1980 int best_subset = -1, subset_i; 1981 1982 for( i = -1; i < mi; i++ ) 1983 sum[i] = counts[i] = 0; 1984 1985 // calculate sum response and weight of each category of the input var 1986 for( i = 0; i < n; i++ ) 1987 { 1988 int idx = labels[i]; 1989 double s = sum[idx] + responses[i]; 1990 int nc = counts[idx] + 1; 1991 sum[idx] = s; 1992 counts[idx] = nc; 1993 } 1994 1995 // calculate average response in each category 1996 for( i = 0; i < mi; i++ ) 1997 { 1998 R += counts[i]; 1999 rsum += sum[i]; 2000 sum[i] /= MAX(counts[i],1); 2001 sum_ptr[i] = sum + i; 2002 } 2003 2004 icvSortDblPtr( sum_ptr, mi, 0 ); 2005 2006 // revert back to unnormalized sums 2007 // (there should be a very little loss of accuracy) 2008 for( i = 0; i < mi; i++ ) 2009 sum[i] *= counts[i]; 2010 2011 for( subset_i = 0; subset_i < mi-1; subset_i++ ) 2012 { 2013 int idx = (int)(sum_ptr[subset_i] - sum); 2014 int ni = counts[idx]; 2015 2016 if( ni ) 2017 { 2018 double s = sum[idx]; 2019 lsum += s; L += ni; 2020 rsum -= s; R -= ni; 2021 2022 if( L && R ) 2023 { 2024 double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R); 2025 if( best_val < val ) 2026 { 2027 best_val = val; 2028 best_subset = subset_i; 2029 } 2030 } 2031 } 2032 } 2033 2034 if( best_subset < 0 ) 2035 return 0; 2036 2037 split = data->new_split_cat( vi, (float)best_val ); 2038 for( i = 0; i <= best_subset; i++ ) 2039 { 2040 int idx = (int)(sum_ptr[i] - sum); 2041 split->subset[idx >> 5] |= 1 << (idx & 31); 2042 } 2043 2044 return split; 2045} 2046 2047 2048CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi ) 2049{ 2050 const float epsilon = FLT_EPSILON*2; 2051 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi); 2052 const char* dir = (char*)data->direction->data.ptr; 2053 int n1 = node->get_num_valid(vi); 2054 // LL - number of samples that both the primary and the surrogate splits send to the left 2055 // LR - ... primary split sends to the left and the surrogate split sends to the right 2056 // RL - ... primary split sends to the right and the surrogate split sends to the left 2057 // RR - ... both send to the right 2058 int i, best_i = -1, best_inversed = 0; 2059 double best_val; 2060 2061 if( !data->have_priors ) 2062 { 2063 int LL = 0, RL = 0, LR, RR; 2064 int worst_val = cvFloor(node->maxlr), _best_val = worst_val; 2065 int sum = 0, sum_abs = 0; 2066 2067 for( i = 0; i < n1; i++ ) 2068 { 2069 int d = dir[sorted[i].i]; 2070 sum += d; sum_abs += d & 1; 2071 } 2072 2073 // sum_abs = R + L; sum = R - L 2074 RR = (sum_abs + sum) >> 1; 2075 LR = (sum_abs - sum) >> 1; 2076 2077 // initially all the samples are sent to the right by the surrogate split, 2078 // LR of them are sent to the left by primary split, and RR - to the right. 2079 // now iteratively compute LL, LR, RL and RR for every possible surrogate split value. 2080 for( i = 0; i < n1 - 1; i++ ) 2081 { 2082 int d = dir[sorted[i].i]; 2083 2084 if( d < 0 ) 2085 { 2086 LL++; LR--; 2087 if( LL + RR > _best_val && sorted[i].val + epsilon < sorted[i+1].val ) 2088 { 2089 best_val = LL + RR; 2090 best_i = i; best_inversed = 0; 2091 } 2092 } 2093 else if( d > 0 ) 2094 { 2095 RL++; RR--; 2096 if( RL + LR > _best_val && sorted[i].val + epsilon < sorted[i+1].val ) 2097 { 2098 best_val = RL + LR; 2099 best_i = i; best_inversed = 1; 2100 } 2101 } 2102 } 2103 best_val = _best_val; 2104 } 2105 else 2106 { 2107 double LL = 0, RL = 0, LR, RR; 2108 double worst_val = node->maxlr; 2109 double sum = 0, sum_abs = 0; 2110 const double* priors = data->priors_mult->data.db; 2111 const int* responses = data->get_class_labels(node); 2112 best_val = worst_val; 2113 2114 for( i = 0; i < n1; i++ ) 2115 { 2116 int idx = sorted[i].i; 2117 double w = priors[responses[idx]]; 2118 int d = dir[idx]; 2119 sum += d*w; sum_abs += (d & 1)*w; 2120 } 2121 2122 // sum_abs = R + L; sum = R - L 2123 RR = (sum_abs + sum)*0.5; 2124 LR = (sum_abs - sum)*0.5; 2125 2126 // initially all the samples are sent to the right by the surrogate split, 2127 // LR of them are sent to the left by primary split, and RR - to the right. 2128 // now iteratively compute LL, LR, RL and RR for every possible surrogate split value. 2129 for( i = 0; i < n1 - 1; i++ ) 2130 { 2131 int idx = sorted[i].i; 2132 double w = priors[responses[idx]]; 2133 int d = dir[idx]; 2134 2135 if( d < 0 ) 2136 { 2137 LL += w; LR -= w; 2138 if( LL + RR > best_val && sorted[i].val + epsilon < sorted[i+1].val ) 2139 { 2140 best_val = LL + RR; 2141 best_i = i; best_inversed = 0; 2142 } 2143 } 2144 else if( d > 0 ) 2145 { 2146 RL += w; RR -= w; 2147 if( RL + LR > best_val && sorted[i].val + epsilon < sorted[i+1].val ) 2148 { 2149 best_val = RL + LR; 2150 best_i = i; best_inversed = 1; 2151 } 2152 } 2153 } 2154 } 2155 2156 return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi, 2157 (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i, 2158 best_inversed, (float)best_val ) : 0; 2159} 2160 2161 2162CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi ) 2163{ 2164 const int* labels = data->get_cat_var_data(node, vi); 2165 const char* dir = (char*)data->direction->data.ptr; 2166 int n = node->sample_count; 2167 // LL - number of samples that both the primary and the surrogate splits send to the left 2168 // LR - ... primary split sends to the left and the surrogate split sends to the right 2169 // RL - ... primary split sends to the right and the surrogate split sends to the left 2170 // RR - ... both send to the right 2171 CvDTreeSplit* split = data->new_split_cat( vi, 0 ); 2172 int i, mi = data->cat_count->data.i[data->get_var_type(vi)], l_win = 0; 2173 double best_val = 0; 2174 double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1; 2175 double* rc = lc + mi + 1; 2176 2177 for( i = -1; i < mi; i++ ) 2178 lc[i] = rc[i] = 0; 2179 2180 // for each category calculate the weight of samples 2181 // sent to the left (lc) and to the right (rc) by the primary split 2182 if( !data->have_priors ) 2183 { 2184 int* _lc = (int*)cvStackAlloc((mi+2)*2*sizeof(_lc[0])) + 1; 2185 int* _rc = _lc + mi + 1; 2186 2187 for( i = -1; i < mi; i++ ) 2188 _lc[i] = _rc[i] = 0; 2189 2190 for( i = 0; i < n; i++ ) 2191 { 2192 int idx = labels[i]; 2193 int d = dir[i]; 2194 int sum = _lc[idx] + d; 2195 int sum_abs = _rc[idx] + (d & 1); 2196 _lc[idx] = sum; _rc[idx] = sum_abs; 2197 } 2198 2199 for( i = 0; i < mi; i++ ) 2200 { 2201 int sum = _lc[i]; 2202 int sum_abs = _rc[i]; 2203 lc[i] = (sum_abs - sum) >> 1; 2204 rc[i] = (sum_abs + sum) >> 1; 2205 } 2206 } 2207 else 2208 { 2209 const double* priors = data->priors_mult->data.db; 2210 const int* responses = data->get_class_labels(node); 2211 2212 for( i = 0; i < n; i++ ) 2213 { 2214 int idx = labels[i]; 2215 double w = priors[responses[i]]; 2216 int d = dir[i]; 2217 double sum = lc[idx] + d*w; 2218 double sum_abs = rc[idx] + (d & 1)*w; 2219 lc[idx] = sum; rc[idx] = sum_abs; 2220 } 2221 2222 for( i = 0; i < mi; i++ ) 2223 { 2224 double sum = lc[i]; 2225 double sum_abs = rc[i]; 2226 lc[i] = (sum_abs - sum) * 0.5; 2227 rc[i] = (sum_abs + sum) * 0.5; 2228 } 2229 } 2230 2231 // 2. now form the split. 2232 // in each category send all the samples to the same direction as majority 2233 for( i = 0; i < mi; i++ ) 2234 { 2235 double lval = lc[i], rval = rc[i]; 2236 if( lval > rval ) 2237 { 2238 split->subset[i >> 5] |= 1 << (i & 31); 2239 best_val += lval; 2240 l_win++; 2241 } 2242 else 2243 best_val += rval; 2244 } 2245 2246 split->quality = (float)best_val; 2247 if( split->quality <= node->maxlr || l_win == 0 || l_win == mi ) 2248 cvSetRemoveByPtr( data->split_heap, split ), split = 0; 2249 2250 return split; 2251} 2252 2253 2254void CvDTree::calc_node_value( CvDTreeNode* node ) 2255{ 2256 int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds; 2257 const int* cv_labels = data->get_labels(node); 2258 2259 if( data->is_classifier ) 2260 { 2261 // in case of classification tree: 2262 // * node value is the label of the class that has the largest weight in the node. 2263 // * node risk is the weighted number of misclassified samples, 2264 // * j-th cross-validation fold value and risk are calculated as above, 2265 // but using the samples with cv_labels(*)!=j. 2266 // * j-th cross-validation fold error is calculated as the weighted number of 2267 // misclassified samples with cv_labels(*)==j. 2268 2269 // compute the number of instances of each class 2270 int* cls_count = data->counts->data.i; 2271 const int* responses = data->get_class_labels(node); 2272 int m = data->get_num_classes(); 2273 int* cv_cls_count = (int*)cvStackAlloc(m*cv_n*sizeof(cv_cls_count[0])); 2274 double max_val = -1, total_weight = 0; 2275 int max_k = -1; 2276 double* priors = data->priors_mult->data.db; 2277 2278 for( k = 0; k < m; k++ ) 2279 cls_count[k] = 0; 2280 2281 if( cv_n == 0 ) 2282 { 2283 for( i = 0; i < n; i++ ) 2284 cls_count[responses[i]]++; 2285 } 2286 else 2287 { 2288 for( j = 0; j < cv_n; j++ ) 2289 for( k = 0; k < m; k++ ) 2290 cv_cls_count[j*m + k] = 0; 2291 2292 for( i = 0; i < n; i++ ) 2293 { 2294 j = cv_labels[i]; k = responses[i]; 2295 cv_cls_count[j*m + k]++; 2296 } 2297 2298 for( j = 0; j < cv_n; j++ ) 2299 for( k = 0; k < m; k++ ) 2300 cls_count[k] += cv_cls_count[j*m + k]; 2301 } 2302 2303 if( data->have_priors && node->parent == 0 ) 2304 { 2305 // compute priors_mult from priors, take the sample ratio into account. 2306 double sum = 0; 2307 for( k = 0; k < m; k++ ) 2308 { 2309 int n_k = cls_count[k]; 2310 priors[k] = data->priors->data.db[k]*(n_k ? 1./n_k : 0.); 2311 sum += priors[k]; 2312 } 2313 sum = 1./sum; 2314 for( k = 0; k < m; k++ ) 2315 priors[k] *= sum; 2316 } 2317 2318 for( k = 0; k < m; k++ ) 2319 { 2320 double val = cls_count[k]*priors[k]; 2321 total_weight += val; 2322 if( max_val < val ) 2323 { 2324 max_val = val; 2325 max_k = k; 2326 } 2327 } 2328 2329 node->class_idx = max_k; 2330 node->value = data->cat_map->data.i[ 2331 data->cat_ofs->data.i[data->cat_var_count] + max_k]; 2332 node->node_risk = total_weight - max_val; 2333 2334 for( j = 0; j < cv_n; j++ ) 2335 { 2336 double sum_k = 0, sum = 0, max_val_k = 0; 2337 max_val = -1; max_k = -1; 2338 2339 for( k = 0; k < m; k++ ) 2340 { 2341 double w = priors[k]; 2342 double val_k = cv_cls_count[j*m + k]*w; 2343 double val = cls_count[k]*w - val_k; 2344 sum_k += val_k; 2345 sum += val; 2346 if( max_val < val ) 2347 { 2348 max_val = val; 2349 max_val_k = val_k; 2350 max_k = k; 2351 } 2352 } 2353 2354 node->cv_Tn[j] = INT_MAX; 2355 node->cv_node_risk[j] = sum - max_val; 2356 node->cv_node_error[j] = sum_k - max_val_k; 2357 } 2358 } 2359 else 2360 { 2361 // in case of regression tree: 2362 // * node value is 1/n*sum_i(Y_i), where Y_i is i-th response, 2363 // n is the number of samples in the node. 2364 // * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2) 2365 // * j-th cross-validation fold value and risk are calculated as above, 2366 // but using the samples with cv_labels(*)!=j. 2367 // * j-th cross-validation fold error is calculated 2368 // using samples with cv_labels(*)==j as the test subset: 2369 // error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2), 2370 // where node_value_j is the node value calculated 2371 // as described in the previous bullet, and summation is done 2372 // over the samples with cv_labels(*)==j. 2373 2374 double sum = 0, sum2 = 0; 2375 const float* values = data->get_ord_responses(node); 2376 double *cv_sum = 0, *cv_sum2 = 0; 2377 int* cv_count = 0; 2378 2379 if( cv_n == 0 ) 2380 { 2381 for( i = 0; i < n; i++ ) 2382 { 2383 double t = values[i]; 2384 sum += t; 2385 sum2 += t*t; 2386 } 2387 } 2388 else 2389 { 2390 cv_sum = (double*)cvStackAlloc( cv_n*sizeof(cv_sum[0]) ); 2391 cv_sum2 = (double*)cvStackAlloc( cv_n*sizeof(cv_sum2[0]) ); 2392 cv_count = (int*)cvStackAlloc( cv_n*sizeof(cv_count[0]) ); 2393 2394 for( j = 0; j < cv_n; j++ ) 2395 { 2396 cv_sum[j] = cv_sum2[j] = 0.; 2397 cv_count[j] = 0; 2398 } 2399 2400 for( i = 0; i < n; i++ ) 2401 { 2402 j = cv_labels[i]; 2403 double t = values[i]; 2404 double s = cv_sum[j] + t; 2405 double s2 = cv_sum2[j] + t*t; 2406 int nc = cv_count[j] + 1; 2407 cv_sum[j] = s; 2408 cv_sum2[j] = s2; 2409 cv_count[j] = nc; 2410 } 2411 2412 for( j = 0; j < cv_n; j++ ) 2413 { 2414 sum += cv_sum[j]; 2415 sum2 += cv_sum2[j]; 2416 } 2417 } 2418 2419 node->node_risk = sum2 - (sum/n)*sum; 2420 node->value = sum/n; 2421 2422 for( j = 0; j < cv_n; j++ ) 2423 { 2424 double s = cv_sum[j], si = sum - s; 2425 double s2 = cv_sum2[j], s2i = sum2 - s2; 2426 int c = cv_count[j], ci = n - c; 2427 double r = si/MAX(ci,1); 2428 node->cv_node_risk[j] = s2i - r*r*ci; 2429 node->cv_node_error[j] = s2 - 2*r*s + c*r*r; 2430 node->cv_Tn[j] = INT_MAX; 2431 } 2432 } 2433} 2434 2435 2436void CvDTree::complete_node_dir( CvDTreeNode* node ) 2437{ 2438 int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1; 2439 int nz = n - node->get_num_valid(node->split->var_idx); 2440 char* dir = (char*)data->direction->data.ptr; 2441 2442 // try to complete direction using surrogate splits 2443 if( nz && data->params.use_surrogates ) 2444 { 2445 CvDTreeSplit* split = node->split->next; 2446 for( ; split != 0 && nz; split = split->next ) 2447 { 2448 int inversed_mask = split->inversed ? -1 : 0; 2449 vi = split->var_idx; 2450 2451 if( data->get_var_type(vi) >= 0 ) // split on categorical var 2452 { 2453 const int* labels = data->get_cat_var_data(node, vi); 2454 const int* subset = split->subset; 2455 2456 for( i = 0; i < n; i++ ) 2457 { 2458 int idx; 2459 if( !dir[i] && (idx = labels[i]) >= 0 ) 2460 { 2461 int d = CV_DTREE_CAT_DIR(idx,subset); 2462 dir[i] = (char)((d ^ inversed_mask) - inversed_mask); 2463 if( --nz ) 2464 break; 2465 } 2466 } 2467 } 2468 else // split on ordered var 2469 { 2470 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi); 2471 int split_point = split->ord.split_point; 2472 int n1 = node->get_num_valid(vi); 2473 2474 assert( 0 <= split_point && split_point < n-1 ); 2475 2476 for( i = 0; i < n1; i++ ) 2477 { 2478 int idx = sorted[i].i; 2479 if( !dir[idx] ) 2480 { 2481 int d = i <= split_point ? -1 : 1; 2482 dir[idx] = (char)((d ^ inversed_mask) - inversed_mask); 2483 if( --nz ) 2484 break; 2485 } 2486 } 2487 } 2488 } 2489 } 2490 2491 // find the default direction for the rest 2492 if( nz ) 2493 { 2494 for( i = nr = 0; i < n; i++ ) 2495 nr += dir[i] > 0; 2496 nl = n - nr - nz; 2497 d0 = nl > nr ? -1 : nr > nl; 2498 } 2499 2500 // make sure that every sample is directed either to the left or to the right 2501 for( i = 0; i < n; i++ ) 2502 { 2503 int d = dir[i]; 2504 if( !d ) 2505 { 2506 d = d0; 2507 if( !d ) 2508 d = d1, d1 = -d1; 2509 } 2510 d = d > 0; 2511 dir[i] = (char)d; // remap (-1,1) to (0,1) 2512 } 2513} 2514 2515 2516void CvDTree::split_node_data( CvDTreeNode* node ) 2517{ 2518 int vi, i, n = node->sample_count, nl, nr; 2519 char* dir = (char*)data->direction->data.ptr; 2520 CvDTreeNode *left = 0, *right = 0; 2521 int* new_idx = data->split_buf->data.i; 2522 int new_buf_idx = data->get_child_buf_idx( node ); 2523 int work_var_count = data->get_work_var_count(); 2524 2525 // speedup things a little, especially for tree ensembles with a lots of small trees: 2526 // do not physically split the input data between the left and right child nodes 2527 // when we are not going to split them further, 2528 // as calc_node_value() does not requires input features anyway. 2529 bool split_input_data; 2530 2531 complete_node_dir(node); 2532 2533 for( i = nl = nr = 0; i < n; i++ ) 2534 { 2535 int d = dir[i]; 2536 // initialize new indices for splitting ordered variables 2537 new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li 2538 nr += d; 2539 nl += d^1; 2540 } 2541 2542 node->left = left = data->new_node( node, nl, new_buf_idx, node->offset ); 2543 node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + 2544 (data->ord_var_count + work_var_count)*nl ); 2545 2546 split_input_data = node->depth + 1 < data->params.max_depth && 2547 (node->left->sample_count > data->params.min_sample_count || 2548 node->right->sample_count > data->params.min_sample_count); 2549 2550 // split ordered variables, keep both halves sorted. 2551 for( vi = 0; vi < data->var_count; vi++ ) 2552 { 2553 int ci = data->get_var_type(vi); 2554 int n1 = node->get_num_valid(vi); 2555 CvPair32s32f *src, *ldst0, *rdst0, *ldst, *rdst; 2556 CvPair32s32f tl, tr; 2557 2558 if( ci >= 0 || !split_input_data ) 2559 continue; 2560 2561 src = data->get_ord_var_data(node, vi); 2562 ldst0 = ldst = data->get_ord_var_data(left, vi); 2563 rdst0 = rdst = data->get_ord_var_data(right, vi); 2564 tl = ldst0[nl]; tr = rdst0[nr]; 2565 2566 // split sorted 2567 for( i = 0; i < n1; i++ ) 2568 { 2569 int idx = src[i].i; 2570 float val = src[i].val; 2571 int d = dir[idx]; 2572 idx = new_idx[idx]; 2573 ldst->i = rdst->i = idx; 2574 ldst->val = rdst->val = val; 2575 ldst += d^1; 2576 rdst += d; 2577 } 2578 2579 left->set_num_valid(vi, (int)(ldst - ldst0)); 2580 right->set_num_valid(vi, (int)(rdst - rdst0)); 2581 2582 // split missing 2583 for( ; i < n; i++ ) 2584 { 2585 int idx = src[i].i; 2586 int d = dir[idx]; 2587 idx = new_idx[idx]; 2588 ldst->i = rdst->i = idx; 2589 ldst->val = rdst->val = ord_nan; 2590 ldst += d^1; 2591 rdst += d; 2592 } 2593 2594 ldst0[nl] = tl; rdst0[nr] = tr; 2595 } 2596 2597 // split categorical vars, responses and cv_labels using new_idx relocation table 2598 for( vi = 0; vi < work_var_count; vi++ ) 2599 { 2600 int ci = data->get_var_type(vi); 2601 int n1 = node->get_num_valid(vi), nr1 = 0; 2602 int *src, *ldst0, *rdst0, *ldst, *rdst; 2603 int tl, tr; 2604 2605 if( ci < 0 || (vi < data->var_count && !split_input_data) ) 2606 continue; 2607 2608 src = data->get_cat_var_data(node, vi); 2609 ldst0 = ldst = data->get_cat_var_data(left, vi); 2610 rdst0 = rdst = data->get_cat_var_data(right, vi); 2611 tl = ldst0[nl]; tr = rdst0[nr]; 2612 2613 for( i = 0; i < n; i++ ) 2614 { 2615 int d = dir[i]; 2616 int val = src[i]; 2617 *ldst = *rdst = val; 2618 ldst += d^1; 2619 rdst += d; 2620 nr1 += (val >= 0)&d; 2621 } 2622 2623 if( vi < data->var_count ) 2624 { 2625 left->set_num_valid(vi, n1 - nr1); 2626 right->set_num_valid(vi, nr1); 2627 } 2628 2629 ldst0[nl] = tl; rdst0[nr] = tr; 2630 } 2631 2632 // deallocate the parent node data that is not needed anymore 2633 data->free_node_data(node); 2634} 2635 2636 2637void CvDTree::prune_cv() 2638{ 2639 CvMat* ab = 0; 2640 CvMat* temp = 0; 2641 CvMat* err_jk = 0; 2642 2643 // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}. 2644 // 2. choose the best tree index (if need, apply 1SE rule). 2645 // 3. store the best index and cut the branches. 2646 2647 CV_FUNCNAME( "CvDTree::prune_cv" ); 2648 2649 __BEGIN__; 2650 2651 int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count; 2652 // currently, 1SE for regression is not implemented 2653 bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier; 2654 double* err; 2655 double min_err = 0, min_err_se = 0; 2656 int min_idx = -1; 2657 2658 CV_CALL( ab = cvCreateMat( 1, 256, CV_64F )); 2659 2660 // build the main tree sequence, calculate alpha's 2661 for(;;tree_count++) 2662 { 2663 double min_alpha = update_tree_rnc(tree_count, -1); 2664 if( cut_tree(tree_count, -1, min_alpha) ) 2665 break; 2666 2667 if( ab->cols <= tree_count ) 2668 { 2669 CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F )); 2670 for( ti = 0; ti < ab->cols; ti++ ) 2671 temp->data.db[ti] = ab->data.db[ti]; 2672 cvReleaseMat( &ab ); 2673 ab = temp; 2674 temp = 0; 2675 } 2676 2677 ab->data.db[tree_count] = min_alpha; 2678 } 2679 2680 ab->data.db[0] = 0.; 2681 2682 if( tree_count > 0 ) 2683 { 2684 for( ti = 1; ti < tree_count-1; ti++ ) 2685 ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]); 2686 ab->data.db[tree_count-1] = DBL_MAX*0.5; 2687 2688 CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F )); 2689 err = err_jk->data.db; 2690 2691 for( j = 0; j < cv_n; j++ ) 2692 { 2693 int tj = 0, tk = 0; 2694 for( ; tk < tree_count; tj++ ) 2695 { 2696 double min_alpha = update_tree_rnc(tj, j); 2697 if( cut_tree(tj, j, min_alpha) ) 2698 min_alpha = DBL_MAX; 2699 2700 for( ; tk < tree_count; tk++ ) 2701 { 2702 if( ab->data.db[tk] > min_alpha ) 2703 break; 2704 err[j*tree_count + tk] = root->tree_error; 2705 } 2706 } 2707 } 2708 2709 for( ti = 0; ti < tree_count; ti++ ) 2710 { 2711 double sum_err = 0; 2712 for( j = 0; j < cv_n; j++ ) 2713 sum_err += err[j*tree_count + ti]; 2714 if( ti == 0 || sum_err < min_err ) 2715 { 2716 min_err = sum_err; 2717 min_idx = ti; 2718 if( use_1se ) 2719 min_err_se = sqrt( sum_err*(n - sum_err) ); 2720 } 2721 else if( sum_err < min_err + min_err_se ) 2722 min_idx = ti; 2723 } 2724 } 2725 2726 pruned_tree_idx = min_idx; 2727 free_prune_data(data->params.truncate_pruned_tree != 0); 2728 2729 __END__; 2730 2731 cvReleaseMat( &err_jk ); 2732 cvReleaseMat( &ab ); 2733 cvReleaseMat( &temp ); 2734} 2735 2736 2737double CvDTree::update_tree_rnc( int T, int fold ) 2738{ 2739 CvDTreeNode* node = root; 2740 double min_alpha = DBL_MAX; 2741 2742 for(;;) 2743 { 2744 CvDTreeNode* parent; 2745 for(;;) 2746 { 2747 int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn; 2748 if( t <= T || !node->left ) 2749 { 2750 node->complexity = 1; 2751 node->tree_risk = node->node_risk; 2752 node->tree_error = 0.; 2753 if( fold >= 0 ) 2754 { 2755 node->tree_risk = node->cv_node_risk[fold]; 2756 node->tree_error = node->cv_node_error[fold]; 2757 } 2758 break; 2759 } 2760 node = node->left; 2761 } 2762 2763 for( parent = node->parent; parent && parent->right == node; 2764 node = parent, parent = parent->parent ) 2765 { 2766 parent->complexity += node->complexity; 2767 parent->tree_risk += node->tree_risk; 2768 parent->tree_error += node->tree_error; 2769 2770 parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk) 2771 - parent->tree_risk)/(parent->complexity - 1); 2772 min_alpha = MIN( min_alpha, parent->alpha ); 2773 } 2774 2775 if( !parent ) 2776 break; 2777 2778 parent->complexity = node->complexity; 2779 parent->tree_risk = node->tree_risk; 2780 parent->tree_error = node->tree_error; 2781 node = parent->right; 2782 } 2783 2784 return min_alpha; 2785} 2786 2787 2788int CvDTree::cut_tree( int T, int fold, double min_alpha ) 2789{ 2790 CvDTreeNode* node = root; 2791 if( !node->left ) 2792 return 1; 2793 2794 for(;;) 2795 { 2796 CvDTreeNode* parent; 2797 for(;;) 2798 { 2799 int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn; 2800 if( t <= T || !node->left ) 2801 break; 2802 if( node->alpha <= min_alpha + FLT_EPSILON ) 2803 { 2804 if( fold >= 0 ) 2805 node->cv_Tn[fold] = T; 2806 else 2807 node->Tn = T; 2808 if( node == root ) 2809 return 1; 2810 break; 2811 } 2812 node = node->left; 2813 } 2814 2815 for( parent = node->parent; parent && parent->right == node; 2816 node = parent, parent = parent->parent ) 2817 ; 2818 2819 if( !parent ) 2820 break; 2821 2822 node = parent->right; 2823 } 2824 2825 return 0; 2826} 2827 2828 2829void CvDTree::free_prune_data(bool cut_tree) 2830{ 2831 CvDTreeNode* node = root; 2832 2833 for(;;) 2834 { 2835 CvDTreeNode* parent; 2836 for(;;) 2837 { 2838 // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn ) 2839 // as we will clear the whole cross-validation heap at the end 2840 node->cv_Tn = 0; 2841 node->cv_node_error = node->cv_node_risk = 0; 2842 if( !node->left ) 2843 break; 2844 node = node->left; 2845 } 2846 2847 for( parent = node->parent; parent && parent->right == node; 2848 node = parent, parent = parent->parent ) 2849 { 2850 if( cut_tree && parent->Tn <= pruned_tree_idx ) 2851 { 2852 data->free_node( parent->left ); 2853 data->free_node( parent->right ); 2854 parent->left = parent->right = 0; 2855 } 2856 } 2857 2858 if( !parent ) 2859 break; 2860 2861 node = parent->right; 2862 } 2863 2864 if( data->cv_heap ) 2865 cvClearSet( data->cv_heap ); 2866} 2867 2868 2869void CvDTree::free_tree() 2870{ 2871 if( root && data && data->shared ) 2872 { 2873 pruned_tree_idx = INT_MIN; 2874 free_prune_data(true); 2875 data->free_node(root); 2876 root = 0; 2877 } 2878} 2879 2880 2881CvDTreeNode* CvDTree::predict( const CvMat* _sample, 2882 const CvMat* _missing, bool preprocessed_input ) const 2883{ 2884 CvDTreeNode* result = 0; 2885 int* catbuf = 0; 2886 2887 CV_FUNCNAME( "CvDTree::predict" ); 2888 2889 __BEGIN__; 2890 2891 int i, step, mstep = 0; 2892 const float* sample; 2893 const uchar* m = 0; 2894 CvDTreeNode* node = root; 2895 const int* vtype; 2896 const int* vidx; 2897 const int* cmap; 2898 const int* cofs; 2899 2900 if( !node ) 2901 CV_ERROR( CV_StsError, "The tree has not been trained yet" ); 2902 2903 if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 || 2904 _sample->cols != 1 && _sample->rows != 1 || 2905 _sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input || 2906 _sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input ) 2907 CV_ERROR( CV_StsBadArg, 2908 "the input sample must be 1d floating-point vector with the same " 2909 "number of elements as the total number of variables used for training" ); 2910 2911 sample = _sample->data.fl; 2912 step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]); 2913 2914 if( data->cat_count && !preprocessed_input ) // cache for categorical variables 2915 { 2916 int n = data->cat_count->cols; 2917 catbuf = (int*)cvStackAlloc(n*sizeof(catbuf[0])); 2918 for( i = 0; i < n; i++ ) 2919 catbuf[i] = -1; 2920 } 2921 2922 if( _missing ) 2923 { 2924 if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) || 2925 !CV_ARE_SIZES_EQ(_missing, _sample) ) 2926 CV_ERROR( CV_StsBadArg, 2927 "the missing data mask must be 8-bit vector of the same size as input sample" ); 2928 m = _missing->data.ptr; 2929 mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]); 2930 } 2931 2932 vtype = data->var_type->data.i; 2933 vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0; 2934 cmap = data->cat_map ? data->cat_map->data.i : 0; 2935 cofs = data->cat_ofs ? data->cat_ofs->data.i : 0; 2936 2937 while( node->Tn > pruned_tree_idx && node->left ) 2938 { 2939 CvDTreeSplit* split = node->split; 2940 int dir = 0; 2941 for( ; !dir && split != 0; split = split->next ) 2942 { 2943 int vi = split->var_idx; 2944 int ci = vtype[vi]; 2945 i = vidx ? vidx[vi] : vi; 2946 float val = sample[i*step]; 2947 if( m && m[i*mstep] ) 2948 continue; 2949 if( ci < 0 ) // ordered 2950 dir = val <= split->ord.c ? -1 : 1; 2951 else // categorical 2952 { 2953 int c; 2954 if( preprocessed_input ) 2955 c = cvRound(val); 2956 else 2957 { 2958 c = catbuf[ci]; 2959 if( c < 0 ) 2960 { 2961 int a = c = cofs[ci]; 2962 int b = cofs[ci+1]; 2963 int ival = cvRound(val); 2964 if( ival != val ) 2965 CV_ERROR( CV_StsBadArg, 2966 "one of input categorical variable is not an integer" ); 2967 2968 while( a < b ) 2969 { 2970 c = (a + b) >> 1; 2971 if( ival < cmap[c] ) 2972 b = c; 2973 else if( ival > cmap[c] ) 2974 a = c+1; 2975 else 2976 break; 2977 } 2978 2979 if( c < 0 || ival != cmap[c] ) 2980 continue; 2981 2982 catbuf[ci] = c -= cofs[ci]; 2983 } 2984 } 2985 dir = CV_DTREE_CAT_DIR(c, split->subset); 2986 } 2987 2988 if( split->inversed ) 2989 dir = -dir; 2990 } 2991 2992 if( !dir ) 2993 { 2994 double diff = node->right->sample_count - node->left->sample_count; 2995 dir = diff < 0 ? -1 : 1; 2996 } 2997 node = dir < 0 ? node->left : node->right; 2998 } 2999 3000 result = node; 3001 3002 __END__; 3003 3004 return result; 3005} 3006 3007 3008const CvMat* CvDTree::get_var_importance() 3009{ 3010 if( !var_importance ) 3011 { 3012 CvDTreeNode* node = root; 3013 double* importance; 3014 if( !node ) 3015 return 0; 3016 var_importance = cvCreateMat( 1, data->var_count, CV_64F ); 3017 cvZero( var_importance ); 3018 importance = var_importance->data.db; 3019 3020 for(;;) 3021 { 3022 CvDTreeNode* parent; 3023 for( ;; node = node->left ) 3024 { 3025 CvDTreeSplit* split = node->split; 3026 3027 if( !node->left || node->Tn <= pruned_tree_idx ) 3028 break; 3029 3030 for( ; split != 0; split = split->next ) 3031 importance[split->var_idx] += split->quality; 3032 } 3033 3034 for( parent = node->parent; parent && parent->right == node; 3035 node = parent, parent = parent->parent ) 3036 ; 3037 3038 if( !parent ) 3039 break; 3040 3041 node = parent->right; 3042 } 3043 3044 cvNormalize( var_importance, var_importance, 1., 0, CV_L1 ); 3045 } 3046 3047 return var_importance; 3048} 3049 3050 3051void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split ) 3052{ 3053 int ci; 3054 3055 cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW ); 3056 cvWriteInt( fs, "var", split->var_idx ); 3057 cvWriteReal( fs, "quality", split->quality ); 3058 3059 ci = data->get_var_type(split->var_idx); 3060 if( ci >= 0 ) // split on a categorical var 3061 { 3062 int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir; 3063 for( i = 0; i < n; i++ ) 3064 to_right += CV_DTREE_CAT_DIR(i,split->subset) > 0; 3065 3066 // ad-hoc rule when to use inverse categorical split notation 3067 // to achieve more compact and clear representation 3068 default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1; 3069 3070 cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ? 3071 "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW ); 3072 3073 for( i = 0; i < n; i++ ) 3074 { 3075 int dir = CV_DTREE_CAT_DIR(i,split->subset); 3076 if( dir*default_dir < 0 ) 3077 cvWriteInt( fs, 0, i ); 3078 } 3079 cvEndWriteStruct( fs ); 3080 } 3081 else 3082 cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c ); 3083 3084 cvEndWriteStruct( fs ); 3085} 3086 3087 3088void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node ) 3089{ 3090 CvDTreeSplit* split; 3091 3092 cvStartWriteStruct( fs, 0, CV_NODE_MAP ); 3093 3094 cvWriteInt( fs, "depth", node->depth ); 3095 cvWriteInt( fs, "sample_count", node->sample_count ); 3096 cvWriteReal( fs, "value", node->value ); 3097 3098 if( data->is_classifier ) 3099 cvWriteInt( fs, "norm_class_idx", node->class_idx ); 3100 3101 cvWriteInt( fs, "Tn", node->Tn ); 3102 cvWriteInt( fs, "complexity", node->complexity ); 3103 cvWriteReal( fs, "alpha", node->alpha ); 3104 cvWriteReal( fs, "node_risk", node->node_risk ); 3105 cvWriteReal( fs, "tree_risk", node->tree_risk ); 3106 cvWriteReal( fs, "tree_error", node->tree_error ); 3107 3108 if( node->left ) 3109 { 3110 cvStartWriteStruct( fs, "splits", CV_NODE_SEQ ); 3111 3112 for( split = node->split; split != 0; split = split->next ) 3113 write_split( fs, split ); 3114 3115 cvEndWriteStruct( fs ); 3116 } 3117 3118 cvEndWriteStruct( fs ); 3119} 3120 3121 3122void CvDTree::write_tree_nodes( CvFileStorage* fs ) 3123{ 3124 //CV_FUNCNAME( "CvDTree::write_tree_nodes" ); 3125 3126 __BEGIN__; 3127 3128 CvDTreeNode* node = root; 3129 3130 // traverse the tree and save all the nodes in depth-first order 3131 for(;;) 3132 { 3133 CvDTreeNode* parent; 3134 for(;;) 3135 { 3136 write_node( fs, node ); 3137 if( !node->left ) 3138 break; 3139 node = node->left; 3140 } 3141 3142 for( parent = node->parent; parent && parent->right == node; 3143 node = parent, parent = parent->parent ) 3144 ; 3145 3146 if( !parent ) 3147 break; 3148 3149 node = parent->right; 3150 } 3151 3152 __END__; 3153} 3154 3155 3156void CvDTree::write( CvFileStorage* fs, const char* name ) 3157{ 3158 //CV_FUNCNAME( "CvDTree::write" ); 3159 3160 __BEGIN__; 3161 3162 cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE ); 3163 3164 get_var_importance(); 3165 data->write_params( fs ); 3166 if( var_importance ) 3167 cvWrite( fs, "var_importance", var_importance ); 3168 write( fs ); 3169 3170 cvEndWriteStruct( fs ); 3171 3172 __END__; 3173} 3174 3175 3176void CvDTree::write( CvFileStorage* fs ) 3177{ 3178 //CV_FUNCNAME( "CvDTree::write" ); 3179 3180 __BEGIN__; 3181 3182 cvWriteInt( fs, "best_tree_idx", pruned_tree_idx ); 3183 3184 cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ ); 3185 write_tree_nodes( fs ); 3186 cvEndWriteStruct( fs ); 3187 3188 __END__; 3189} 3190 3191 3192CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode ) 3193{ 3194 CvDTreeSplit* split = 0; 3195 3196 CV_FUNCNAME( "CvDTree::read_split" ); 3197 3198 __BEGIN__; 3199 3200 int vi, ci; 3201 3202 if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP ) 3203 CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" ); 3204 3205 vi = cvReadIntByName( fs, fnode, "var", -1 ); 3206 if( (unsigned)vi >= (unsigned)data->var_count ) 3207 CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" ); 3208 3209 ci = data->get_var_type(vi); 3210 if( ci >= 0 ) // split on categorical var 3211 { 3212 int i, n = data->cat_count->data.i[ci], inversed = 0, val; 3213 CvSeqReader reader; 3214 CvFileNode* inseq; 3215 split = data->new_split_cat( vi, 0 ); 3216 inseq = cvGetFileNodeByName( fs, fnode, "in" ); 3217 if( !inseq ) 3218 { 3219 inseq = cvGetFileNodeByName( fs, fnode, "not_in" ); 3220 inversed = 1; 3221 } 3222 if( !inseq || 3223 (CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ && CV_NODE_TYPE(inseq->tag) != CV_NODE_INT)) 3224 CV_ERROR( CV_StsParseError, 3225 "Either 'in' or 'not_in' tags should be inside a categorical split data" ); 3226 3227 if( CV_NODE_TYPE(inseq->tag) == CV_NODE_INT ) 3228 { 3229 val = inseq->data.i; 3230 if( (unsigned)val >= (unsigned)n ) 3231 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" ); 3232 3233 split->subset[val >> 5] |= 1 << (val & 31); 3234 } 3235 else 3236 { 3237 cvStartReadSeq( inseq->data.seq, &reader ); 3238 3239 for( i = 0; i < reader.seq->total; i++ ) 3240 { 3241 CvFileNode* inode = (CvFileNode*)reader.ptr; 3242 val = inode->data.i; 3243 if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n ) 3244 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" ); 3245 3246 split->subset[val >> 5] |= 1 << (val & 31); 3247 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader ); 3248 } 3249 } 3250 3251 // for categorical splits we do not use inversed splits, 3252 // instead we inverse the variable set in the split 3253 if( inversed ) 3254 for( i = 0; i < (n + 31) >> 5; i++ ) 3255 split->subset[i] ^= -1; 3256 } 3257 else 3258 { 3259 CvFileNode* cmp_node; 3260 split = data->new_split_ord( vi, 0, 0, 0, 0 ); 3261 3262 cmp_node = cvGetFileNodeByName( fs, fnode, "le" ); 3263 if( !cmp_node ) 3264 { 3265 cmp_node = cvGetFileNodeByName( fs, fnode, "gt" ); 3266 split->inversed = 1; 3267 } 3268 3269 split->ord.c = (float)cvReadReal( cmp_node ); 3270 } 3271 3272 split->quality = (float)cvReadRealByName( fs, fnode, "quality" ); 3273 3274 __END__; 3275 3276 return split; 3277} 3278 3279 3280CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent ) 3281{ 3282 CvDTreeNode* node = 0; 3283 3284 CV_FUNCNAME( "CvDTree::read_node" ); 3285 3286 __BEGIN__; 3287 3288 CvFileNode* splits; 3289 int i, depth; 3290 3291 if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP ) 3292 CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" ); 3293 3294 CV_CALL( node = data->new_node( parent, 0, 0, 0 )); 3295 depth = cvReadIntByName( fs, fnode, "depth", -1 ); 3296 if( depth != node->depth ) 3297 CV_ERROR( CV_StsParseError, "incorrect node depth" ); 3298 3299 node->sample_count = cvReadIntByName( fs, fnode, "sample_count" ); 3300 node->value = cvReadRealByName( fs, fnode, "value" ); 3301 if( data->is_classifier ) 3302 node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" ); 3303 3304 node->Tn = cvReadIntByName( fs, fnode, "Tn" ); 3305 node->complexity = cvReadIntByName( fs, fnode, "complexity" ); 3306 node->alpha = cvReadRealByName( fs, fnode, "alpha" ); 3307 node->node_risk = cvReadRealByName( fs, fnode, "node_risk" ); 3308 node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" ); 3309 node->tree_error = cvReadRealByName( fs, fnode, "tree_error" ); 3310 3311 splits = cvGetFileNodeByName( fs, fnode, "splits" ); 3312 if( splits ) 3313 { 3314 CvSeqReader reader; 3315 CvDTreeSplit* last_split = 0; 3316 3317 if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ ) 3318 CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" ); 3319 3320 cvStartReadSeq( splits->data.seq, &reader ); 3321 for( i = 0; i < reader.seq->total; i++ ) 3322 { 3323 CvDTreeSplit* split; 3324 CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr )); 3325 if( !last_split ) 3326 node->split = last_split = split; 3327 else 3328 last_split = last_split->next = split; 3329 3330 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader ); 3331 } 3332 } 3333 3334 __END__; 3335 3336 return node; 3337} 3338 3339 3340void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode ) 3341{ 3342 CV_FUNCNAME( "CvDTree::read_tree_nodes" ); 3343 3344 __BEGIN__; 3345 3346 CvSeqReader reader; 3347 CvDTreeNode _root; 3348 CvDTreeNode* parent = &_root; 3349 int i; 3350 parent->left = parent->right = parent->parent = 0; 3351 3352 cvStartReadSeq( fnode->data.seq, &reader ); 3353 3354 for( i = 0; i < reader.seq->total; i++ ) 3355 { 3356 CvDTreeNode* node; 3357 3358 CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 )); 3359 if( !parent->left ) 3360 parent->left = node; 3361 else 3362 parent->right = node; 3363 if( node->split ) 3364 parent = node; 3365 else 3366 { 3367 while( parent && parent->right ) 3368 parent = parent->parent; 3369 } 3370 3371 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader ); 3372 } 3373 3374 root = _root.left; 3375 3376 __END__; 3377} 3378 3379 3380void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode ) 3381{ 3382 CvDTreeTrainData* _data = new CvDTreeTrainData(); 3383 _data->read_params( fs, fnode ); 3384 3385 read( fs, fnode, _data ); 3386 get_var_importance(); 3387} 3388 3389 3390// a special entry point for reading weak decision trees from the tree ensembles 3391void CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data ) 3392{ 3393 CV_FUNCNAME( "CvDTree::read" ); 3394 3395 __BEGIN__; 3396 3397 CvFileNode* tree_nodes; 3398 3399 clear(); 3400 data = _data; 3401 3402 tree_nodes = cvGetFileNodeByName( fs, node, "nodes" ); 3403 if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ ) 3404 CV_ERROR( CV_StsParseError, "nodes tag is missing" ); 3405 3406 pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 ); 3407 read_tree_nodes( fs, tree_nodes ); 3408 3409 __END__; 3410} 3411 3412/* End of file. */ 3413