1/*M/////////////////////////////////////////////////////////////////////////////////////// 2// 3// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 4// 5// By downloading, copying, installing or using the software you agree to this license. 6// If you do not agree to this license, do not download, install, 7// copy or use the software. 8// 9// 10// Intel License Agreement 11// 12// Copyright (C) 2000, Intel Corporation, all rights reserved. 13// Third party copyrights are property of their respective owners. 14// 15// Redistribution and use in source and binary forms, with or without modification, 16// are permitted provided that the following conditions are met: 17// 18// * Redistribution's of source code must retain the above copyright notice, 19// this list of conditions and the following disclaimer. 20// 21// * Redistribution's in binary form must reproduce the above copyright notice, 22// this list of conditions and the following disclaimer in the documentation 23// and/or other materials provided with the distribution. 24// 25// * The name of Intel Corporation may not be used to endorse or promote products 26// derived from this software without specific prior written permission. 27// 28// This software is provided by the copyright holders and contributors "as is" and 29// any express or implied warranties, including, but not limited to, the implied 30// warranties of merchantability and fitness for a particular purpose are disclaimed. 31// In no event shall the Intel Corporation or contributors be liable for any direct, 32// indirect, incidental, special, exemplary, or consequential damages 33// (including, but not limited to, procurement of substitute goods or services; 34// loss of use, data, or profits; or business interruption) however caused 35// and on any theory of liability, whether in contract, strict liability, 36// or tort (including negligence or otherwise) arising in any way out of 37// the use of this software, even if advised of the possibility of such damage. 38// 39//M*/ 40 41#include "_ml.h" 42 43CvForestTree::CvForestTree() 44{ 45 forest = NULL; 46} 47 48 49CvForestTree::~CvForestTree() 50{ 51 clear(); 52} 53 54 55bool CvForestTree::train( CvDTreeTrainData* _data, 56 const CvMat* _subsample_idx, 57 CvRTrees* _forest ) 58{ 59 bool result = false; 60 61 CV_FUNCNAME( "CvForestTree::train" ); 62 63 __BEGIN__; 64 65 66 clear(); 67 forest = _forest; 68 69 data = _data; 70 data->shared = true; 71 CV_CALL(result = do_train(_subsample_idx)); 72 73 __END__; 74 75 return result; 76} 77 78 79bool 80CvForestTree::train( const CvMat*, int, const CvMat*, const CvMat*, 81 const CvMat*, const CvMat*, const CvMat*, CvDTreeParams ) 82{ 83 assert(0); 84 return false; 85} 86 87 88bool 89CvForestTree::train( CvDTreeTrainData*, const CvMat* ) 90{ 91 assert(0); 92 return false; 93} 94 95 96CvDTreeSplit* CvForestTree::find_best_split( CvDTreeNode* node ) 97{ 98 int vi; 99 CvDTreeSplit *best_split = 0, *split = 0, *t; 100 101 CV_FUNCNAME("CvForestTree::find_best_split"); 102 __BEGIN__; 103 104 CvMat* active_var_mask = 0; 105 if( forest ) 106 { 107 int var_count; 108 CvRNG* rng = forest->get_rng(); 109 110 active_var_mask = forest->get_active_var_mask(); 111 var_count = active_var_mask->cols; 112 113 CV_ASSERT( var_count == data->var_count ); 114 115 for( vi = 0; vi < var_count; vi++ ) 116 { 117 uchar temp; 118 int i1 = cvRandInt(rng) % var_count; 119 int i2 = cvRandInt(rng) % var_count; 120 CV_SWAP( active_var_mask->data.ptr[i1], 121 active_var_mask->data.ptr[i2], temp ); 122 } 123 } 124 for( vi = 0; vi < data->var_count; vi++ ) 125 { 126 int ci = data->var_type->data.i[vi]; 127 if( node->num_valid[vi] <= 1 128 || (active_var_mask && !active_var_mask->data.ptr[vi]) ) 129 continue; 130 131 if( data->is_classifier ) 132 { 133 if( ci >= 0 ) 134 split = find_split_cat_class( node, vi ); 135 else 136 split = find_split_ord_class( node, vi ); 137 } 138 else 139 { 140 if( ci >= 0 ) 141 split = find_split_cat_reg( node, vi ); 142 else 143 split = find_split_ord_reg( node, vi ); 144 } 145 146 if( split ) 147 { 148 if( !best_split || best_split->quality < split->quality ) 149 CV_SWAP( best_split, split, t ); 150 if( split ) 151 cvSetRemoveByPtr( data->split_heap, split ); 152 } 153 } 154 155 __END__; 156 157 return best_split; 158} 159 160 161void CvForestTree::read( CvFileStorage* fs, CvFileNode* fnode, CvRTrees* _forest, CvDTreeTrainData* _data ) 162{ 163 CvDTree::read( fs, fnode, _data ); 164 forest = _forest; 165} 166 167 168void CvForestTree::read( CvFileStorage*, CvFileNode* ) 169{ 170 assert(0); 171} 172 173void CvForestTree::read( CvFileStorage* _fs, CvFileNode* _node, 174 CvDTreeTrainData* _data ) 175{ 176 CvDTree::read( _fs, _node, _data ); 177} 178 179 180////////////////////////////////////////////////////////////////////////////////////////// 181// Random trees // 182////////////////////////////////////////////////////////////////////////////////////////// 183 184CvRTrees::CvRTrees() 185{ 186 nclasses = 0; 187 oob_error = 0; 188 ntrees = 0; 189 trees = NULL; 190 data = NULL; 191 active_var_mask = NULL; 192 var_importance = NULL; 193 rng = cvRNG(0xffffffff); 194 default_model_name = "my_random_trees"; 195} 196 197 198void CvRTrees::clear() 199{ 200 int k; 201 for( k = 0; k < ntrees; k++ ) 202 delete trees[k]; 203 cvFree( &trees ); 204 205 delete data; 206 data = 0; 207 208 cvReleaseMat( &active_var_mask ); 209 cvReleaseMat( &var_importance ); 210 ntrees = 0; 211} 212 213 214CvRTrees::~CvRTrees() 215{ 216 clear(); 217} 218 219 220CvMat* CvRTrees::get_active_var_mask() 221{ 222 return active_var_mask; 223} 224 225 226CvRNG* CvRTrees::get_rng() 227{ 228 return &rng; 229} 230 231bool CvRTrees::train( const CvMat* _train_data, int _tflag, 232 const CvMat* _responses, const CvMat* _var_idx, 233 const CvMat* _sample_idx, const CvMat* _var_type, 234 const CvMat* _missing_mask, CvRTParams params ) 235{ 236 bool result = false; 237 238 CV_FUNCNAME("CvRTrees::train"); 239 __BEGIN__; 240 241 int var_count = 0; 242 243 clear(); 244 245 CvDTreeParams tree_params( params.max_depth, params.min_sample_count, 246 params.regression_accuracy, params.use_surrogates, params.max_categories, 247 params.cv_folds, params.use_1se_rule, false, params.priors ); 248 249 data = new CvDTreeTrainData(); 250 CV_CALL(data->set_data( _train_data, _tflag, _responses, _var_idx, 251 _sample_idx, _var_type, _missing_mask, tree_params, true)); 252 253 var_count = data->var_count; 254 if( params.nactive_vars > var_count ) 255 params.nactive_vars = var_count; 256 else if( params.nactive_vars == 0 ) 257 params.nactive_vars = (int)sqrt((double)var_count); 258 else if( params.nactive_vars < 0 ) 259 CV_ERROR( CV_StsBadArg, "<nactive_vars> must be non-negative" ); 260 params.term_crit = cvCheckTermCriteria( params.term_crit, 0.1, 1000 ); 261 262 // Create mask of active variables at the tree nodes 263 CV_CALL(active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 )); 264 if( params.calc_var_importance ) 265 { 266 CV_CALL(var_importance = cvCreateMat( 1, var_count, CV_32FC1 )); 267 cvZero(var_importance); 268 } 269 { // initialize active variables mask 270 CvMat submask1, submask2; 271 cvGetCols( active_var_mask, &submask1, 0, params.nactive_vars ); 272 cvGetCols( active_var_mask, &submask2, params.nactive_vars, var_count ); 273 cvSet( &submask1, cvScalar(1) ); 274 cvZero( &submask2 ); 275 } 276 277 CV_CALL(result = grow_forest( params.term_crit )); 278 279 result = true; 280 281 __END__; 282 283 return result; 284} 285 286 287bool CvRTrees::grow_forest( const CvTermCriteria term_crit ) 288{ 289 bool result = false; 290 291 CvMat* sample_idx_mask_for_tree = 0; 292 CvMat* sample_idx_for_tree = 0; 293 294 CvMat* oob_sample_votes = 0; 295 CvMat* oob_responses = 0; 296 297 float* oob_samples_perm_ptr= 0; 298 299 float* samples_ptr = 0; 300 uchar* missing_ptr = 0; 301 float* true_resp_ptr = 0; 302 303 CV_FUNCNAME("CvRTrees::grow_forest"); 304 __BEGIN__; 305 306 const int max_ntrees = term_crit.max_iter; 307 const double max_oob_err = term_crit.epsilon; 308 309 const int dims = data->var_count; 310 float maximal_response = 0; 311 312 // oob_predictions_sum[i] = sum of predicted values for the i-th sample 313 // oob_num_of_predictions[i] = number of summands 314 // (number of predictions for the i-th sample) 315 // initialize these variable to avoid warning C4701 316 CvMat oob_predictions_sum = cvMat( 1, 1, CV_32FC1 ); 317 CvMat oob_num_of_predictions = cvMat( 1, 1, CV_32FC1 ); 318 319 nsamples = data->sample_count; 320 nclasses = data->get_num_classes(); 321 322 trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees ); 323 memset( trees, 0, sizeof(trees[0])*max_ntrees ); 324 325 if( data->is_classifier ) 326 { 327 CV_CALL(oob_sample_votes = cvCreateMat( nsamples, nclasses, CV_32SC1 )); 328 cvZero(oob_sample_votes); 329 } 330 else 331 { 332 // oob_responses[0,i] = oob_predictions_sum[i] 333 // = sum of predicted values for the i-th sample 334 // oob_responses[1,i] = oob_num_of_predictions[i] 335 // = number of summands (number of predictions for the i-th sample) 336 CV_CALL(oob_responses = cvCreateMat( 2, nsamples, CV_32FC1 )); 337 cvZero(oob_responses); 338 cvGetRow( oob_responses, &oob_predictions_sum, 0 ); 339 cvGetRow( oob_responses, &oob_num_of_predictions, 1 ); 340 } 341 CV_CALL(sample_idx_mask_for_tree = cvCreateMat( 1, nsamples, CV_8UC1 )); 342 CV_CALL(sample_idx_for_tree = cvCreateMat( 1, nsamples, CV_32SC1 )); 343 CV_CALL(oob_samples_perm_ptr = (float*)cvAlloc( sizeof(float)*nsamples*dims )); 344 CV_CALL(samples_ptr = (float*)cvAlloc( sizeof(float)*nsamples*dims )); 345 CV_CALL(missing_ptr = (uchar*)cvAlloc( sizeof(uchar)*nsamples*dims )); 346 CV_CALL(true_resp_ptr = (float*)cvAlloc( sizeof(float)*nsamples )); 347 348 CV_CALL(data->get_vectors( 0, samples_ptr, missing_ptr, true_resp_ptr )); 349 { 350 double minval, maxval; 351 CvMat responses = cvMat(1, nsamples, CV_32FC1, true_resp_ptr); 352 cvMinMaxLoc( &responses, &minval, &maxval ); 353 maximal_response = (float)MAX( MAX( fabs(minval), fabs(maxval) ), 0 ); 354 } 355 356 ntrees = 0; 357 while( ntrees < max_ntrees ) 358 { 359 int i, oob_samples_count = 0; 360 double ncorrect_responses = 0; // used for estimation of variable importance 361 CvMat sample, missing; 362 CvForestTree* tree = 0; 363 364 cvZero( sample_idx_mask_for_tree ); 365 for( i = 0; i < nsamples; i++ ) //form sample for creation one tree 366 { 367 int idx = cvRandInt( &rng ) % nsamples; 368 sample_idx_for_tree->data.i[i] = idx; 369 sample_idx_mask_for_tree->data.ptr[idx] = 0xFF; 370 } 371 372 trees[ntrees] = new CvForestTree(); 373 tree = trees[ntrees]; 374 CV_CALL(tree->train( data, sample_idx_for_tree, this )); 375 376 // form array of OOB samples indices and get these samples 377 sample = cvMat( 1, dims, CV_32FC1, samples_ptr ); 378 missing = cvMat( 1, dims, CV_8UC1, missing_ptr ); 379 380 oob_error = 0; 381 for( i = 0; i < nsamples; i++, 382 sample.data.fl += dims, missing.data.ptr += dims ) 383 { 384 CvDTreeNode* predicted_node = 0; 385 // check if the sample is OOB 386 if( sample_idx_mask_for_tree->data.ptr[i] ) 387 continue; 388 389 // predict oob samples 390 if( !predicted_node ) 391 CV_CALL(predicted_node = tree->predict(&sample, &missing, true)); 392 393 if( !data->is_classifier ) //regression 394 { 395 double avg_resp, resp = predicted_node->value; 396 oob_predictions_sum.data.fl[i] += (float)resp; 397 oob_num_of_predictions.data.fl[i] += 1; 398 399 // compute oob error 400 avg_resp = oob_predictions_sum.data.fl[i]/oob_num_of_predictions.data.fl[i]; 401 avg_resp -= true_resp_ptr[i]; 402 oob_error += avg_resp*avg_resp; 403 resp = (resp - true_resp_ptr[i])/maximal_response; 404 ncorrect_responses += exp( -resp*resp ); 405 } 406 else //classification 407 { 408 double prdct_resp; 409 CvPoint max_loc; 410 CvMat votes; 411 412 cvGetRow(oob_sample_votes, &votes, i); 413 votes.data.i[predicted_node->class_idx]++; 414 415 // compute oob error 416 cvMinMaxLoc( &votes, 0, 0, 0, &max_loc ); 417 418 prdct_resp = data->cat_map->data.i[max_loc.x]; 419 oob_error += (fabs(prdct_resp - true_resp_ptr[i]) < FLT_EPSILON) ? 0 : 1; 420 421 ncorrect_responses += cvRound(predicted_node->value - true_resp_ptr[i]) == 0; 422 } 423 oob_samples_count++; 424 } 425 if( oob_samples_count > 0 ) 426 oob_error /= (double)oob_samples_count; 427 428 // estimate variable importance 429 if( var_importance && oob_samples_count > 0 ) 430 { 431 int m; 432 433 memcpy( oob_samples_perm_ptr, samples_ptr, dims*nsamples*sizeof(float)); 434 for( m = 0; m < dims; m++ ) 435 { 436 double ncorrect_responses_permuted = 0; 437 // randomly permute values of the m-th variable in the oob samples 438 float* mth_var_ptr = oob_samples_perm_ptr + m; 439 440 for( i = 0; i < nsamples; i++ ) 441 { 442 int i1, i2; 443 float temp; 444 445 if( sample_idx_mask_for_tree->data.ptr[i] ) //the sample is not OOB 446 continue; 447 i1 = cvRandInt( &rng ) % nsamples; 448 i2 = cvRandInt( &rng ) % nsamples; 449 CV_SWAP( mth_var_ptr[i1*dims], mth_var_ptr[i2*dims], temp ); 450 451 // turn values of (m-1)-th variable, that were permuted 452 // at the previous iteration, untouched 453 if( m > 1 ) 454 oob_samples_perm_ptr[i*dims+m-1] = samples_ptr[i*dims+m-1]; 455 } 456 457 // predict "permuted" cases and calculate the number of votes for the 458 // correct class in the variable-m-permuted oob data 459 sample = cvMat( 1, dims, CV_32FC1, oob_samples_perm_ptr ); 460 missing = cvMat( 1, dims, CV_8UC1, missing_ptr ); 461 for( i = 0; i < nsamples; i++, 462 sample.data.fl += dims, missing.data.ptr += dims ) 463 { 464 double predct_resp, true_resp; 465 466 if( sample_idx_mask_for_tree->data.ptr[i] ) //the sample is not OOB 467 continue; 468 469 predct_resp = tree->predict(&sample, &missing, true)->value; 470 true_resp = true_resp_ptr[i]; 471 if( data->is_classifier ) 472 ncorrect_responses_permuted += cvRound(true_resp - predct_resp) == 0; 473 else 474 { 475 true_resp = (true_resp - predct_resp)/maximal_response; 476 ncorrect_responses_permuted += exp( -true_resp*true_resp ); 477 } 478 } 479 var_importance->data.fl[m] += (float)(ncorrect_responses 480 - ncorrect_responses_permuted); 481 } 482 } 483 ntrees++; 484 if( term_crit.type != CV_TERMCRIT_ITER && oob_error < max_oob_err ) 485 break; 486 } 487 if( var_importance ) 488 CV_CALL(cvConvertScale( var_importance, var_importance, 1./ntrees/nsamples )); 489 490 result = true; 491 492 __END__; 493 494 cvReleaseMat( &sample_idx_mask_for_tree ); 495 cvReleaseMat( &sample_idx_for_tree ); 496 cvReleaseMat( &oob_sample_votes ); 497 cvReleaseMat( &oob_responses ); 498 499 cvFree( &oob_samples_perm_ptr ); 500 cvFree( &samples_ptr ); 501 cvFree( &missing_ptr ); 502 cvFree( &true_resp_ptr ); 503 504 return result; 505} 506 507 508const CvMat* CvRTrees::get_var_importance() 509{ 510 return var_importance; 511} 512 513 514float CvRTrees::get_proximity( const CvMat* sample1, const CvMat* sample2, 515 const CvMat* missing1, const CvMat* missing2 ) const 516{ 517 float result = 0; 518 519 CV_FUNCNAME( "CvRTrees::get_proximity" ); 520 521 __BEGIN__; 522 523 int i; 524 for( i = 0; i < ntrees; i++ ) 525 result += trees[i]->predict( sample1, missing1 ) == 526 trees[i]->predict( sample2, missing2 ) ? 1 : 0; 527 result = result/(float)ntrees; 528 529 __END__; 530 531 return result; 532} 533 534 535float CvRTrees::predict( const CvMat* sample, const CvMat* missing ) const 536{ 537 double result = -1; 538 539 CV_FUNCNAME("CvRTrees::predict"); 540 __BEGIN__; 541 542 int k; 543 544 if( nclasses > 0 ) //classification 545 { 546 int max_nvotes = 0; 547 int* votes = (int*)alloca( sizeof(int)*nclasses ); 548 memset( votes, 0, sizeof(*votes)*nclasses ); 549 for( k = 0; k < ntrees; k++ ) 550 { 551 CvDTreeNode* predicted_node = trees[k]->predict( sample, missing ); 552 int nvotes; 553 int class_idx = predicted_node->class_idx; 554 CV_ASSERT( 0 <= class_idx && class_idx < nclasses ); 555 556 nvotes = ++votes[class_idx]; 557 if( nvotes > max_nvotes ) 558 { 559 max_nvotes = nvotes; 560 result = predicted_node->value; 561 } 562 } 563 } 564 else // regression 565 { 566 result = 0; 567 for( k = 0; k < ntrees; k++ ) 568 result += trees[k]->predict( sample, missing )->value; 569 result /= (double)ntrees; 570 } 571 572 __END__; 573 574 return (float)result; 575} 576 577 578void CvRTrees::write( CvFileStorage* fs, const char* name ) 579{ 580 CV_FUNCNAME( "CvRTrees::write" ); 581 582 __BEGIN__; 583 584 int k; 585 586 if( ntrees < 1 || !trees || nsamples < 1 ) 587 CV_ERROR( CV_StsBadArg, "Invalid CvRTrees object" ); 588 589 cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_RTREES ); 590 591 cvWriteInt( fs, "nclasses", nclasses ); 592 cvWriteInt( fs, "nsamples", nsamples ); 593 cvWriteInt( fs, "nactive_vars", (int)cvSum(active_var_mask).val[0] ); 594 cvWriteReal( fs, "oob_error", oob_error ); 595 596 if( var_importance ) 597 cvWrite( fs, "var_importance", var_importance ); 598 599 cvWriteInt( fs, "ntrees", ntrees ); 600 601 CV_CALL(data->write_params( fs )); 602 603 cvStartWriteStruct( fs, "trees", CV_NODE_SEQ ); 604 605 for( k = 0; k < ntrees; k++ ) 606 { 607 cvStartWriteStruct( fs, 0, CV_NODE_MAP ); 608 CV_CALL( trees[k]->write( fs )); 609 cvEndWriteStruct( fs ); 610 } 611 612 cvEndWriteStruct( fs ); //trees 613 cvEndWriteStruct( fs ); //CV_TYPE_NAME_ML_RTREES 614 615 __END__; 616} 617 618 619void CvRTrees::read( CvFileStorage* fs, CvFileNode* fnode ) 620{ 621 CV_FUNCNAME( "CvRTrees::read" ); 622 623 __BEGIN__; 624 625 int nactive_vars, var_count, k; 626 CvSeqReader reader; 627 CvFileNode* trees_fnode = 0; 628 629 clear(); 630 631 nclasses = cvReadIntByName( fs, fnode, "nclasses", -1 ); 632 nsamples = cvReadIntByName( fs, fnode, "nsamples" ); 633 nactive_vars = cvReadIntByName( fs, fnode, "nactive_vars", -1 ); 634 oob_error = cvReadRealByName(fs, fnode, "oob_error", -1 ); 635 ntrees = cvReadIntByName( fs, fnode, "ntrees", -1 ); 636 637 var_importance = (CvMat*)cvReadByName( fs, fnode, "var_importance" ); 638 639 if( nclasses < 0 || nsamples <= 0 || nactive_vars < 0 || oob_error < 0 || ntrees <= 0) 640 CV_ERROR( CV_StsParseError, "Some <nclasses>, <nsamples>, <var_count>, " 641 "<nactive_vars>, <oob_error>, <ntrees> of tags are missing" ); 642 643 rng = CvRNG( -1 ); 644 645 trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*ntrees ); 646 memset( trees, 0, sizeof(trees[0])*ntrees ); 647 648 data = new CvDTreeTrainData(); 649 data->read_params( fs, fnode ); 650 data->shared = true; 651 652 trees_fnode = cvGetFileNodeByName( fs, fnode, "trees" ); 653 if( !trees_fnode || !CV_NODE_IS_SEQ(trees_fnode->tag) ) 654 CV_ERROR( CV_StsParseError, "<trees> tag is missing" ); 655 656 cvStartReadSeq( trees_fnode->data.seq, &reader ); 657 if( reader.seq->total != ntrees ) 658 CV_ERROR( CV_StsParseError, 659 "<ntrees> is not equal to the number of trees saved in file" ); 660 661 for( k = 0; k < ntrees; k++ ) 662 { 663 trees[k] = new CvForestTree(); 664 CV_CALL(trees[k]->read( fs, (CvFileNode*)reader.ptr, this, data )); 665 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader ); 666 } 667 668 var_count = data->var_count; 669 CV_CALL(active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 )); 670 { 671 // initialize active variables mask 672 CvMat submask1, submask2; 673 cvGetCols( active_var_mask, &submask1, 0, nactive_vars ); 674 cvGetCols( active_var_mask, &submask2, nactive_vars, var_count ); 675 cvSet( &submask1, cvScalar(1) ); 676 cvZero( &submask2 ); 677 } 678 679 __END__; 680} 681 682 683int CvRTrees::get_tree_count() const 684{ 685 return ntrees; 686} 687 688CvForestTree* CvRTrees::get_tree(int i) const 689{ 690 return (unsigned)i < (unsigned)ntrees ? trees[i] : 0; 691} 692 693// End of file. 694