1/*M///////////////////////////////////////////////////////////////////////////////////////
2//
3//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4//
5//  By downloading, copying, installing or using the software you agree to this license.
6//  If you do not agree to this license, do not download, install,
7//  copy or use the software.
8//
9//
10//                        Intel License Agreement
11//
12// Copyright (C) 2000, Intel Corporation, all rights reserved.
13// Third party copyrights are property of their respective owners.
14//
15// Redistribution and use in source and binary forms, with or without modification,
16// are permitted provided that the following conditions are met:
17//
18//   * Redistribution's of source code must retain the above copyright notice,
19//     this list of conditions and the following disclaimer.
20//
21//   * Redistribution's in binary form must reproduce the above copyright notice,
22//     this list of conditions and the following disclaimer in the documentation
23//     and/or other materials provided with the distribution.
24//
25//   * The name of Intel Corporation may not be used to endorse or promote products
26//     derived from this software without specific prior written permission.
27//
28// This software is provided by the copyright holders and contributors "as is" and
29// any express or implied warranties, including, but not limited to, the implied
30// warranties of merchantability and fitness for a particular purpose are disclaimed.
31// In no event shall the Intel Corporation or contributors be liable for any direct,
32// indirect, incidental, special, exemplary, or consequential damages
33// (including, but not limited to, procurement of substitute goods or services;
34// loss of use, data, or profits; or business interruption) however caused
35// and on any theory of liability, whether in contract, strict liability,
36// or tort (including negligence or otherwise) arising in any way out of
37// the use of this software, even if advised of the possibility of such damage.
38//
39//M*/
40
41#include "_ml.h"
42
43CvForestTree::CvForestTree()
44{
45    forest = NULL;
46}
47
48
49CvForestTree::~CvForestTree()
50{
51    clear();
52}
53
54
55bool CvForestTree::train( CvDTreeTrainData* _data,
56                          const CvMat* _subsample_idx,
57                          CvRTrees* _forest )
58{
59    bool result = false;
60
61    CV_FUNCNAME( "CvForestTree::train" );
62
63    __BEGIN__;
64
65
66    clear();
67    forest = _forest;
68
69    data = _data;
70    data->shared = true;
71    CV_CALL(result = do_train(_subsample_idx));
72
73    __END__;
74
75    return result;
76}
77
78
79bool
80CvForestTree::train( const CvMat*, int, const CvMat*, const CvMat*,
81                    const CvMat*, const CvMat*, const CvMat*, CvDTreeParams )
82{
83    assert(0);
84    return false;
85}
86
87
88bool
89CvForestTree::train( CvDTreeTrainData*, const CvMat* )
90{
91    assert(0);
92    return false;
93}
94
95
96CvDTreeSplit* CvForestTree::find_best_split( CvDTreeNode* node )
97{
98    int vi;
99    CvDTreeSplit *best_split = 0, *split = 0, *t;
100
101    CV_FUNCNAME("CvForestTree::find_best_split");
102    __BEGIN__;
103
104    CvMat* active_var_mask = 0;
105    if( forest )
106    {
107        int var_count;
108        CvRNG* rng = forest->get_rng();
109
110        active_var_mask = forest->get_active_var_mask();
111        var_count = active_var_mask->cols;
112
113        CV_ASSERT( var_count == data->var_count );
114
115        for( vi = 0; vi < var_count; vi++ )
116        {
117            uchar temp;
118            int i1 = cvRandInt(rng) % var_count;
119            int i2 = cvRandInt(rng) % var_count;
120            CV_SWAP( active_var_mask->data.ptr[i1],
121                active_var_mask->data.ptr[i2], temp );
122        }
123    }
124    for( vi = 0; vi < data->var_count; vi++ )
125    {
126        int ci = data->var_type->data.i[vi];
127        if( node->num_valid[vi] <= 1
128            || (active_var_mask && !active_var_mask->data.ptr[vi]) )
129            continue;
130
131        if( data->is_classifier )
132        {
133            if( ci >= 0 )
134                split = find_split_cat_class( node, vi );
135            else
136                split = find_split_ord_class( node, vi );
137        }
138        else
139        {
140            if( ci >= 0 )
141                split = find_split_cat_reg( node, vi );
142            else
143                split = find_split_ord_reg( node, vi );
144        }
145
146        if( split )
147        {
148            if( !best_split || best_split->quality < split->quality )
149                CV_SWAP( best_split, split, t );
150            if( split )
151                cvSetRemoveByPtr( data->split_heap, split );
152        }
153    }
154
155    __END__;
156
157    return best_split;
158}
159
160
161void CvForestTree::read( CvFileStorage* fs, CvFileNode* fnode, CvRTrees* _forest, CvDTreeTrainData* _data )
162{
163    CvDTree::read( fs, fnode, _data );
164    forest = _forest;
165}
166
167
168void CvForestTree::read( CvFileStorage*, CvFileNode* )
169{
170    assert(0);
171}
172
173void CvForestTree::read( CvFileStorage* _fs, CvFileNode* _node,
174                         CvDTreeTrainData* _data )
175{
176    CvDTree::read( _fs, _node, _data );
177}
178
179
180//////////////////////////////////////////////////////////////////////////////////////////
181//                                  Random trees                                        //
182//////////////////////////////////////////////////////////////////////////////////////////
183
184CvRTrees::CvRTrees()
185{
186    nclasses         = 0;
187    oob_error        = 0;
188    ntrees           = 0;
189    trees            = NULL;
190    data             = NULL;
191    active_var_mask  = NULL;
192    var_importance   = NULL;
193    rng = cvRNG(0xffffffff);
194    default_model_name = "my_random_trees";
195}
196
197
198void CvRTrees::clear()
199{
200    int k;
201    for( k = 0; k < ntrees; k++ )
202        delete trees[k];
203    cvFree( &trees );
204
205    delete data;
206    data = 0;
207
208    cvReleaseMat( &active_var_mask );
209    cvReleaseMat( &var_importance );
210    ntrees = 0;
211}
212
213
214CvRTrees::~CvRTrees()
215{
216    clear();
217}
218
219
220CvMat* CvRTrees::get_active_var_mask()
221{
222    return active_var_mask;
223}
224
225
226CvRNG* CvRTrees::get_rng()
227{
228    return &rng;
229}
230
231bool CvRTrees::train( const CvMat* _train_data, int _tflag,
232                        const CvMat* _responses, const CvMat* _var_idx,
233                        const CvMat* _sample_idx, const CvMat* _var_type,
234                        const CvMat* _missing_mask, CvRTParams params )
235{
236    bool result = false;
237
238    CV_FUNCNAME("CvRTrees::train");
239    __BEGIN__;
240
241    int var_count = 0;
242
243    clear();
244
245    CvDTreeParams tree_params( params.max_depth, params.min_sample_count,
246        params.regression_accuracy, params.use_surrogates, params.max_categories,
247        params.cv_folds, params.use_1se_rule, false, params.priors );
248
249    data = new CvDTreeTrainData();
250    CV_CALL(data->set_data( _train_data, _tflag, _responses, _var_idx,
251        _sample_idx, _var_type, _missing_mask, tree_params, true));
252
253    var_count = data->var_count;
254    if( params.nactive_vars > var_count )
255        params.nactive_vars = var_count;
256    else if( params.nactive_vars == 0 )
257        params.nactive_vars = (int)sqrt((double)var_count);
258    else if( params.nactive_vars < 0 )
259        CV_ERROR( CV_StsBadArg, "<nactive_vars> must be non-negative" );
260    params.term_crit = cvCheckTermCriteria( params.term_crit, 0.1, 1000 );
261
262    // Create mask of active variables at the tree nodes
263    CV_CALL(active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 ));
264    if( params.calc_var_importance )
265    {
266        CV_CALL(var_importance  = cvCreateMat( 1, var_count, CV_32FC1 ));
267        cvZero(var_importance);
268    }
269    { // initialize active variables mask
270        CvMat submask1, submask2;
271        cvGetCols( active_var_mask, &submask1, 0, params.nactive_vars );
272        cvGetCols( active_var_mask, &submask2, params.nactive_vars, var_count );
273        cvSet( &submask1, cvScalar(1) );
274        cvZero( &submask2 );
275    }
276
277    CV_CALL(result = grow_forest( params.term_crit ));
278
279    result = true;
280
281    __END__;
282
283    return result;
284}
285
286
287bool CvRTrees::grow_forest( const CvTermCriteria term_crit )
288{
289    bool result = false;
290
291    CvMat* sample_idx_mask_for_tree = 0;
292    CvMat* sample_idx_for_tree      = 0;
293
294    CvMat* oob_sample_votes	   = 0;
295    CvMat* oob_responses       = 0;
296
297    float* oob_samples_perm_ptr= 0;
298
299    float* samples_ptr     = 0;
300    uchar* missing_ptr     = 0;
301    float* true_resp_ptr   = 0;
302
303    CV_FUNCNAME("CvRTrees::grow_forest");
304    __BEGIN__;
305
306    const int max_ntrees = term_crit.max_iter;
307    const double max_oob_err = term_crit.epsilon;
308
309    const int dims = data->var_count;
310    float maximal_response = 0;
311
312    // oob_predictions_sum[i] = sum of predicted values for the i-th sample
313    // oob_num_of_predictions[i] = number of summands
314    //                            (number of predictions for the i-th sample)
315    // initialize these variable to avoid warning C4701
316    CvMat oob_predictions_sum = cvMat( 1, 1, CV_32FC1 );
317    CvMat oob_num_of_predictions = cvMat( 1, 1, CV_32FC1 );
318
319    nsamples = data->sample_count;
320    nclasses = data->get_num_classes();
321
322    trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees );
323    memset( trees, 0, sizeof(trees[0])*max_ntrees );
324
325    if( data->is_classifier )
326    {
327        CV_CALL(oob_sample_votes = cvCreateMat( nsamples, nclasses, CV_32SC1 ));
328        cvZero(oob_sample_votes);
329    }
330    else
331    {
332        // oob_responses[0,i] = oob_predictions_sum[i]
333        //    = sum of predicted values for the i-th sample
334        // oob_responses[1,i] = oob_num_of_predictions[i]
335        //    = number of summands (number of predictions for the i-th sample)
336        CV_CALL(oob_responses = cvCreateMat( 2, nsamples, CV_32FC1 ));
337        cvZero(oob_responses);
338        cvGetRow( oob_responses, &oob_predictions_sum, 0 );
339        cvGetRow( oob_responses, &oob_num_of_predictions, 1 );
340    }
341    CV_CALL(sample_idx_mask_for_tree = cvCreateMat( 1, nsamples, CV_8UC1 ));
342    CV_CALL(sample_idx_for_tree      = cvCreateMat( 1, nsamples, CV_32SC1 ));
343    CV_CALL(oob_samples_perm_ptr     = (float*)cvAlloc( sizeof(float)*nsamples*dims ));
344    CV_CALL(samples_ptr              = (float*)cvAlloc( sizeof(float)*nsamples*dims ));
345    CV_CALL(missing_ptr              = (uchar*)cvAlloc( sizeof(uchar)*nsamples*dims ));
346    CV_CALL(true_resp_ptr            = (float*)cvAlloc( sizeof(float)*nsamples ));
347
348    CV_CALL(data->get_vectors( 0, samples_ptr, missing_ptr, true_resp_ptr ));
349    {
350        double minval, maxval;
351        CvMat responses = cvMat(1, nsamples, CV_32FC1, true_resp_ptr);
352        cvMinMaxLoc( &responses, &minval, &maxval );
353        maximal_response = (float)MAX( MAX( fabs(minval), fabs(maxval) ), 0 );
354    }
355
356    ntrees = 0;
357    while( ntrees < max_ntrees )
358    {
359        int i, oob_samples_count = 0;
360        double ncorrect_responses = 0; // used for estimation of variable importance
361        CvMat sample, missing;
362        CvForestTree* tree = 0;
363
364        cvZero( sample_idx_mask_for_tree );
365        for( i = 0; i < nsamples; i++ ) //form sample for creation one tree
366        {
367            int idx = cvRandInt( &rng ) % nsamples;
368            sample_idx_for_tree->data.i[i] = idx;
369            sample_idx_mask_for_tree->data.ptr[idx] = 0xFF;
370        }
371
372        trees[ntrees] = new CvForestTree();
373        tree = trees[ntrees];
374        CV_CALL(tree->train( data, sample_idx_for_tree, this ));
375
376        // form array of OOB samples indices and get these samples
377        sample   = cvMat( 1, dims, CV_32FC1, samples_ptr );
378        missing  = cvMat( 1, dims, CV_8UC1,  missing_ptr );
379
380        oob_error = 0;
381        for( i = 0; i < nsamples; i++,
382            sample.data.fl += dims, missing.data.ptr += dims )
383        {
384            CvDTreeNode* predicted_node = 0;
385            // check if the sample is OOB
386            if( sample_idx_mask_for_tree->data.ptr[i] )
387                continue;
388
389            // predict oob samples
390            if( !predicted_node )
391                CV_CALL(predicted_node = tree->predict(&sample, &missing, true));
392
393            if( !data->is_classifier ) //regression
394            {
395                double avg_resp, resp = predicted_node->value;
396                oob_predictions_sum.data.fl[i] += (float)resp;
397                oob_num_of_predictions.data.fl[i] += 1;
398
399                // compute oob error
400                avg_resp = oob_predictions_sum.data.fl[i]/oob_num_of_predictions.data.fl[i];
401                avg_resp -= true_resp_ptr[i];
402                oob_error += avg_resp*avg_resp;
403                resp = (resp - true_resp_ptr[i])/maximal_response;
404                ncorrect_responses += exp( -resp*resp );
405            }
406            else //classification
407            {
408                double prdct_resp;
409                CvPoint max_loc;
410                CvMat votes;
411
412                cvGetRow(oob_sample_votes, &votes, i);
413                votes.data.i[predicted_node->class_idx]++;
414
415                // compute oob error
416                cvMinMaxLoc( &votes, 0, 0, 0, &max_loc );
417
418                prdct_resp = data->cat_map->data.i[max_loc.x];
419                oob_error += (fabs(prdct_resp - true_resp_ptr[i]) < FLT_EPSILON) ? 0 : 1;
420
421                ncorrect_responses += cvRound(predicted_node->value - true_resp_ptr[i]) == 0;
422            }
423            oob_samples_count++;
424        }
425        if( oob_samples_count > 0 )
426            oob_error /= (double)oob_samples_count;
427
428        // estimate variable importance
429        if( var_importance && oob_samples_count > 0 )
430        {
431            int m;
432
433            memcpy( oob_samples_perm_ptr, samples_ptr, dims*nsamples*sizeof(float));
434            for( m = 0; m < dims; m++ )
435            {
436                double ncorrect_responses_permuted = 0;
437                // randomly permute values of the m-th variable in the oob samples
438                float* mth_var_ptr = oob_samples_perm_ptr + m;
439
440                for( i = 0; i < nsamples; i++ )
441                {
442                    int i1, i2;
443                    float temp;
444
445                    if( sample_idx_mask_for_tree->data.ptr[i] ) //the sample is not OOB
446                        continue;
447                    i1 = cvRandInt( &rng ) % nsamples;
448                    i2 = cvRandInt( &rng ) % nsamples;
449                    CV_SWAP( mth_var_ptr[i1*dims], mth_var_ptr[i2*dims], temp );
450
451                    // turn values of (m-1)-th variable, that were permuted
452                    // at the previous iteration, untouched
453                    if( m > 1 )
454                        oob_samples_perm_ptr[i*dims+m-1] = samples_ptr[i*dims+m-1];
455                }
456
457                // predict "permuted" cases and calculate the number of votes for the
458                // correct class in the variable-m-permuted oob data
459                sample  = cvMat( 1, dims, CV_32FC1, oob_samples_perm_ptr );
460                missing = cvMat( 1, dims, CV_8UC1, missing_ptr );
461                for( i = 0; i < nsamples; i++,
462                    sample.data.fl += dims, missing.data.ptr += dims )
463                {
464                    double predct_resp, true_resp;
465
466                    if( sample_idx_mask_for_tree->data.ptr[i] ) //the sample is not OOB
467                        continue;
468
469                    predct_resp = tree->predict(&sample, &missing, true)->value;
470                    true_resp   = true_resp_ptr[i];
471                    if( data->is_classifier )
472                        ncorrect_responses_permuted += cvRound(true_resp - predct_resp) == 0;
473                    else
474                    {
475                        true_resp = (true_resp - predct_resp)/maximal_response;
476                        ncorrect_responses_permuted += exp( -true_resp*true_resp );
477                    }
478                }
479                var_importance->data.fl[m] += (float)(ncorrect_responses
480                    - ncorrect_responses_permuted);
481            }
482        }
483        ntrees++;
484        if( term_crit.type != CV_TERMCRIT_ITER && oob_error < max_oob_err )
485            break;
486    }
487    if( var_importance )
488        CV_CALL(cvConvertScale( var_importance, var_importance, 1./ntrees/nsamples ));
489
490    result = true;
491
492    __END__;
493
494    cvReleaseMat( &sample_idx_mask_for_tree );
495    cvReleaseMat( &sample_idx_for_tree );
496    cvReleaseMat( &oob_sample_votes );
497    cvReleaseMat( &oob_responses );
498
499    cvFree( &oob_samples_perm_ptr );
500    cvFree( &samples_ptr );
501    cvFree( &missing_ptr );
502    cvFree( &true_resp_ptr );
503
504    return result;
505}
506
507
508const CvMat* CvRTrees::get_var_importance()
509{
510    return var_importance;
511}
512
513
514float CvRTrees::get_proximity( const CvMat* sample1, const CvMat* sample2,
515                              const CvMat* missing1, const CvMat* missing2 ) const
516{
517    float result = 0;
518
519    CV_FUNCNAME( "CvRTrees::get_proximity" );
520
521    __BEGIN__;
522
523    int i;
524    for( i = 0; i < ntrees; i++ )
525        result += trees[i]->predict( sample1, missing1 ) ==
526        trees[i]->predict( sample2, missing2 ) ?  1 : 0;
527    result = result/(float)ntrees;
528
529    __END__;
530
531    return result;
532}
533
534
535float CvRTrees::predict( const CvMat* sample, const CvMat* missing ) const
536{
537    double result = -1;
538
539    CV_FUNCNAME("CvRTrees::predict");
540    __BEGIN__;
541
542    int k;
543
544    if( nclasses > 0 ) //classification
545    {
546        int max_nvotes = 0;
547        int* votes = (int*)alloca( sizeof(int)*nclasses );
548        memset( votes, 0, sizeof(*votes)*nclasses );
549        for( k = 0; k < ntrees; k++ )
550        {
551            CvDTreeNode* predicted_node = trees[k]->predict( sample, missing );
552            int nvotes;
553            int class_idx = predicted_node->class_idx;
554            CV_ASSERT( 0 <= class_idx && class_idx < nclasses );
555
556            nvotes = ++votes[class_idx];
557            if( nvotes > max_nvotes )
558            {
559                max_nvotes = nvotes;
560                result = predicted_node->value;
561            }
562        }
563    }
564    else // regression
565    {
566        result = 0;
567        for( k = 0; k < ntrees; k++ )
568            result += trees[k]->predict( sample, missing )->value;
569        result /= (double)ntrees;
570    }
571
572    __END__;
573
574    return (float)result;
575}
576
577
578void CvRTrees::write( CvFileStorage* fs, const char* name )
579{
580    CV_FUNCNAME( "CvRTrees::write" );
581
582    __BEGIN__;
583
584    int k;
585
586    if( ntrees < 1 || !trees || nsamples < 1 )
587        CV_ERROR( CV_StsBadArg, "Invalid CvRTrees object" );
588
589    cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_RTREES );
590
591    cvWriteInt( fs, "nclasses", nclasses );
592    cvWriteInt( fs, "nsamples", nsamples );
593    cvWriteInt( fs, "nactive_vars", (int)cvSum(active_var_mask).val[0] );
594    cvWriteReal( fs, "oob_error", oob_error );
595
596    if( var_importance )
597        cvWrite( fs, "var_importance", var_importance );
598
599    cvWriteInt( fs, "ntrees", ntrees );
600
601    CV_CALL(data->write_params( fs ));
602
603    cvStartWriteStruct( fs, "trees", CV_NODE_SEQ );
604
605    for( k = 0; k < ntrees; k++ )
606    {
607        cvStartWriteStruct( fs, 0, CV_NODE_MAP );
608        CV_CALL( trees[k]->write( fs ));
609        cvEndWriteStruct( fs );
610    }
611
612    cvEndWriteStruct( fs ); //trees
613    cvEndWriteStruct( fs ); //CV_TYPE_NAME_ML_RTREES
614
615    __END__;
616}
617
618
619void CvRTrees::read( CvFileStorage* fs, CvFileNode* fnode )
620{
621    CV_FUNCNAME( "CvRTrees::read" );
622
623    __BEGIN__;
624
625    int nactive_vars, var_count, k;
626    CvSeqReader reader;
627    CvFileNode* trees_fnode = 0;
628
629    clear();
630
631    nclasses     = cvReadIntByName( fs, fnode, "nclasses", -1 );
632    nsamples     = cvReadIntByName( fs, fnode, "nsamples" );
633    nactive_vars = cvReadIntByName( fs, fnode, "nactive_vars", -1 );
634    oob_error    = cvReadRealByName(fs, fnode, "oob_error", -1 );
635    ntrees       = cvReadIntByName( fs, fnode, "ntrees", -1 );
636
637    var_importance = (CvMat*)cvReadByName( fs, fnode, "var_importance" );
638
639    if( nclasses < 0 || nsamples <= 0 || nactive_vars < 0 || oob_error < 0 || ntrees <= 0)
640        CV_ERROR( CV_StsParseError, "Some <nclasses>, <nsamples>, <var_count>, "
641        "<nactive_vars>, <oob_error>, <ntrees> of tags are missing" );
642
643    rng = CvRNG( -1 );
644
645    trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*ntrees );
646    memset( trees, 0, sizeof(trees[0])*ntrees );
647
648    data = new CvDTreeTrainData();
649    data->read_params( fs, fnode );
650    data->shared = true;
651
652    trees_fnode = cvGetFileNodeByName( fs, fnode, "trees" );
653    if( !trees_fnode || !CV_NODE_IS_SEQ(trees_fnode->tag) )
654        CV_ERROR( CV_StsParseError, "<trees> tag is missing" );
655
656    cvStartReadSeq( trees_fnode->data.seq, &reader );
657    if( reader.seq->total != ntrees )
658        CV_ERROR( CV_StsParseError,
659        "<ntrees> is not equal to the number of trees saved in file" );
660
661    for( k = 0; k < ntrees; k++ )
662    {
663        trees[k] = new CvForestTree();
664        CV_CALL(trees[k]->read( fs, (CvFileNode*)reader.ptr, this, data ));
665        CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
666    }
667
668    var_count = data->var_count;
669    CV_CALL(active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 ));
670    {
671        // initialize active variables mask
672        CvMat submask1, submask2;
673        cvGetCols( active_var_mask, &submask1, 0, nactive_vars );
674        cvGetCols( active_var_mask, &submask2, nactive_vars, var_count );
675        cvSet( &submask1, cvScalar(1) );
676        cvZero( &submask2 );
677    }
678
679    __END__;
680}
681
682
683int CvRTrees::get_tree_count() const
684{
685    return ntrees;
686}
687
688CvForestTree* CvRTrees::get_tree(int i) const
689{
690    return (unsigned)i < (unsigned)ntrees ? trees[i] : 0;
691}
692
693// End of file.
694