1/*M///////////////////////////////////////////////////////////////////////////////////////
2//
3//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4//
5//  By downloading, copying, installing or using the software you agree to this license.
6//  If you do not agree to this license, do not download, install,
7//  copy or use the software.
8//
9//
10//                        Intel License Agreement
11//
12// Copyright (C) 2000, Intel Corporation, all rights reserved.
13// Third party copyrights are property of their respective owners.
14//
15// Redistribution and use in source and binary forms, with or without modification,
16// are permitted provided that the following conditions are met:
17//
18//   * Redistribution's of source code must retain the above copyright notice,
19//     this list of conditions and the following disclaimer.
20//
21//   * Redistribution's in binary form must reproduce the above copyright notice,
22//     this list of conditions and the following disclaimer in the documentation
23//     and/or other materials provided with the distribution.
24//
25//   * The name of Intel Corporation may not be used to endorse or promote products
26//     derived from this software without specific prior written permission.
27//
28// This software is provided by the copyright holders and contributors "as is" and
29// any express or implied warranties, including, but not limited to, the implied
30// warranties of merchantability and fitness for a particular purpose are disclaimed.
31// In no event shall the Intel Corporation or contributors be liable for any direct,
32// indirect, incidental, special, exemplary, or consequential damages
33// (including, but not limited to, procurement of substitute goods or services;
34// loss of use, data, or profits; or business interruption) however caused
35// and on any theory of liability, whether in contract, strict liability,
36// or tort (including negligence or otherwise) arising in any way out of
37// the use of this software, even if advised of the possibility of such damage.
38//
39//M*/
40
41#include "_ml.h"
42
43static const float ord_nan = FLT_MAX*0.5f;
44static const int min_block_size = 1 << 16;
45static const int block_size_delta = 1 << 10;
46
47CvDTreeTrainData::CvDTreeTrainData()
48{
49    var_idx = var_type = cat_count = cat_ofs = cat_map =
50        priors = priors_mult = counts = buf = direction = split_buf = 0;
51    tree_storage = temp_storage = 0;
52
53    clear();
54}
55
56
57CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,
58                      const CvMat* _responses, const CvMat* _var_idx,
59                      const CvMat* _sample_idx, const CvMat* _var_type,
60                      const CvMat* _missing_mask, const CvDTreeParams& _params,
61                      bool _shared, bool _add_labels )
62{
63    var_idx = var_type = cat_count = cat_ofs = cat_map =
64        priors = priors_mult = counts = buf = direction = split_buf = 0;
65    tree_storage = temp_storage = 0;
66
67    set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,
68              _var_type, _missing_mask, _params, _shared, _add_labels );
69}
70
71
72CvDTreeTrainData::~CvDTreeTrainData()
73{
74    clear();
75}
76
77
78bool CvDTreeTrainData::set_params( const CvDTreeParams& _params )
79{
80    bool ok = false;
81
82    CV_FUNCNAME( "CvDTreeTrainData::set_params" );
83
84    __BEGIN__;
85
86    // set parameters
87    params = _params;
88
89    if( params.max_categories < 2 )
90        CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" );
91    params.max_categories = MIN( params.max_categories, 15 );
92
93    if( params.max_depth < 0 )
94        CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" );
95    params.max_depth = MIN( params.max_depth, 25 );
96
97    params.min_sample_count = MAX(params.min_sample_count,1);
98
99    if( params.cv_folds < 0 )
100        CV_ERROR( CV_StsOutOfRange,
101        "params.cv_folds should be =0 (the tree is not pruned) "
102        "or n>0 (tree is pruned using n-fold cross-validation)" );
103
104    if( params.cv_folds == 1 )
105        params.cv_folds = 0;
106
107    if( params.regression_accuracy < 0 )
108        CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
109
110    ok = true;
111
112    __END__;
113
114    return ok;
115}
116
117
118#define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
119static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )
120static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )
121
122#define CV_CMP_PAIRS(a,b) ((a).val < (b).val)
123static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair32s32f, CV_CMP_PAIRS, int )
124
125void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
126    const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
127    const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
128    bool _shared, bool _add_labels, bool _update_data )
129{
130    CvMat* sample_idx = 0;
131    CvMat* var_type0 = 0;
132    CvMat* tmp_map = 0;
133    int** int_ptr = 0;
134    CvDTreeTrainData* data = 0;
135
136    CV_FUNCNAME( "CvDTreeTrainData::set_data" );
137
138    __BEGIN__;
139
140    int sample_all = 0, r_type = 0, cv_n;
141    int total_c_count = 0;
142    int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
143    int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
144    int vi, i;
145    char err[100];
146    const int *sidx = 0, *vidx = 0;
147
148    if( _update_data && data_root )
149    {
150        data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
151            _sample_idx, _var_type, _missing_mask, _params, _shared, _add_labels );
152
153        // compare new and old train data
154        if( !(data->var_count == var_count &&
155            cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON &&
156            cvNorm( data->cat_count, cat_count, CV_C ) < FLT_EPSILON &&
157            cvNorm( data->cat_map, cat_map, CV_C ) < FLT_EPSILON) )
158            CV_ERROR( CV_StsBadArg,
159            "The new training data must have the same types and the input and output variables "
160            "and the same categories for categorical variables" );
161
162        cvReleaseMat( &priors );
163        cvReleaseMat( &priors_mult );
164        cvReleaseMat( &buf );
165        cvReleaseMat( &direction );
166        cvReleaseMat( &split_buf );
167        cvReleaseMemStorage( &temp_storage );
168
169        priors = data->priors; data->priors = 0;
170        priors_mult = data->priors_mult; data->priors_mult = 0;
171        buf = data->buf; data->buf = 0;
172        buf_count = data->buf_count; buf_size = data->buf_size;
173        sample_count = data->sample_count;
174
175        direction = data->direction; data->direction = 0;
176        split_buf = data->split_buf; data->split_buf = 0;
177        temp_storage = data->temp_storage; data->temp_storage = 0;
178        nv_heap = data->nv_heap; cv_heap = data->cv_heap;
179
180        data_root = new_node( 0, sample_count, 0, 0 );
181        EXIT;
182    }
183
184    clear();
185
186    var_all = 0;
187    rng = cvRNG(-1);
188
189    CV_CALL( set_params( _params ));
190
191    // check parameter types and sizes
192    CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
193    if( _tflag == CV_ROW_SAMPLE )
194    {
195        ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
196        dv_step = 1;
197        if( _missing_mask )
198            ms_step = _missing_mask->step, mv_step = 1;
199    }
200    else
201    {
202        dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
203        ds_step = 1;
204        if( _missing_mask )
205            mv_step = _missing_mask->step, ms_step = 1;
206    }
207
208    sample_count = sample_all;
209    var_count = var_all;
210
211    if( _sample_idx )
212    {
213        CV_CALL( sample_idx = cvPreprocessIndexArray( _sample_idx, sample_all ));
214        sidx = sample_idx->data.i;
215        sample_count = sample_idx->rows + sample_idx->cols - 1;
216    }
217
218    if( _var_idx )
219    {
220        CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
221        vidx = var_idx->data.i;
222        var_count = var_idx->rows + var_idx->cols - 1;
223    }
224
225    if( !CV_IS_MAT(_responses) ||
226        (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
227         CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
228        _responses->rows != 1 && _responses->cols != 1 ||
229        _responses->rows + _responses->cols - 1 != sample_all )
230        CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
231                  "floating-point vector containing as many elements as "
232                  "the total number of samples in the training data matrix" );
233
234    CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_all, &r_type ));
235    CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
236
237    cat_var_count = 0;
238    ord_var_count = -1;
239
240    is_classifier = r_type == CV_VAR_CATEGORICAL;
241
242    // step 0. calc the number of categorical vars
243    for( vi = 0; vi < var_count; vi++ )
244    {
245        var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?
246            cat_var_count++ : ord_var_count--;
247    }
248
249    ord_var_count = ~ord_var_count;
250    cv_n = params.cv_folds;
251    // set the two last elements of var_type array to be able
252    // to locate responses and cross-validation labels using
253    // the corresponding get_* functions.
254    var_type->data.i[var_count] = cat_var_count;
255    var_type->data.i[var_count+1] = cat_var_count+1;
256
257    // in case of single ordered predictor we need dummy cv_labels
258    // for safe split_node_data() operation
259    have_labels = cv_n > 0 || ord_var_count == 1 && cat_var_count == 0 || _add_labels;
260
261    buf_size = (ord_var_count + get_work_var_count())*sample_count + 2;
262    shared = _shared;
263    buf_count = shared ? 3 : 2;
264    CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));
265    CV_CALL( cat_count = cvCreateMat( 1, cat_var_count+1, CV_32SC1 ));
266    CV_CALL( cat_ofs = cvCreateMat( 1, cat_count->cols+1, CV_32SC1 ));
267    CV_CALL( cat_map = cvCreateMat( 1, cat_count->cols*10 + 128, CV_32SC1 ));
268
269    // now calculate the maximum size of split,
270    // create memory storage that will keep nodes and splits of the decision tree
271    // allocate root node and the buffer for the whole training data
272    max_split_size = cvAlign(sizeof(CvDTreeSplit) +
273        (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
274    tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
275    tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
276    CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
277    CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
278
279    nv_size = var_count*sizeof(int);
280    nv_size = MAX( nv_size, (int)sizeof(CvSetElem) );
281
282    temp_block_size = nv_size;
283
284    if( cv_n )
285    {
286        if( sample_count < cv_n*MAX(params.min_sample_count,10) )
287            CV_ERROR( CV_StsOutOfRange,
288                "The many folds in cross-validation for such a small dataset" );
289
290        cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
291        temp_block_size = MAX(temp_block_size, cv_size);
292    }
293
294    temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
295    CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
296    CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
297    if( cv_size )
298        CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
299
300    CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
301    CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
302
303    max_c_count = 1;
304
305    // transform the training data to convenient representation
306    for( vi = 0; vi <= var_count; vi++ )
307    {
308        int ci;
309        const uchar* mask = 0;
310        int m_step = 0, step;
311        const int* idata = 0;
312        const float* fdata = 0;
313        int num_valid = 0;
314
315        if( vi < var_count ) // analyze i-th input variable
316        {
317            int vi0 = vidx ? vidx[vi] : vi;
318            ci = get_var_type(vi);
319            step = ds_step; m_step = ms_step;
320            if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
321                idata = _train_data->data.i + vi0*dv_step;
322            else
323                fdata = _train_data->data.fl + vi0*dv_step;
324            if( _missing_mask )
325                mask = _missing_mask->data.ptr + vi0*mv_step;
326        }
327        else // analyze _responses
328        {
329            ci = cat_var_count;
330            step = CV_IS_MAT_CONT(_responses->type) ?
331                1 : _responses->step / CV_ELEM_SIZE(_responses->type);
332            if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
333                idata = _responses->data.i;
334            else
335                fdata = _responses->data.fl;
336        }
337
338        if( vi < var_count && ci >= 0 ||
339            vi == var_count && is_classifier ) // process categorical variable or response
340        {
341            int c_count, prev_label;
342            int* c_map, *dst = get_cat_var_data( data_root, vi );
343
344            // copy data
345            for( i = 0; i < sample_count; i++ )
346            {
347                int val = INT_MAX, si = sidx ? sidx[i] : i;
348                if( !mask || !mask[si*m_step] )
349                {
350                    if( idata )
351                        val = idata[si*step];
352                    else
353                    {
354                        float t = fdata[si*step];
355                        val = cvRound(t);
356                        if( val != t )
357                        {
358                            sprintf( err, "%d-th value of %d-th (categorical) "
359                                "variable is not an integer", i, vi );
360                            CV_ERROR( CV_StsBadArg, err );
361                        }
362                    }
363
364                    if( val == INT_MAX )
365                    {
366                        sprintf( err, "%d-th value of %d-th (categorical) "
367                            "variable is too large", i, vi );
368                        CV_ERROR( CV_StsBadArg, err );
369                    }
370                    num_valid++;
371                }
372                dst[i] = val;
373                int_ptr[i] = dst + i;
374            }
375
376            // sort all the values, including the missing measurements
377            // that should all move to the end
378            icvSortIntPtr( int_ptr, sample_count, 0 );
379            //qsort( int_ptr, sample_count, sizeof(int_ptr[0]), icvCmpIntPtr );
380
381            c_count = num_valid > 0;
382
383            // count the categories
384            for( i = 1; i < num_valid; i++ )
385                c_count += *int_ptr[i] != *int_ptr[i-1];
386
387            if( vi > 0 )
388                max_c_count = MAX( max_c_count, c_count );
389            cat_count->data.i[ci] = c_count;
390            cat_ofs->data.i[ci] = total_c_count;
391
392            // resize cat_map, if need
393            if( cat_map->cols < total_c_count + c_count )
394            {
395                tmp_map = cat_map;
396                CV_CALL( cat_map = cvCreateMat( 1,
397                    MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
398                for( i = 0; i < total_c_count; i++ )
399                    cat_map->data.i[i] = tmp_map->data.i[i];
400                cvReleaseMat( &tmp_map );
401            }
402
403            c_map = cat_map->data.i + total_c_count;
404            total_c_count += c_count;
405
406            // compact the class indices and build the map
407            prev_label = ~*int_ptr[0];
408            c_count = -1;
409
410            for( i = 0; i < num_valid; i++ )
411            {
412                int cur_label = *int_ptr[i];
413                if( cur_label != prev_label )
414                    c_map[++c_count] = prev_label = cur_label;
415                *int_ptr[i] = c_count;
416            }
417
418            // replace labels for missing values with -1
419            for( ; i < sample_count; i++ )
420                *int_ptr[i] = -1;
421        }
422        else if( ci < 0 ) // process ordered variable
423        {
424            CvPair32s32f* dst = get_ord_var_data( data_root, vi );
425
426            for( i = 0; i < sample_count; i++ )
427            {
428                float val = ord_nan;
429                int si = sidx ? sidx[i] : i;
430                if( !mask || !mask[si*m_step] )
431                {
432                    if( idata )
433                        val = (float)idata[si*step];
434                    else
435                        val = fdata[si*step];
436
437                    if( fabs(val) >= ord_nan )
438                    {
439                        sprintf( err, "%d-th value of %d-th (ordered) "
440                            "variable (=%g) is too large", i, vi, val );
441                        CV_ERROR( CV_StsBadArg, err );
442                    }
443                    num_valid++;
444                }
445                dst[i].i = i;
446                dst[i].val = val;
447            }
448
449            icvSortPairs( dst, sample_count, 0 );
450        }
451        else // special case: process ordered response,
452             // it will be stored similarly to categorical vars (i.e. no pairs)
453        {
454            float* dst = get_ord_responses( data_root );
455
456            for( i = 0; i < sample_count; i++ )
457            {
458                float val = ord_nan;
459                int si = sidx ? sidx[i] : i;
460                if( idata )
461                    val = (float)idata[si*step];
462                else
463                    val = fdata[si*step];
464
465                if( fabs(val) >= ord_nan )
466                {
467                    sprintf( err, "%d-th value of %d-th (ordered) "
468                        "variable (=%g) is out of range", i, vi, val );
469                    CV_ERROR( CV_StsBadArg, err );
470                }
471                dst[i] = val;
472            }
473
474            cat_count->data.i[cat_var_count] = 0;
475            cat_ofs->data.i[cat_var_count] = total_c_count;
476            num_valid = sample_count;
477        }
478
479        if( vi < var_count )
480            data_root->set_num_valid(vi, num_valid);
481    }
482
483    if( cv_n )
484    {
485        int* dst = get_labels(data_root);
486        CvRNG* r = &rng;
487
488        for( i = vi = 0; i < sample_count; i++ )
489        {
490            dst[i] = vi++;
491            vi &= vi < cv_n ? -1 : 0;
492        }
493
494        for( i = 0; i < sample_count; i++ )
495        {
496            int a = cvRandInt(r) % sample_count;
497            int b = cvRandInt(r) % sample_count;
498            CV_SWAP( dst[a], dst[b], vi );
499        }
500    }
501
502    cat_map->cols = MAX( total_c_count, 1 );
503
504    max_split_size = cvAlign(sizeof(CvDTreeSplit) +
505        (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
506    CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
507
508    have_priors = is_classifier && params.priors;
509    if( is_classifier )
510    {
511        int m = get_num_classes();
512        double sum = 0;
513        CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
514        for( i = 0; i < m; i++ )
515        {
516            double val = have_priors ? params.priors[i] : 1.;
517            if( val <= 0 )
518                CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
519            priors->data.db[i] = val;
520            sum += val;
521        }
522
523        // normalize weights
524        if( have_priors )
525            cvScale( priors, priors, 1./sum );
526
527        CV_CALL( priors_mult = cvCloneMat( priors ));
528        CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
529    }
530
531    CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
532    CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
533
534    __END__;
535
536    if( data )
537        delete data;
538
539    cvFree( &int_ptr );
540    cvReleaseMat( &sample_idx );
541    cvReleaseMat( &var_type0 );
542    cvReleaseMat( &tmp_map );
543}
544
545
546CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
547{
548    CvDTreeNode* root = 0;
549    CvMat* isubsample_idx = 0;
550    CvMat* subsample_co = 0;
551
552    CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );
553
554    __BEGIN__;
555
556    if( !data_root )
557        CV_ERROR( CV_StsError, "No training data has been set" );
558
559    if( _subsample_idx )
560        CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
561
562    if( !isubsample_idx )
563    {
564        // make a copy of the root node
565        CvDTreeNode temp;
566        int i;
567        root = new_node( 0, 1, 0, 0 );
568        temp = *root;
569        *root = *data_root;
570        root->num_valid = temp.num_valid;
571        if( root->num_valid )
572        {
573            for( i = 0; i < var_count; i++ )
574                root->num_valid[i] = data_root->num_valid[i];
575        }
576        root->cv_Tn = temp.cv_Tn;
577        root->cv_node_risk = temp.cv_node_risk;
578        root->cv_node_error = temp.cv_node_error;
579    }
580    else
581    {
582        int* sidx = isubsample_idx->data.i;
583        // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)
584        int* co, cur_ofs = 0;
585        int vi, i, total = data_root->sample_count;
586        int count = isubsample_idx->rows + isubsample_idx->cols - 1;
587        int work_var_count = get_work_var_count();
588        root = new_node( 0, count, 1, 0 );
589
590        CV_CALL( subsample_co = cvCreateMat( 1, total*2, CV_32SC1 ));
591        cvZero( subsample_co );
592        co = subsample_co->data.i;
593        for( i = 0; i < count; i++ )
594            co[sidx[i]*2]++;
595        for( i = 0; i < total; i++ )
596        {
597            if( co[i*2] )
598            {
599                co[i*2+1] = cur_ofs;
600                cur_ofs += co[i*2];
601            }
602            else
603                co[i*2+1] = -1;
604        }
605
606        for( vi = 0; vi < work_var_count; vi++ )
607        {
608            int ci = get_var_type(vi);
609
610            if( ci >= 0 || vi >= var_count )
611            {
612                const int* src = get_cat_var_data( data_root, vi );
613                int* dst = get_cat_var_data( root, vi );
614                int num_valid = 0;
615
616                for( i = 0; i < count; i++ )
617                {
618                    int val = src[sidx[i]];
619                    dst[i] = val;
620                    num_valid += val >= 0;
621                }
622
623                if( vi < var_count )
624                    root->set_num_valid(vi, num_valid);
625            }
626            else
627            {
628                const CvPair32s32f* src = get_ord_var_data( data_root, vi );
629                CvPair32s32f* dst = get_ord_var_data( root, vi );
630                int j = 0, idx, count_i;
631                int num_valid = data_root->get_num_valid(vi);
632
633                for( i = 0; i < num_valid; i++ )
634                {
635                    idx = src[i].i;
636                    count_i = co[idx*2];
637                    if( count_i )
638                    {
639                        float val = src[i].val;
640                        for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
641                        {
642                            dst[j].val = val;
643                            dst[j].i = cur_ofs;
644                        }
645                    }
646                }
647
648                root->set_num_valid(vi, j);
649
650                for( ; i < total; i++ )
651                {
652                    idx = src[i].i;
653                    count_i = co[idx*2];
654                    if( count_i )
655                    {
656                        float val = src[i].val;
657                        for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
658                        {
659                            dst[j].val = val;
660                            dst[j].i = cur_ofs;
661                        }
662                    }
663                }
664            }
665        }
666    }
667
668    __END__;
669
670    cvReleaseMat( &isubsample_idx );
671    cvReleaseMat( &subsample_co );
672
673    return root;
674}
675
676
677void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,
678                                    float* values, uchar* missing,
679                                    float* responses, bool get_class_idx )
680{
681    CvMat* subsample_idx = 0;
682    CvMat* subsample_co = 0;
683
684    CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );
685
686    __BEGIN__;
687
688    int i, vi, total = sample_count, count = total, cur_ofs = 0;
689    int* sidx = 0;
690    int* co = 0;
691
692    if( _subsample_idx )
693    {
694        CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
695        sidx = subsample_idx->data.i;
696        CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
697        co = subsample_co->data.i;
698        cvZero( subsample_co );
699        count = subsample_idx->cols + subsample_idx->rows - 1;
700        for( i = 0; i < count; i++ )
701            co[sidx[i]*2]++;
702        for( i = 0; i < total; i++ )
703        {
704            int count_i = co[i*2];
705            if( count_i )
706            {
707                co[i*2+1] = cur_ofs*var_count;
708                cur_ofs += count_i;
709            }
710        }
711    }
712
713    if( missing )
714        memset( missing, 1, count*var_count );
715
716    for( vi = 0; vi < var_count; vi++ )
717    {
718        int ci = get_var_type(vi);
719        if( ci >= 0 ) // categorical
720        {
721            float* dst = values + vi;
722            uchar* m = missing ? missing + vi : 0;
723            const int* src = get_cat_var_data(data_root, vi);
724
725            for( i = 0; i < count; i++, dst += var_count )
726            {
727                int idx = sidx ? sidx[i] : i;
728                int val = src[idx];
729                *dst = (float)val;
730                if( m )
731                {
732                    *m = val < 0;
733                    m += var_count;
734                }
735            }
736        }
737        else // ordered
738        {
739            float* dst = values + vi;
740            uchar* m = missing ? missing + vi : 0;
741            const CvPair32s32f* src = get_ord_var_data(data_root, vi);
742            int count1 = data_root->get_num_valid(vi);
743
744            for( i = 0; i < count1; i++ )
745            {
746                int idx = src[i].i;
747                int count_i = 1;
748                if( co )
749                {
750                    count_i = co[idx*2];
751                    cur_ofs = co[idx*2+1];
752                }
753                else
754                    cur_ofs = idx*var_count;
755                if( count_i )
756                {
757                    float val = src[i].val;
758                    for( ; count_i > 0; count_i--, cur_ofs += var_count )
759                    {
760                        dst[cur_ofs] = val;
761                        if( m )
762                            m[cur_ofs] = 0;
763                    }
764                }
765            }
766        }
767    }
768
769    // copy responses
770    if( responses )
771    {
772        if( is_classifier )
773        {
774            const int* src = get_class_labels(data_root);
775            for( i = 0; i < count; i++ )
776            {
777                int idx = sidx ? sidx[i] : i;
778                int val = get_class_idx ? src[idx] :
779                    cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
780                responses[i] = (float)val;
781            }
782        }
783        else
784        {
785            const float* src = get_ord_responses(data_root);
786            for( i = 0; i < count; i++ )
787            {
788                int idx = sidx ? sidx[i] : i;
789                responses[i] = src[idx];
790            }
791        }
792    }
793
794    __END__;
795
796    cvReleaseMat( &subsample_idx );
797    cvReleaseMat( &subsample_co );
798}
799
800
801CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count,
802                                         int storage_idx, int offset )
803{
804    CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );
805
806    node->sample_count = count;
807    node->depth = parent ? parent->depth + 1 : 0;
808    node->parent = parent;
809    node->left = node->right = 0;
810    node->split = 0;
811    node->value = 0;
812    node->class_idx = 0;
813    node->maxlr = 0.;
814
815    node->buf_idx = storage_idx;
816    node->offset = offset;
817    if( nv_heap )
818        node->num_valid = (int*)cvSetNew( nv_heap );
819    else
820        node->num_valid = 0;
821    node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;
822    node->complexity = 0;
823
824    if( params.cv_folds > 0 && cv_heap )
825    {
826        int cv_n = params.cv_folds;
827        node->Tn = INT_MAX;
828        node->cv_Tn = (int*)cvSetNew( cv_heap );
829        node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));
830        node->cv_node_error = node->cv_node_risk + cv_n;
831    }
832    else
833    {
834        node->Tn = 0;
835        node->cv_Tn = 0;
836        node->cv_node_risk = 0;
837        node->cv_node_error = 0;
838    }
839
840    return node;
841}
842
843
844CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val,
845                int split_point, int inversed, float quality )
846{
847    CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
848    split->var_idx = vi;
849    split->ord.c = cmp_val;
850    split->ord.split_point = split_point;
851    split->inversed = inversed;
852    split->quality = quality;
853    split->next = 0;
854
855    return split;
856}
857
858
859CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality )
860{
861    CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
862    int i, n = (max_c_count + 31)/32;
863
864    split->var_idx = vi;
865    split->inversed = 0;
866    split->quality = quality;
867    for( i = 0; i < n; i++ )
868        split->subset[i] = 0;
869    split->next = 0;
870
871    return split;
872}
873
874
875void CvDTreeTrainData::free_node( CvDTreeNode* node )
876{
877    CvDTreeSplit* split = node->split;
878    free_node_data( node );
879    while( split )
880    {
881        CvDTreeSplit* next = split->next;
882        cvSetRemoveByPtr( split_heap, split );
883        split = next;
884    }
885    node->split = 0;
886    cvSetRemoveByPtr( node_heap, node );
887}
888
889
890void CvDTreeTrainData::free_node_data( CvDTreeNode* node )
891{
892    if( node->num_valid )
893    {
894        cvSetRemoveByPtr( nv_heap, node->num_valid );
895        node->num_valid = 0;
896    }
897    // do not free cv_* fields, as all the cross-validation related data is released at once.
898}
899
900
901void CvDTreeTrainData::free_train_data()
902{
903    cvReleaseMat( &counts );
904    cvReleaseMat( &buf );
905    cvReleaseMat( &direction );
906    cvReleaseMat( &split_buf );
907    cvReleaseMemStorage( &temp_storage );
908    cv_heap = nv_heap = 0;
909}
910
911
912void CvDTreeTrainData::clear()
913{
914    free_train_data();
915
916    cvReleaseMemStorage( &tree_storage );
917
918    cvReleaseMat( &var_idx );
919    cvReleaseMat( &var_type );
920    cvReleaseMat( &cat_count );
921    cvReleaseMat( &cat_ofs );
922    cvReleaseMat( &cat_map );
923    cvReleaseMat( &priors );
924    cvReleaseMat( &priors_mult );
925
926    node_heap = split_heap = 0;
927
928    sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;
929    have_labels = have_priors = is_classifier = false;
930
931    buf_count = buf_size = 0;
932    shared = false;
933
934    data_root = 0;
935
936    rng = cvRNG(-1);
937}
938
939
940int CvDTreeTrainData::get_num_classes() const
941{
942    return is_classifier ? cat_count->data.i[cat_var_count] : 0;
943}
944
945
946int CvDTreeTrainData::get_var_type(int vi) const
947{
948    return var_type->data.i[vi];
949}
950
951
952int CvDTreeTrainData::get_work_var_count() const
953{
954    return var_count + 1 + (have_labels ? 1 : 0);
955}
956
957CvPair32s32f* CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi )
958{
959    int oi = ~get_var_type(vi);
960    assert( 0 <= oi && oi < ord_var_count );
961    return (CvPair32s32f*)(buf->data.i + n->buf_idx*buf->cols +
962                           n->offset + oi*n->sample_count*2);
963}
964
965
966int* CvDTreeTrainData::get_class_labels( CvDTreeNode* n )
967{
968    return get_cat_var_data( n, var_count );
969}
970
971
972float* CvDTreeTrainData::get_ord_responses( CvDTreeNode* n )
973{
974    return (float*)get_cat_var_data( n, var_count );
975}
976
977
978int* CvDTreeTrainData::get_labels( CvDTreeNode* n )
979{
980    return have_labels ? get_cat_var_data( n, var_count + 1 ) : 0;
981}
982
983
984int* CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi )
985{
986    int ci = get_var_type(vi);
987    assert( 0 <= ci && ci <= cat_var_count + 1 );
988    return buf->data.i + n->buf_idx*buf->cols + n->offset +
989           (ord_var_count*2 + ci)*n->sample_count;
990}
991
992
993int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )
994{
995    int idx = n->buf_idx + 1;
996    if( idx >= buf_count )
997        idx = shared ? 1 : 0;
998    return idx;
999}
1000
1001
1002void CvDTreeTrainData::write_params( CvFileStorage* fs )
1003{
1004    CV_FUNCNAME( "CvDTreeTrainData::write_params" );
1005
1006    __BEGIN__;
1007
1008    int vi, vcount = var_count;
1009
1010    cvWriteInt( fs, "is_classifier", is_classifier ? 1 : 0 );
1011    cvWriteInt( fs, "var_all", var_all );
1012    cvWriteInt( fs, "var_count", var_count );
1013    cvWriteInt( fs, "ord_var_count", ord_var_count );
1014    cvWriteInt( fs, "cat_var_count", cat_var_count );
1015
1016    cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
1017    cvWriteInt( fs, "use_surrogates", params.use_surrogates ? 1 : 0 );
1018
1019    if( is_classifier )
1020    {
1021        cvWriteInt( fs, "max_categories", params.max_categories );
1022    }
1023    else
1024    {
1025        cvWriteReal( fs, "regression_accuracy", params.regression_accuracy );
1026    }
1027
1028    cvWriteInt( fs, "max_depth", params.max_depth );
1029    cvWriteInt( fs, "min_sample_count", params.min_sample_count );
1030    cvWriteInt( fs, "cross_validation_folds", params.cv_folds );
1031
1032    if( params.cv_folds > 1 )
1033    {
1034        cvWriteInt( fs, "use_1se_rule", params.use_1se_rule ? 1 : 0 );
1035        cvWriteInt( fs, "truncate_pruned_tree", params.truncate_pruned_tree ? 1 : 0 );
1036    }
1037
1038    if( priors )
1039        cvWrite( fs, "priors", priors );
1040
1041    cvEndWriteStruct( fs );
1042
1043    if( var_idx )
1044        cvWrite( fs, "var_idx", var_idx );
1045
1046    cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );
1047
1048    for( vi = 0; vi < vcount; vi++ )
1049        cvWriteInt( fs, 0, var_type->data.i[vi] >= 0 );
1050
1051    cvEndWriteStruct( fs );
1052
1053    if( cat_count && (cat_var_count > 0 || is_classifier) )
1054    {
1055        CV_ASSERT( cat_count != 0 );
1056        cvWrite( fs, "cat_count", cat_count );
1057        cvWrite( fs, "cat_map", cat_map );
1058    }
1059
1060    __END__;
1061}
1062
1063
1064void CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node )
1065{
1066    CV_FUNCNAME( "CvDTreeTrainData::read_params" );
1067
1068    __BEGIN__;
1069
1070    CvFileNode *tparams_node, *vartype_node;
1071    CvSeqReader reader;
1072    int vi, max_split_size, tree_block_size;
1073
1074    is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);
1075    var_all = cvReadIntByName( fs, node, "var_all" );
1076    var_count = cvReadIntByName( fs, node, "var_count", var_all );
1077    cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
1078    ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
1079
1080    tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
1081
1082    if( tparams_node ) // training parameters are not necessary
1083    {
1084        params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;
1085
1086        if( is_classifier )
1087        {
1088            params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );
1089        }
1090        else
1091        {
1092            params.regression_accuracy =
1093                (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );
1094        }
1095
1096        params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );
1097        params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );
1098        params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );
1099
1100        if( params.cv_folds > 1 )
1101        {
1102            params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;
1103            params.truncate_pruned_tree =
1104                cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;
1105        }
1106
1107        priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );
1108        if( priors )
1109        {
1110            if( !CV_IS_MAT(priors) )
1111                CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );
1112            priors_mult = cvCloneMat( priors );
1113        }
1114    }
1115
1116    CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
1117    if( var_idx )
1118    {
1119        if( !CV_IS_MAT(var_idx) ||
1120            var_idx->cols != 1 && var_idx->rows != 1 ||
1121            var_idx->cols + var_idx->rows - 1 != var_count ||
1122            CV_MAT_TYPE(var_idx->type) != CV_32SC1 )
1123            CV_ERROR( CV_StsParseError,
1124                "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );
1125
1126        for( vi = 0; vi < var_count; vi++ )
1127            if( (unsigned)var_idx->data.i[vi] >= (unsigned)var_all )
1128                CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );
1129    }
1130
1131    ////// read var type
1132    CV_CALL( var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 ));
1133
1134    cat_var_count = 0;
1135    ord_var_count = -1;
1136    vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
1137
1138    if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && var_count == 1 )
1139        var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--;
1140    else
1141    {
1142        if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||
1143            vartype_node->data.seq->total != var_count )
1144            CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
1145
1146        cvStartReadSeq( vartype_node->data.seq, &reader );
1147
1148        for( vi = 0; vi < var_count; vi++ )
1149        {
1150            CvFileNode* n = (CvFileNode*)reader.ptr;
1151            if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )
1152                CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
1153            var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;
1154            CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
1155        }
1156    }
1157    var_type->data.i[var_count] = cat_var_count;
1158
1159    ord_var_count = ~ord_var_count;
1160    if( cat_var_count != cat_var_count || ord_var_count != ord_var_count )
1161        CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" );
1162    //////
1163
1164    if( cat_var_count > 0 || is_classifier )
1165    {
1166        int ccount, total_c_count = 0;
1167        CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
1168        CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
1169
1170        if( !CV_IS_MAT(cat_count) || !CV_IS_MAT(cat_map) ||
1171            cat_count->cols != 1 && cat_count->rows != 1 ||
1172            CV_MAT_TYPE(cat_count->type) != CV_32SC1 ||
1173            cat_count->cols + cat_count->rows - 1 != cat_var_count + is_classifier ||
1174            cat_map->cols != 1 && cat_map->rows != 1 ||
1175            CV_MAT_TYPE(cat_map->type) != CV_32SC1 )
1176            CV_ERROR( CV_StsParseError,
1177            "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );
1178
1179        ccount = cat_var_count + is_classifier;
1180
1181        CV_CALL( cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));
1182        cat_ofs->data.i[0] = 0;
1183        max_c_count = 1;
1184
1185        for( vi = 0; vi < ccount; vi++ )
1186        {
1187            int val = cat_count->data.i[vi];
1188            if( val <= 0 )
1189                CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );
1190            max_c_count = MAX( max_c_count, val );
1191            cat_ofs->data.i[vi+1] = total_c_count += val;
1192        }
1193
1194        if( cat_map->cols + cat_map->rows - 1 != total_c_count )
1195            CV_ERROR( CV_StsBadSize,
1196            "cat_map vector length is not equal to the total number of categories in all categorical vars" );
1197    }
1198
1199    max_split_size = cvAlign(sizeof(CvDTreeSplit) +
1200        (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
1201
1202    tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
1203    tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
1204    CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
1205    CV_CALL( node_heap = cvCreateSet( 0, sizeof(node_heap[0]),
1206            sizeof(CvDTreeNode), tree_storage ));
1207    CV_CALL( split_heap = cvCreateSet( 0, sizeof(split_heap[0]),
1208            max_split_size, tree_storage ));
1209
1210    __END__;
1211}
1212
1213
1214/////////////////////// Decision Tree /////////////////////////
1215
1216CvDTree::CvDTree()
1217{
1218    data = 0;
1219    var_importance = 0;
1220    default_model_name = "my_tree";
1221
1222    clear();
1223}
1224
1225
1226void CvDTree::clear()
1227{
1228    cvReleaseMat( &var_importance );
1229    if( data )
1230    {
1231        if( !data->shared )
1232            delete data;
1233        else
1234            free_tree();
1235        data = 0;
1236    }
1237    root = 0;
1238    pruned_tree_idx = -1;
1239}
1240
1241
1242CvDTree::~CvDTree()
1243{
1244    clear();
1245}
1246
1247
1248const CvDTreeNode* CvDTree::get_root() const
1249{
1250    return root;
1251}
1252
1253
1254int CvDTree::get_pruned_tree_idx() const
1255{
1256    return pruned_tree_idx;
1257}
1258
1259
1260CvDTreeTrainData* CvDTree::get_data()
1261{
1262    return data;
1263}
1264
1265
1266bool CvDTree::train( const CvMat* _train_data, int _tflag,
1267                     const CvMat* _responses, const CvMat* _var_idx,
1268                     const CvMat* _sample_idx, const CvMat* _var_type,
1269                     const CvMat* _missing_mask, CvDTreeParams _params )
1270{
1271    bool result = false;
1272
1273    CV_FUNCNAME( "CvDTree::train" );
1274
1275    __BEGIN__;
1276
1277    clear();
1278    data = new CvDTreeTrainData( _train_data, _tflag, _responses,
1279                                 _var_idx, _sample_idx, _var_type,
1280                                 _missing_mask, _params, false );
1281    CV_CALL( result = do_train(0));
1282
1283    __END__;
1284
1285    return result;
1286}
1287
1288
1289bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )
1290{
1291    bool result = false;
1292
1293    CV_FUNCNAME( "CvDTree::train" );
1294
1295    __BEGIN__;
1296
1297    clear();
1298    data = _data;
1299    data->shared = true;
1300    CV_CALL( result = do_train(_subsample_idx));
1301
1302    __END__;
1303
1304    return result;
1305}
1306
1307
1308bool CvDTree::do_train( const CvMat* _subsample_idx )
1309{
1310    bool result = false;
1311
1312    CV_FUNCNAME( "CvDTree::do_train" );
1313
1314    __BEGIN__;
1315
1316    root = data->subsample_data( _subsample_idx );
1317
1318    CV_CALL( try_split_node(root));
1319
1320    if( data->params.cv_folds > 0 )
1321        CV_CALL( prune_cv());
1322
1323    if( !data->shared )
1324        data->free_train_data();
1325
1326    result = true;
1327
1328    __END__;
1329
1330    return result;
1331}
1332
1333
1334void CvDTree::try_split_node( CvDTreeNode* node )
1335{
1336    CvDTreeSplit* best_split = 0;
1337    int i, n = node->sample_count, vi;
1338    bool can_split = true;
1339    double quality_scale;
1340
1341    calc_node_value( node );
1342
1343    if( node->sample_count <= data->params.min_sample_count ||
1344        node->depth >= data->params.max_depth )
1345        can_split = false;
1346
1347    if( can_split && data->is_classifier )
1348    {
1349        // check if we have a "pure" node,
1350        // we assume that cls_count is filled by calc_node_value()
1351        int* cls_count = data->counts->data.i;
1352        int nz = 0, m = data->get_num_classes();
1353        for( i = 0; i < m; i++ )
1354            nz += cls_count[i] != 0;
1355        if( nz == 1 ) // there is only one class
1356            can_split = false;
1357    }
1358    else if( can_split )
1359    {
1360        if( sqrt(node->node_risk)/n < data->params.regression_accuracy )
1361            can_split = false;
1362    }
1363
1364    if( can_split )
1365    {
1366        best_split = find_best_split(node);
1367        // TODO: check the split quality ...
1368        node->split = best_split;
1369    }
1370
1371    if( !can_split || !best_split )
1372    {
1373        data->free_node_data(node);
1374        return;
1375    }
1376
1377    quality_scale = calc_node_dir( node );
1378
1379    if( data->params.use_surrogates )
1380    {
1381        // find all the surrogate splits
1382        // and sort them by their similarity to the primary one
1383        for( vi = 0; vi < data->var_count; vi++ )
1384        {
1385            CvDTreeSplit* split;
1386            int ci = data->get_var_type(vi);
1387
1388            if( vi == best_split->var_idx )
1389                continue;
1390
1391            if( ci >= 0 )
1392                split = find_surrogate_split_cat( node, vi );
1393            else
1394                split = find_surrogate_split_ord( node, vi );
1395
1396            if( split )
1397            {
1398                // insert the split
1399                CvDTreeSplit* prev_split = node->split;
1400                split->quality = (float)(split->quality*quality_scale);
1401
1402                while( prev_split->next &&
1403                       prev_split->next->quality > split->quality )
1404                    prev_split = prev_split->next;
1405                split->next = prev_split->next;
1406                prev_split->next = split;
1407            }
1408        }
1409    }
1410
1411    split_node_data( node );
1412    try_split_node( node->left );
1413    try_split_node( node->right );
1414}
1415
1416
1417// calculate direction (left(-1),right(1),missing(0))
1418// for each sample using the best split
1419// the function returns scale coefficients for surrogate split quality factors.
1420// the scale is applied to normalize surrogate split quality relatively to the
1421// best (primary) split quality. That is, if a surrogate split is absolutely
1422// identical to the primary split, its quality will be set to the maximum value =
1423// quality of the primary split; otherwise, it will be lower.
1424// besides, the function compute node->maxlr,
1425// minimum possible quality (w/o considering the above mentioned scale)
1426// for a surrogate split. Surrogate splits with quality less than node->maxlr
1427// are not discarded.
1428double CvDTree::calc_node_dir( CvDTreeNode* node )
1429{
1430    char* dir = (char*)data->direction->data.ptr;
1431    int i, n = node->sample_count, vi = node->split->var_idx;
1432    double L, R;
1433
1434    assert( !node->split->inversed );
1435
1436    if( data->get_var_type(vi) >= 0 ) // split on categorical var
1437    {
1438        const int* labels = data->get_cat_var_data(node,vi);
1439        const int* subset = node->split->subset;
1440
1441        if( !data->have_priors )
1442        {
1443            int sum = 0, sum_abs = 0;
1444
1445            for( i = 0; i < n; i++ )
1446            {
1447                int idx = labels[i];
1448                int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
1449                sum += d; sum_abs += d & 1;
1450                dir[i] = (char)d;
1451            }
1452
1453            R = (sum_abs + sum) >> 1;
1454            L = (sum_abs - sum) >> 1;
1455        }
1456        else
1457        {
1458            const int* responses = data->get_class_labels(node);
1459            const double* priors = data->priors_mult->data.db;
1460            double sum = 0, sum_abs = 0;
1461
1462            for( i = 0; i < n; i++ )
1463            {
1464                int idx = labels[i];
1465                double w = priors[responses[i]];
1466                int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
1467                sum += d*w; sum_abs += (d & 1)*w;
1468                dir[i] = (char)d;
1469            }
1470
1471            R = (sum_abs + sum) * 0.5;
1472            L = (sum_abs - sum) * 0.5;
1473        }
1474    }
1475    else // split on ordered var
1476    {
1477        const CvPair32s32f* sorted = data->get_ord_var_data(node,vi);
1478        int split_point = node->split->ord.split_point;
1479        int n1 = node->get_num_valid(vi);
1480
1481        assert( 0 <= split_point && split_point < n1-1 );
1482
1483        if( !data->have_priors )
1484        {
1485            for( i = 0; i <= split_point; i++ )
1486                dir[sorted[i].i] = (char)-1;
1487            for( ; i < n1; i++ )
1488                dir[sorted[i].i] = (char)1;
1489            for( ; i < n; i++ )
1490                dir[sorted[i].i] = (char)0;
1491
1492            L = split_point-1;
1493            R = n1 - split_point + 1;
1494        }
1495        else
1496        {
1497            const int* responses = data->get_class_labels(node);
1498            const double* priors = data->priors_mult->data.db;
1499            L = R = 0;
1500
1501            for( i = 0; i <= split_point; i++ )
1502            {
1503                int idx = sorted[i].i;
1504                double w = priors[responses[idx]];
1505                dir[idx] = (char)-1;
1506                L += w;
1507            }
1508
1509            for( ; i < n1; i++ )
1510            {
1511                int idx = sorted[i].i;
1512                double w = priors[responses[idx]];
1513                dir[idx] = (char)1;
1514                R += w;
1515            }
1516
1517            for( ; i < n; i++ )
1518                dir[sorted[i].i] = (char)0;
1519        }
1520    }
1521
1522    node->maxlr = MAX( L, R );
1523    return node->split->quality/(L + R);
1524}
1525
1526
1527CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )
1528{
1529    int vi;
1530    CvDTreeSplit *best_split = 0, *split = 0, *t;
1531
1532    for( vi = 0; vi < data->var_count; vi++ )
1533    {
1534        int ci = data->get_var_type(vi);
1535        if( node->get_num_valid(vi) <= 1 )
1536            continue;
1537
1538        if( data->is_classifier )
1539        {
1540            if( ci >= 0 )
1541                split = find_split_cat_class( node, vi );
1542            else
1543                split = find_split_ord_class( node, vi );
1544        }
1545        else
1546        {
1547            if( ci >= 0 )
1548                split = find_split_cat_reg( node, vi );
1549            else
1550                split = find_split_ord_reg( node, vi );
1551        }
1552
1553        if( split )
1554        {
1555            if( !best_split || best_split->quality < split->quality )
1556                CV_SWAP( best_split, split, t );
1557            if( split )
1558                cvSetRemoveByPtr( data->split_heap, split );
1559        }
1560    }
1561
1562    return best_split;
1563}
1564
1565
1566CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi )
1567{
1568    const float epsilon = FLT_EPSILON*2;
1569    const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
1570    const int* responses = data->get_class_labels(node);
1571    int n = node->sample_count;
1572    int n1 = node->get_num_valid(vi);
1573    int m = data->get_num_classes();
1574    const int* rc0 = data->counts->data.i;
1575    int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));
1576    int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));
1577    int i, best_i = -1;
1578    double lsum2 = 0, rsum2 = 0, best_val = 0;
1579    const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
1580
1581    // init arrays of class instance counters on both sides of the split
1582    for( i = 0; i < m; i++ )
1583    {
1584        lc[i] = 0;
1585        rc[i] = rc0[i];
1586    }
1587
1588    // compensate for missing values
1589    for( i = n1; i < n; i++ )
1590        rc[responses[sorted[i].i]]--;
1591
1592    if( !priors )
1593    {
1594        int L = 0, R = n1;
1595
1596        for( i = 0; i < m; i++ )
1597            rsum2 += (double)rc[i]*rc[i];
1598
1599        for( i = 0; i < n1 - 1; i++ )
1600        {
1601            int idx = responses[sorted[i].i];
1602            int lv, rv;
1603            L++; R--;
1604            lv = lc[idx]; rv = rc[idx];
1605            lsum2 += lv*2 + 1;
1606            rsum2 -= rv*2 - 1;
1607            lc[idx] = lv + 1; rc[idx] = rv - 1;
1608
1609            if( sorted[i].val + epsilon < sorted[i+1].val )
1610            {
1611                double val = (lsum2*R + rsum2*L)/((double)L*R);
1612                if( best_val < val )
1613                {
1614                    best_val = val;
1615                    best_i = i;
1616                }
1617            }
1618        }
1619    }
1620    else
1621    {
1622        double L = 0, R = 0;
1623        for( i = 0; i < m; i++ )
1624        {
1625            double wv = rc[i]*priors[i];
1626            R += wv;
1627            rsum2 += wv*wv;
1628        }
1629
1630        for( i = 0; i < n1 - 1; i++ )
1631        {
1632            int idx = responses[sorted[i].i];
1633            int lv, rv;
1634            double p = priors[idx], p2 = p*p;
1635            L += p; R -= p;
1636            lv = lc[idx]; rv = rc[idx];
1637            lsum2 += p2*(lv*2 + 1);
1638            rsum2 -= p2*(rv*2 - 1);
1639            lc[idx] = lv + 1; rc[idx] = rv - 1;
1640
1641            if( sorted[i].val + epsilon < sorted[i+1].val )
1642            {
1643                double val = (lsum2*R + rsum2*L)/((double)L*R);
1644                if( best_val < val )
1645                {
1646                    best_val = val;
1647                    best_i = i;
1648                }
1649            }
1650        }
1651    }
1652
1653    return best_i >= 0 ? data->new_split_ord( vi,
1654        (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
1655        0, (float)best_val ) : 0;
1656}
1657
1658
1659void CvDTree::cluster_categories( const int* vectors, int n, int m,
1660                                int* csums, int k, int* labels )
1661{
1662    // TODO: consider adding priors (class weights) and sample weights to the clustering algorithm
1663    int iters = 0, max_iters = 100;
1664    int i, j, idx;
1665    double* buf = (double*)cvStackAlloc( (n + k)*sizeof(buf[0]) );
1666    double *v_weights = buf, *c_weights = buf + k;
1667    bool modified = true;
1668    CvRNG* r = &data->rng;
1669
1670    // assign labels randomly
1671    for( i = idx = 0; i < n; i++ )
1672    {
1673        int sum = 0;
1674        const int* v = vectors + i*m;
1675        labels[i] = idx++;
1676        idx &= idx < k ? -1 : 0;
1677
1678        // compute weight of each vector
1679        for( j = 0; j < m; j++ )
1680            sum += v[j];
1681        v_weights[i] = sum ? 1./sum : 0.;
1682    }
1683
1684    for( i = 0; i < n; i++ )
1685    {
1686        int i1 = cvRandInt(r) % n;
1687        int i2 = cvRandInt(r) % n;
1688        CV_SWAP( labels[i1], labels[i2], j );
1689    }
1690
1691    for( iters = 0; iters <= max_iters; iters++ )
1692    {
1693        // calculate csums
1694        for( i = 0; i < k; i++ )
1695        {
1696            for( j = 0; j < m; j++ )
1697                csums[i*m + j] = 0;
1698        }
1699
1700        for( i = 0; i < n; i++ )
1701        {
1702            const int* v = vectors + i*m;
1703            int* s = csums + labels[i]*m;
1704            for( j = 0; j < m; j++ )
1705                s[j] += v[j];
1706        }
1707
1708        // exit the loop here, when we have up-to-date csums
1709        if( iters == max_iters || !modified )
1710            break;
1711
1712        modified = false;
1713
1714        // calculate weight of each cluster
1715        for( i = 0; i < k; i++ )
1716        {
1717            const int* s = csums + i*m;
1718            int sum = 0;
1719            for( j = 0; j < m; j++ )
1720                sum += s[j];
1721            c_weights[i] = sum ? 1./sum : 0;
1722        }
1723
1724        // now for each vector determine the closest cluster
1725        for( i = 0; i < n; i++ )
1726        {
1727            const int* v = vectors + i*m;
1728            double alpha = v_weights[i];
1729            double min_dist2 = DBL_MAX;
1730            int min_idx = -1;
1731
1732            for( idx = 0; idx < k; idx++ )
1733            {
1734                const int* s = csums + idx*m;
1735                double dist2 = 0., beta = c_weights[idx];
1736                for( j = 0; j < m; j++ )
1737                {
1738                    double t = v[j]*alpha - s[j]*beta;
1739                    dist2 += t*t;
1740                }
1741                if( min_dist2 > dist2 )
1742                {
1743                    min_dist2 = dist2;
1744                    min_idx = idx;
1745                }
1746            }
1747
1748            if( min_idx != labels[i] )
1749                modified = true;
1750            labels[i] = min_idx;
1751        }
1752    }
1753}
1754
1755
1756CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi )
1757{
1758    CvDTreeSplit* split;
1759    const int* labels = data->get_cat_var_data(node, vi);
1760    const int* responses = data->get_class_labels(node);
1761    int ci = data->get_var_type(vi);
1762    int n = node->sample_count;
1763    int m = data->get_num_classes();
1764    int _mi = data->cat_count->data.i[ci], mi = _mi;
1765    int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));
1766    int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));
1767    int* _cjk = (int*)cvStackAlloc(m*(mi+1)*sizeof(_cjk[0]))+m, *cjk = _cjk;
1768    double* c_weights = (double*)cvStackAlloc( mi*sizeof(c_weights[0]) );
1769    int* cluster_labels = 0;
1770    int** int_ptr = 0;
1771    int i, j, k, idx;
1772    double L = 0, R = 0;
1773    double best_val = 0;
1774    int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
1775    const double* priors = data->priors_mult->data.db;
1776
1777    // init array of counters:
1778    // c_{jk} - number of samples that have vi-th input variable = j and response = k.
1779    for( j = -1; j < mi; j++ )
1780        for( k = 0; k < m; k++ )
1781            cjk[j*m + k] = 0;
1782
1783    for( i = 0; i < n; i++ )
1784    {
1785        j = labels[i];
1786        k = responses[i];
1787        cjk[j*m + k]++;
1788    }
1789
1790    if( m > 2 )
1791    {
1792        if( mi > data->params.max_categories )
1793        {
1794            mi = MIN(data->params.max_categories, n);
1795            cjk += _mi*m;
1796            cluster_labels = (int*)cvStackAlloc(mi*sizeof(cluster_labels[0]));
1797            cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );
1798        }
1799        subset_i = 1;
1800        subset_n = 1 << mi;
1801    }
1802    else
1803    {
1804        assert( m == 2 );
1805        int_ptr = (int**)cvStackAlloc( mi*sizeof(int_ptr[0]) );
1806        for( j = 0; j < mi; j++ )
1807            int_ptr[j] = cjk + j*2 + 1;
1808        icvSortIntPtr( int_ptr, mi, 0 );
1809        subset_i = 0;
1810        subset_n = mi;
1811    }
1812
1813    for( k = 0; k < m; k++ )
1814    {
1815        int sum = 0;
1816        for( j = 0; j < mi; j++ )
1817            sum += cjk[j*m + k];
1818        rc[k] = sum;
1819        lc[k] = 0;
1820    }
1821
1822    for( j = 0; j < mi; j++ )
1823    {
1824        double sum = 0;
1825        for( k = 0; k < m; k++ )
1826            sum += cjk[j*m + k]*priors[k];
1827        c_weights[j] = sum;
1828        R += c_weights[j];
1829    }
1830
1831    for( ; subset_i < subset_n; subset_i++ )
1832    {
1833        double weight;
1834        int* crow;
1835        double lsum2 = 0, rsum2 = 0;
1836
1837        if( m == 2 )
1838            idx = (int)(int_ptr[subset_i] - cjk)/2;
1839        else
1840        {
1841            int graycode = (subset_i>>1)^subset_i;
1842            int diff = graycode ^ prevcode;
1843
1844            // determine index of the changed bit.
1845            Cv32suf u;
1846            idx = diff >= (1 << 16) ? 16 : 0;
1847            u.f = (float)(((diff >> 16) | diff) & 65535);
1848            idx += (u.i >> 23) - 127;
1849            subtract = graycode < prevcode;
1850            prevcode = graycode;
1851        }
1852
1853        crow = cjk + idx*m;
1854        weight = c_weights[idx];
1855        if( weight < FLT_EPSILON )
1856            continue;
1857
1858        if( !subtract )
1859        {
1860            for( k = 0; k < m; k++ )
1861            {
1862                int t = crow[k];
1863                int lval = lc[k] + t;
1864                int rval = rc[k] - t;
1865                double p = priors[k], p2 = p*p;
1866                lsum2 += p2*lval*lval;
1867                rsum2 += p2*rval*rval;
1868                lc[k] = lval; rc[k] = rval;
1869            }
1870            L += weight;
1871            R -= weight;
1872        }
1873        else
1874        {
1875            for( k = 0; k < m; k++ )
1876            {
1877                int t = crow[k];
1878                int lval = lc[k] - t;
1879                int rval = rc[k] + t;
1880                double p = priors[k], p2 = p*p;
1881                lsum2 += p2*lval*lval;
1882                rsum2 += p2*rval*rval;
1883                lc[k] = lval; rc[k] = rval;
1884            }
1885            L -= weight;
1886            R += weight;
1887        }
1888
1889        if( L > FLT_EPSILON && R > FLT_EPSILON )
1890        {
1891            double val = (lsum2*R + rsum2*L)/((double)L*R);
1892            if( best_val < val )
1893            {
1894                best_val = val;
1895                best_subset = subset_i;
1896            }
1897        }
1898    }
1899
1900    if( best_subset < 0 )
1901        return 0;
1902
1903    split = data->new_split_cat( vi, (float)best_val );
1904
1905    if( m == 2 )
1906    {
1907        for( i = 0; i <= best_subset; i++ )
1908        {
1909            idx = (int)(int_ptr[i] - cjk) >> 1;
1910            split->subset[idx >> 5] |= 1 << (idx & 31);
1911        }
1912    }
1913    else
1914    {
1915        for( i = 0; i < _mi; i++ )
1916        {
1917            idx = cluster_labels ? cluster_labels[i] : i;
1918            if( best_subset & (1 << idx) )
1919                split->subset[i >> 5] |= 1 << (i & 31);
1920        }
1921    }
1922
1923    return split;
1924}
1925
1926
1927CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi )
1928{
1929    const float epsilon = FLT_EPSILON*2;
1930    const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
1931    const float* responses = data->get_ord_responses(node);
1932    int n = node->sample_count;
1933    int n1 = node->get_num_valid(vi);
1934    int i, best_i = -1;
1935    double best_val = 0, lsum = 0, rsum = node->value*n;
1936    int L = 0, R = n1;
1937
1938    // compensate for missing values
1939    for( i = n1; i < n; i++ )
1940        rsum -= responses[sorted[i].i];
1941
1942    // find the optimal split
1943    for( i = 0; i < n1 - 1; i++ )
1944    {
1945        float t = responses[sorted[i].i];
1946        L++; R--;
1947        lsum += t;
1948        rsum -= t;
1949
1950        if( sorted[i].val + epsilon < sorted[i+1].val )
1951        {
1952            double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
1953            if( best_val < val )
1954            {
1955                best_val = val;
1956                best_i = i;
1957            }
1958        }
1959    }
1960
1961    return best_i >= 0 ? data->new_split_ord( vi,
1962        (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
1963        0, (float)best_val ) : 0;
1964}
1965
1966
1967CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi )
1968{
1969    CvDTreeSplit* split;
1970    const int* labels = data->get_cat_var_data(node, vi);
1971    const float* responses = data->get_ord_responses(node);
1972    int ci = data->get_var_type(vi);
1973    int n = node->sample_count;
1974    int mi = data->cat_count->data.i[ci];
1975    double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;
1976    int* counts = (int*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;
1977    double** sum_ptr = 0;
1978    int i, L = 0, R = 0;
1979    double best_val = 0, lsum = 0, rsum = 0;
1980    int best_subset = -1, subset_i;
1981
1982    for( i = -1; i < mi; i++ )
1983        sum[i] = counts[i] = 0;
1984
1985    // calculate sum response and weight of each category of the input var
1986    for( i = 0; i < n; i++ )
1987    {
1988        int idx = labels[i];
1989        double s = sum[idx] + responses[i];
1990        int nc = counts[idx] + 1;
1991        sum[idx] = s;
1992        counts[idx] = nc;
1993    }
1994
1995    // calculate average response in each category
1996    for( i = 0; i < mi; i++ )
1997    {
1998        R += counts[i];
1999        rsum += sum[i];
2000        sum[i] /= MAX(counts[i],1);
2001        sum_ptr[i] = sum + i;
2002    }
2003
2004    icvSortDblPtr( sum_ptr, mi, 0 );
2005
2006    // revert back to unnormalized sums
2007    // (there should be a very little loss of accuracy)
2008    for( i = 0; i < mi; i++ )
2009        sum[i] *= counts[i];
2010
2011    for( subset_i = 0; subset_i < mi-1; subset_i++ )
2012    {
2013        int idx = (int)(sum_ptr[subset_i] - sum);
2014        int ni = counts[idx];
2015
2016        if( ni )
2017        {
2018            double s = sum[idx];
2019            lsum += s; L += ni;
2020            rsum -= s; R -= ni;
2021
2022            if( L && R )
2023            {
2024                double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
2025                if( best_val < val )
2026                {
2027                    best_val = val;
2028                    best_subset = subset_i;
2029                }
2030            }
2031        }
2032    }
2033
2034    if( best_subset < 0 )
2035        return 0;
2036
2037    split = data->new_split_cat( vi, (float)best_val );
2038    for( i = 0; i <= best_subset; i++ )
2039    {
2040        int idx = (int)(sum_ptr[i] - sum);
2041        split->subset[idx >> 5] |= 1 << (idx & 31);
2042    }
2043
2044    return split;
2045}
2046
2047
2048CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )
2049{
2050    const float epsilon = FLT_EPSILON*2;
2051    const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
2052    const char* dir = (char*)data->direction->data.ptr;
2053    int n1 = node->get_num_valid(vi);
2054    // LL - number of samples that both the primary and the surrogate splits send to the left
2055    // LR - ... primary split sends to the left and the surrogate split sends to the right
2056    // RL - ... primary split sends to the right and the surrogate split sends to the left
2057    // RR - ... both send to the right
2058    int i, best_i = -1, best_inversed = 0;
2059    double best_val;
2060
2061    if( !data->have_priors )
2062    {
2063        int LL = 0, RL = 0, LR, RR;
2064        int worst_val = cvFloor(node->maxlr), _best_val = worst_val;
2065        int sum = 0, sum_abs = 0;
2066
2067        for( i = 0; i < n1; i++ )
2068        {
2069            int d = dir[sorted[i].i];
2070            sum += d; sum_abs += d & 1;
2071        }
2072
2073        // sum_abs = R + L; sum = R - L
2074        RR = (sum_abs + sum) >> 1;
2075        LR = (sum_abs - sum) >> 1;
2076
2077        // initially all the samples are sent to the right by the surrogate split,
2078        // LR of them are sent to the left by primary split, and RR - to the right.
2079        // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
2080        for( i = 0; i < n1 - 1; i++ )
2081        {
2082            int d = dir[sorted[i].i];
2083
2084            if( d < 0 )
2085            {
2086                LL++; LR--;
2087                if( LL + RR > _best_val && sorted[i].val + epsilon < sorted[i+1].val )
2088                {
2089                    best_val = LL + RR;
2090                    best_i = i; best_inversed = 0;
2091                }
2092            }
2093            else if( d > 0 )
2094            {
2095                RL++; RR--;
2096                if( RL + LR > _best_val && sorted[i].val + epsilon < sorted[i+1].val )
2097                {
2098                    best_val = RL + LR;
2099                    best_i = i; best_inversed = 1;
2100                }
2101            }
2102        }
2103        best_val = _best_val;
2104    }
2105    else
2106    {
2107        double LL = 0, RL = 0, LR, RR;
2108        double worst_val = node->maxlr;
2109        double sum = 0, sum_abs = 0;
2110        const double* priors = data->priors_mult->data.db;
2111        const int* responses = data->get_class_labels(node);
2112        best_val = worst_val;
2113
2114        for( i = 0; i < n1; i++ )
2115        {
2116            int idx = sorted[i].i;
2117            double w = priors[responses[idx]];
2118            int d = dir[idx];
2119            sum += d*w; sum_abs += (d & 1)*w;
2120        }
2121
2122        // sum_abs = R + L; sum = R - L
2123        RR = (sum_abs + sum)*0.5;
2124        LR = (sum_abs - sum)*0.5;
2125
2126        // initially all the samples are sent to the right by the surrogate split,
2127        // LR of them are sent to the left by primary split, and RR - to the right.
2128        // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
2129        for( i = 0; i < n1 - 1; i++ )
2130        {
2131            int idx = sorted[i].i;
2132            double w = priors[responses[idx]];
2133            int d = dir[idx];
2134
2135            if( d < 0 )
2136            {
2137                LL += w; LR -= w;
2138                if( LL + RR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
2139                {
2140                    best_val = LL + RR;
2141                    best_i = i; best_inversed = 0;
2142                }
2143            }
2144            else if( d > 0 )
2145            {
2146                RL += w; RR -= w;
2147                if( RL + LR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
2148                {
2149                    best_val = RL + LR;
2150                    best_i = i; best_inversed = 1;
2151                }
2152            }
2153        }
2154    }
2155
2156    return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
2157        (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
2158        best_inversed, (float)best_val ) : 0;
2159}
2160
2161
2162CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )
2163{
2164    const int* labels = data->get_cat_var_data(node, vi);
2165    const char* dir = (char*)data->direction->data.ptr;
2166    int n = node->sample_count;
2167    // LL - number of samples that both the primary and the surrogate splits send to the left
2168    // LR - ... primary split sends to the left and the surrogate split sends to the right
2169    // RL - ... primary split sends to the right and the surrogate split sends to the left
2170    // RR - ... both send to the right
2171    CvDTreeSplit* split = data->new_split_cat( vi, 0 );
2172    int i, mi = data->cat_count->data.i[data->get_var_type(vi)], l_win = 0;
2173    double best_val = 0;
2174    double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;
2175    double* rc = lc + mi + 1;
2176
2177    for( i = -1; i < mi; i++ )
2178        lc[i] = rc[i] = 0;
2179
2180    // for each category calculate the weight of samples
2181    // sent to the left (lc) and to the right (rc) by the primary split
2182    if( !data->have_priors )
2183    {
2184        int* _lc = (int*)cvStackAlloc((mi+2)*2*sizeof(_lc[0])) + 1;
2185        int* _rc = _lc + mi + 1;
2186
2187        for( i = -1; i < mi; i++ )
2188            _lc[i] = _rc[i] = 0;
2189
2190        for( i = 0; i < n; i++ )
2191        {
2192            int idx = labels[i];
2193            int d = dir[i];
2194            int sum = _lc[idx] + d;
2195            int sum_abs = _rc[idx] + (d & 1);
2196            _lc[idx] = sum; _rc[idx] = sum_abs;
2197        }
2198
2199        for( i = 0; i < mi; i++ )
2200        {
2201            int sum = _lc[i];
2202            int sum_abs = _rc[i];
2203            lc[i] = (sum_abs - sum) >> 1;
2204            rc[i] = (sum_abs + sum) >> 1;
2205        }
2206    }
2207    else
2208    {
2209        const double* priors = data->priors_mult->data.db;
2210        const int* responses = data->get_class_labels(node);
2211
2212        for( i = 0; i < n; i++ )
2213        {
2214            int idx = labels[i];
2215            double w = priors[responses[i]];
2216            int d = dir[i];
2217            double sum = lc[idx] + d*w;
2218            double sum_abs = rc[idx] + (d & 1)*w;
2219            lc[idx] = sum; rc[idx] = sum_abs;
2220        }
2221
2222        for( i = 0; i < mi; i++ )
2223        {
2224            double sum = lc[i];
2225            double sum_abs = rc[i];
2226            lc[i] = (sum_abs - sum) * 0.5;
2227            rc[i] = (sum_abs + sum) * 0.5;
2228        }
2229    }
2230
2231    // 2. now form the split.
2232    // in each category send all the samples to the same direction as majority
2233    for( i = 0; i < mi; i++ )
2234    {
2235        double lval = lc[i], rval = rc[i];
2236        if( lval > rval )
2237        {
2238            split->subset[i >> 5] |= 1 << (i & 31);
2239            best_val += lval;
2240            l_win++;
2241        }
2242        else
2243            best_val += rval;
2244    }
2245
2246    split->quality = (float)best_val;
2247    if( split->quality <= node->maxlr || l_win == 0 || l_win == mi )
2248        cvSetRemoveByPtr( data->split_heap, split ), split = 0;
2249
2250    return split;
2251}
2252
2253
2254void CvDTree::calc_node_value( CvDTreeNode* node )
2255{
2256    int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;
2257    const int* cv_labels = data->get_labels(node);
2258
2259    if( data->is_classifier )
2260    {
2261        // in case of classification tree:
2262        //  * node value is the label of the class that has the largest weight in the node.
2263        //  * node risk is the weighted number of misclassified samples,
2264        //  * j-th cross-validation fold value and risk are calculated as above,
2265        //    but using the samples with cv_labels(*)!=j.
2266        //  * j-th cross-validation fold error is calculated as the weighted number of
2267        //    misclassified samples with cv_labels(*)==j.
2268
2269        // compute the number of instances of each class
2270        int* cls_count = data->counts->data.i;
2271        const int* responses = data->get_class_labels(node);
2272        int m = data->get_num_classes();
2273        int* cv_cls_count = (int*)cvStackAlloc(m*cv_n*sizeof(cv_cls_count[0]));
2274        double max_val = -1, total_weight = 0;
2275        int max_k = -1;
2276        double* priors = data->priors_mult->data.db;
2277
2278        for( k = 0; k < m; k++ )
2279            cls_count[k] = 0;
2280
2281        if( cv_n == 0 )
2282        {
2283            for( i = 0; i < n; i++ )
2284                cls_count[responses[i]]++;
2285        }
2286        else
2287        {
2288            for( j = 0; j < cv_n; j++ )
2289                for( k = 0; k < m; k++ )
2290                    cv_cls_count[j*m + k] = 0;
2291
2292            for( i = 0; i < n; i++ )
2293            {
2294                j = cv_labels[i]; k = responses[i];
2295                cv_cls_count[j*m + k]++;
2296            }
2297
2298            for( j = 0; j < cv_n; j++ )
2299                for( k = 0; k < m; k++ )
2300                    cls_count[k] += cv_cls_count[j*m + k];
2301        }
2302
2303        if( data->have_priors && node->parent == 0 )
2304        {
2305            // compute priors_mult from priors, take the sample ratio into account.
2306            double sum = 0;
2307            for( k = 0; k < m; k++ )
2308            {
2309                int n_k = cls_count[k];
2310                priors[k] = data->priors->data.db[k]*(n_k ? 1./n_k : 0.);
2311                sum += priors[k];
2312            }
2313            sum = 1./sum;
2314            for( k = 0; k < m; k++ )
2315                priors[k] *= sum;
2316        }
2317
2318        for( k = 0; k < m; k++ )
2319        {
2320            double val = cls_count[k]*priors[k];
2321            total_weight += val;
2322            if( max_val < val )
2323            {
2324                max_val = val;
2325                max_k = k;
2326            }
2327        }
2328
2329        node->class_idx = max_k;
2330        node->value = data->cat_map->data.i[
2331            data->cat_ofs->data.i[data->cat_var_count] + max_k];
2332        node->node_risk = total_weight - max_val;
2333
2334        for( j = 0; j < cv_n; j++ )
2335        {
2336            double sum_k = 0, sum = 0, max_val_k = 0;
2337            max_val = -1; max_k = -1;
2338
2339            for( k = 0; k < m; k++ )
2340            {
2341                double w = priors[k];
2342                double val_k = cv_cls_count[j*m + k]*w;
2343                double val = cls_count[k]*w - val_k;
2344                sum_k += val_k;
2345                sum += val;
2346                if( max_val < val )
2347                {
2348                    max_val = val;
2349                    max_val_k = val_k;
2350                    max_k = k;
2351                }
2352            }
2353
2354            node->cv_Tn[j] = INT_MAX;
2355            node->cv_node_risk[j] = sum - max_val;
2356            node->cv_node_error[j] = sum_k - max_val_k;
2357        }
2358    }
2359    else
2360    {
2361        // in case of regression tree:
2362        //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
2363        //    n is the number of samples in the node.
2364        //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
2365        //  * j-th cross-validation fold value and risk are calculated as above,
2366        //    but using the samples with cv_labels(*)!=j.
2367        //  * j-th cross-validation fold error is calculated
2368        //    using samples with cv_labels(*)==j as the test subset:
2369        //    error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
2370        //    where node_value_j is the node value calculated
2371        //    as described in the previous bullet, and summation is done
2372        //    over the samples with cv_labels(*)==j.
2373
2374        double sum = 0, sum2 = 0;
2375        const float* values = data->get_ord_responses(node);
2376        double *cv_sum = 0, *cv_sum2 = 0;
2377        int* cv_count = 0;
2378
2379        if( cv_n == 0 )
2380        {
2381            for( i = 0; i < n; i++ )
2382            {
2383                double t = values[i];
2384                sum += t;
2385                sum2 += t*t;
2386            }
2387        }
2388        else
2389        {
2390            cv_sum = (double*)cvStackAlloc( cv_n*sizeof(cv_sum[0]) );
2391            cv_sum2 = (double*)cvStackAlloc( cv_n*sizeof(cv_sum2[0]) );
2392            cv_count = (int*)cvStackAlloc( cv_n*sizeof(cv_count[0]) );
2393
2394            for( j = 0; j < cv_n; j++ )
2395            {
2396                cv_sum[j] = cv_sum2[j] = 0.;
2397                cv_count[j] = 0;
2398            }
2399
2400            for( i = 0; i < n; i++ )
2401            {
2402                j = cv_labels[i];
2403                double t = values[i];
2404                double s = cv_sum[j] + t;
2405                double s2 = cv_sum2[j] + t*t;
2406                int nc = cv_count[j] + 1;
2407                cv_sum[j] = s;
2408                cv_sum2[j] = s2;
2409                cv_count[j] = nc;
2410            }
2411
2412            for( j = 0; j < cv_n; j++ )
2413            {
2414                sum += cv_sum[j];
2415                sum2 += cv_sum2[j];
2416            }
2417        }
2418
2419        node->node_risk = sum2 - (sum/n)*sum;
2420        node->value = sum/n;
2421
2422        for( j = 0; j < cv_n; j++ )
2423        {
2424            double s = cv_sum[j], si = sum - s;
2425            double s2 = cv_sum2[j], s2i = sum2 - s2;
2426            int c = cv_count[j], ci = n - c;
2427            double r = si/MAX(ci,1);
2428            node->cv_node_risk[j] = s2i - r*r*ci;
2429            node->cv_node_error[j] = s2 - 2*r*s + c*r*r;
2430            node->cv_Tn[j] = INT_MAX;
2431        }
2432    }
2433}
2434
2435
2436void CvDTree::complete_node_dir( CvDTreeNode* node )
2437{
2438    int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
2439    int nz = n - node->get_num_valid(node->split->var_idx);
2440    char* dir = (char*)data->direction->data.ptr;
2441
2442    // try to complete direction using surrogate splits
2443    if( nz && data->params.use_surrogates )
2444    {
2445        CvDTreeSplit* split = node->split->next;
2446        for( ; split != 0 && nz; split = split->next )
2447        {
2448            int inversed_mask = split->inversed ? -1 : 0;
2449            vi = split->var_idx;
2450
2451            if( data->get_var_type(vi) >= 0 ) // split on categorical var
2452            {
2453                const int* labels = data->get_cat_var_data(node, vi);
2454                const int* subset = split->subset;
2455
2456                for( i = 0; i < n; i++ )
2457                {
2458                    int idx;
2459                    if( !dir[i] && (idx = labels[i]) >= 0 )
2460                    {
2461                        int d = CV_DTREE_CAT_DIR(idx,subset);
2462                        dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
2463                        if( --nz )
2464                            break;
2465                    }
2466                }
2467            }
2468            else // split on ordered var
2469            {
2470                const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
2471                int split_point = split->ord.split_point;
2472                int n1 = node->get_num_valid(vi);
2473
2474                assert( 0 <= split_point && split_point < n-1 );
2475
2476                for( i = 0; i < n1; i++ )
2477                {
2478                    int idx = sorted[i].i;
2479                    if( !dir[idx] )
2480                    {
2481                        int d = i <= split_point ? -1 : 1;
2482                        dir[idx] = (char)((d ^ inversed_mask) - inversed_mask);
2483                        if( --nz )
2484                            break;
2485                    }
2486                }
2487            }
2488        }
2489    }
2490
2491    // find the default direction for the rest
2492    if( nz )
2493    {
2494        for( i = nr = 0; i < n; i++ )
2495            nr += dir[i] > 0;
2496        nl = n - nr - nz;
2497        d0 = nl > nr ? -1 : nr > nl;
2498    }
2499
2500    // make sure that every sample is directed either to the left or to the right
2501    for( i = 0; i < n; i++ )
2502    {
2503        int d = dir[i];
2504        if( !d )
2505        {
2506            d = d0;
2507            if( !d )
2508                d = d1, d1 = -d1;
2509        }
2510        d = d > 0;
2511        dir[i] = (char)d; // remap (-1,1) to (0,1)
2512    }
2513}
2514
2515
2516void CvDTree::split_node_data( CvDTreeNode* node )
2517{
2518    int vi, i, n = node->sample_count, nl, nr;
2519    char* dir = (char*)data->direction->data.ptr;
2520    CvDTreeNode *left = 0, *right = 0;
2521    int* new_idx = data->split_buf->data.i;
2522    int new_buf_idx = data->get_child_buf_idx( node );
2523    int work_var_count = data->get_work_var_count();
2524
2525    // speedup things a little, especially for tree ensembles with a lots of small trees:
2526    //   do not physically split the input data between the left and right child nodes
2527    //   when we are not going to split them further,
2528    //   as calc_node_value() does not requires input features anyway.
2529    bool split_input_data;
2530
2531    complete_node_dir(node);
2532
2533    for( i = nl = nr = 0; i < n; i++ )
2534    {
2535        int d = dir[i];
2536        // initialize new indices for splitting ordered variables
2537        new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li
2538        nr += d;
2539        nl += d^1;
2540    }
2541
2542    node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
2543    node->right = right = data->new_node( node, nr, new_buf_idx, node->offset +
2544        (data->ord_var_count + work_var_count)*nl );
2545
2546    split_input_data = node->depth + 1 < data->params.max_depth &&
2547        (node->left->sample_count > data->params.min_sample_count ||
2548        node->right->sample_count > data->params.min_sample_count);
2549
2550    // split ordered variables, keep both halves sorted.
2551    for( vi = 0; vi < data->var_count; vi++ )
2552    {
2553        int ci = data->get_var_type(vi);
2554        int n1 = node->get_num_valid(vi);
2555        CvPair32s32f *src, *ldst0, *rdst0, *ldst, *rdst;
2556        CvPair32s32f tl, tr;
2557
2558        if( ci >= 0 || !split_input_data )
2559            continue;
2560
2561        src = data->get_ord_var_data(node, vi);
2562        ldst0 = ldst = data->get_ord_var_data(left, vi);
2563        rdst0 = rdst = data->get_ord_var_data(right, vi);
2564        tl = ldst0[nl]; tr = rdst0[nr];
2565
2566        // split sorted
2567        for( i = 0; i < n1; i++ )
2568        {
2569            int idx = src[i].i;
2570            float val = src[i].val;
2571            int d = dir[idx];
2572            idx = new_idx[idx];
2573            ldst->i = rdst->i = idx;
2574            ldst->val = rdst->val = val;
2575            ldst += d^1;
2576            rdst += d;
2577        }
2578
2579        left->set_num_valid(vi, (int)(ldst - ldst0));
2580        right->set_num_valid(vi, (int)(rdst - rdst0));
2581
2582        // split missing
2583        for( ; i < n; i++ )
2584        {
2585            int idx = src[i].i;
2586            int d = dir[idx];
2587            idx = new_idx[idx];
2588            ldst->i = rdst->i = idx;
2589            ldst->val = rdst->val = ord_nan;
2590            ldst += d^1;
2591            rdst += d;
2592        }
2593
2594        ldst0[nl] = tl; rdst0[nr] = tr;
2595    }
2596
2597    // split categorical vars, responses and cv_labels using new_idx relocation table
2598    for( vi = 0; vi < work_var_count; vi++ )
2599    {
2600        int ci = data->get_var_type(vi);
2601        int n1 = node->get_num_valid(vi), nr1 = 0;
2602        int *src, *ldst0, *rdst0, *ldst, *rdst;
2603        int tl, tr;
2604
2605        if( ci < 0 || (vi < data->var_count && !split_input_data) )
2606            continue;
2607
2608        src = data->get_cat_var_data(node, vi);
2609        ldst0 = ldst = data->get_cat_var_data(left, vi);
2610        rdst0 = rdst = data->get_cat_var_data(right, vi);
2611        tl = ldst0[nl]; tr = rdst0[nr];
2612
2613        for( i = 0; i < n; i++ )
2614        {
2615            int d = dir[i];
2616            int val = src[i];
2617            *ldst = *rdst = val;
2618            ldst += d^1;
2619            rdst += d;
2620            nr1 += (val >= 0)&d;
2621        }
2622
2623        if( vi < data->var_count )
2624        {
2625            left->set_num_valid(vi, n1 - nr1);
2626            right->set_num_valid(vi, nr1);
2627        }
2628
2629        ldst0[nl] = tl; rdst0[nr] = tr;
2630    }
2631
2632    // deallocate the parent node data that is not needed anymore
2633    data->free_node_data(node);
2634}
2635
2636
2637void CvDTree::prune_cv()
2638{
2639    CvMat* ab = 0;
2640    CvMat* temp = 0;
2641    CvMat* err_jk = 0;
2642
2643    // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
2644    // 2. choose the best tree index (if need, apply 1SE rule).
2645    // 3. store the best index and cut the branches.
2646
2647    CV_FUNCNAME( "CvDTree::prune_cv" );
2648
2649    __BEGIN__;
2650
2651    int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count;
2652    // currently, 1SE for regression is not implemented
2653    bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier;
2654    double* err;
2655    double min_err = 0, min_err_se = 0;
2656    int min_idx = -1;
2657
2658    CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));
2659
2660    // build the main tree sequence, calculate alpha's
2661    for(;;tree_count++)
2662    {
2663        double min_alpha = update_tree_rnc(tree_count, -1);
2664        if( cut_tree(tree_count, -1, min_alpha) )
2665            break;
2666
2667        if( ab->cols <= tree_count )
2668        {
2669            CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F ));
2670            for( ti = 0; ti < ab->cols; ti++ )
2671                temp->data.db[ti] = ab->data.db[ti];
2672            cvReleaseMat( &ab );
2673            ab = temp;
2674            temp = 0;
2675        }
2676
2677        ab->data.db[tree_count] = min_alpha;
2678    }
2679
2680    ab->data.db[0] = 0.;
2681
2682    if( tree_count > 0 )
2683    {
2684        for( ti = 1; ti < tree_count-1; ti++ )
2685            ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);
2686        ab->data.db[tree_count-1] = DBL_MAX*0.5;
2687
2688        CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));
2689        err = err_jk->data.db;
2690
2691        for( j = 0; j < cv_n; j++ )
2692        {
2693            int tj = 0, tk = 0;
2694            for( ; tk < tree_count; tj++ )
2695            {
2696                double min_alpha = update_tree_rnc(tj, j);
2697                if( cut_tree(tj, j, min_alpha) )
2698                    min_alpha = DBL_MAX;
2699
2700                for( ; tk < tree_count; tk++ )
2701                {
2702                    if( ab->data.db[tk] > min_alpha )
2703                        break;
2704                    err[j*tree_count + tk] = root->tree_error;
2705                }
2706            }
2707        }
2708
2709        for( ti = 0; ti < tree_count; ti++ )
2710        {
2711            double sum_err = 0;
2712            for( j = 0; j < cv_n; j++ )
2713                sum_err += err[j*tree_count + ti];
2714            if( ti == 0 || sum_err < min_err )
2715            {
2716                min_err = sum_err;
2717                min_idx = ti;
2718                if( use_1se )
2719                    min_err_se = sqrt( sum_err*(n - sum_err) );
2720            }
2721            else if( sum_err < min_err + min_err_se )
2722                min_idx = ti;
2723        }
2724    }
2725
2726    pruned_tree_idx = min_idx;
2727    free_prune_data(data->params.truncate_pruned_tree != 0);
2728
2729    __END__;
2730
2731    cvReleaseMat( &err_jk );
2732    cvReleaseMat( &ab );
2733    cvReleaseMat( &temp );
2734}
2735
2736
2737double CvDTree::update_tree_rnc( int T, int fold )
2738{
2739    CvDTreeNode* node = root;
2740    double min_alpha = DBL_MAX;
2741
2742    for(;;)
2743    {
2744        CvDTreeNode* parent;
2745        for(;;)
2746        {
2747            int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
2748            if( t <= T || !node->left )
2749            {
2750                node->complexity = 1;
2751                node->tree_risk = node->node_risk;
2752                node->tree_error = 0.;
2753                if( fold >= 0 )
2754                {
2755                    node->tree_risk = node->cv_node_risk[fold];
2756                    node->tree_error = node->cv_node_error[fold];
2757                }
2758                break;
2759            }
2760            node = node->left;
2761        }
2762
2763        for( parent = node->parent; parent && parent->right == node;
2764            node = parent, parent = parent->parent )
2765        {
2766            parent->complexity += node->complexity;
2767            parent->tree_risk += node->tree_risk;
2768            parent->tree_error += node->tree_error;
2769
2770            parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk)
2771                - parent->tree_risk)/(parent->complexity - 1);
2772            min_alpha = MIN( min_alpha, parent->alpha );
2773        }
2774
2775        if( !parent )
2776            break;
2777
2778        parent->complexity = node->complexity;
2779        parent->tree_risk = node->tree_risk;
2780        parent->tree_error = node->tree_error;
2781        node = parent->right;
2782    }
2783
2784    return min_alpha;
2785}
2786
2787
2788int CvDTree::cut_tree( int T, int fold, double min_alpha )
2789{
2790    CvDTreeNode* node = root;
2791    if( !node->left )
2792        return 1;
2793
2794    for(;;)
2795    {
2796        CvDTreeNode* parent;
2797        for(;;)
2798        {
2799            int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
2800            if( t <= T || !node->left )
2801                break;
2802            if( node->alpha <= min_alpha + FLT_EPSILON )
2803            {
2804                if( fold >= 0 )
2805                    node->cv_Tn[fold] = T;
2806                else
2807                    node->Tn = T;
2808                if( node == root )
2809                    return 1;
2810                break;
2811            }
2812            node = node->left;
2813        }
2814
2815        for( parent = node->parent; parent && parent->right == node;
2816            node = parent, parent = parent->parent )
2817            ;
2818
2819        if( !parent )
2820            break;
2821
2822        node = parent->right;
2823    }
2824
2825    return 0;
2826}
2827
2828
2829void CvDTree::free_prune_data(bool cut_tree)
2830{
2831    CvDTreeNode* node = root;
2832
2833    for(;;)
2834    {
2835        CvDTreeNode* parent;
2836        for(;;)
2837        {
2838            // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )
2839            // as we will clear the whole cross-validation heap at the end
2840            node->cv_Tn = 0;
2841            node->cv_node_error = node->cv_node_risk = 0;
2842            if( !node->left )
2843                break;
2844            node = node->left;
2845        }
2846
2847        for( parent = node->parent; parent && parent->right == node;
2848            node = parent, parent = parent->parent )
2849        {
2850            if( cut_tree && parent->Tn <= pruned_tree_idx )
2851            {
2852                data->free_node( parent->left );
2853                data->free_node( parent->right );
2854                parent->left = parent->right = 0;
2855            }
2856        }
2857
2858        if( !parent )
2859            break;
2860
2861        node = parent->right;
2862    }
2863
2864    if( data->cv_heap )
2865        cvClearSet( data->cv_heap );
2866}
2867
2868
2869void CvDTree::free_tree()
2870{
2871    if( root && data && data->shared )
2872    {
2873        pruned_tree_idx = INT_MIN;
2874        free_prune_data(true);
2875        data->free_node(root);
2876        root = 0;
2877    }
2878}
2879
2880
2881CvDTreeNode* CvDTree::predict( const CvMat* _sample,
2882    const CvMat* _missing, bool preprocessed_input ) const
2883{
2884    CvDTreeNode* result = 0;
2885    int* catbuf = 0;
2886
2887    CV_FUNCNAME( "CvDTree::predict" );
2888
2889    __BEGIN__;
2890
2891    int i, step, mstep = 0;
2892    const float* sample;
2893    const uchar* m = 0;
2894    CvDTreeNode* node = root;
2895    const int* vtype;
2896    const int* vidx;
2897    const int* cmap;
2898    const int* cofs;
2899
2900    if( !node )
2901        CV_ERROR( CV_StsError, "The tree has not been trained yet" );
2902
2903    if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
2904        _sample->cols != 1 && _sample->rows != 1 ||
2905        _sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input ||
2906        _sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input )
2907            CV_ERROR( CV_StsBadArg,
2908        "the input sample must be 1d floating-point vector with the same "
2909        "number of elements as the total number of variables used for training" );
2910
2911    sample = _sample->data.fl;
2912    step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]);
2913
2914    if( data->cat_count && !preprocessed_input ) // cache for categorical variables
2915    {
2916        int n = data->cat_count->cols;
2917        catbuf = (int*)cvStackAlloc(n*sizeof(catbuf[0]));
2918        for( i = 0; i < n; i++ )
2919            catbuf[i] = -1;
2920    }
2921
2922    if( _missing )
2923    {
2924        if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
2925        !CV_ARE_SIZES_EQ(_missing, _sample) )
2926            CV_ERROR( CV_StsBadArg,
2927        "the missing data mask must be 8-bit vector of the same size as input sample" );
2928        m = _missing->data.ptr;
2929        mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);
2930    }
2931
2932    vtype = data->var_type->data.i;
2933    vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;
2934    cmap = data->cat_map ? data->cat_map->data.i : 0;
2935    cofs = data->cat_ofs ? data->cat_ofs->data.i : 0;
2936
2937    while( node->Tn > pruned_tree_idx && node->left )
2938    {
2939        CvDTreeSplit* split = node->split;
2940        int dir = 0;
2941        for( ; !dir && split != 0; split = split->next )
2942        {
2943            int vi = split->var_idx;
2944            int ci = vtype[vi];
2945            i = vidx ? vidx[vi] : vi;
2946            float val = sample[i*step];
2947            if( m && m[i*mstep] )
2948                continue;
2949            if( ci < 0 ) // ordered
2950                dir = val <= split->ord.c ? -1 : 1;
2951            else // categorical
2952            {
2953                int c;
2954                if( preprocessed_input )
2955                    c = cvRound(val);
2956                else
2957                {
2958                    c = catbuf[ci];
2959                    if( c < 0 )
2960                    {
2961                        int a = c = cofs[ci];
2962                        int b = cofs[ci+1];
2963                        int ival = cvRound(val);
2964                        if( ival != val )
2965                            CV_ERROR( CV_StsBadArg,
2966                            "one of input categorical variable is not an integer" );
2967
2968                        while( a < b )
2969                        {
2970                            c = (a + b) >> 1;
2971                            if( ival < cmap[c] )
2972                                b = c;
2973                            else if( ival > cmap[c] )
2974                                a = c+1;
2975                            else
2976                                break;
2977                        }
2978
2979                        if( c < 0 || ival != cmap[c] )
2980                            continue;
2981
2982                        catbuf[ci] = c -= cofs[ci];
2983                    }
2984                }
2985                dir = CV_DTREE_CAT_DIR(c, split->subset);
2986            }
2987
2988            if( split->inversed )
2989                dir = -dir;
2990        }
2991
2992        if( !dir )
2993        {
2994            double diff = node->right->sample_count - node->left->sample_count;
2995            dir = diff < 0 ? -1 : 1;
2996        }
2997        node = dir < 0 ? node->left : node->right;
2998    }
2999
3000    result = node;
3001
3002    __END__;
3003
3004    return result;
3005}
3006
3007
3008const CvMat* CvDTree::get_var_importance()
3009{
3010    if( !var_importance )
3011    {
3012        CvDTreeNode* node = root;
3013        double* importance;
3014        if( !node )
3015            return 0;
3016        var_importance = cvCreateMat( 1, data->var_count, CV_64F );
3017        cvZero( var_importance );
3018        importance = var_importance->data.db;
3019
3020        for(;;)
3021        {
3022            CvDTreeNode* parent;
3023            for( ;; node = node->left )
3024            {
3025                CvDTreeSplit* split = node->split;
3026
3027                if( !node->left || node->Tn <= pruned_tree_idx )
3028                    break;
3029
3030                for( ; split != 0; split = split->next )
3031                    importance[split->var_idx] += split->quality;
3032            }
3033
3034            for( parent = node->parent; parent && parent->right == node;
3035                node = parent, parent = parent->parent )
3036                ;
3037
3038            if( !parent )
3039                break;
3040
3041            node = parent->right;
3042        }
3043
3044        cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
3045    }
3046
3047    return var_importance;
3048}
3049
3050
3051void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split )
3052{
3053    int ci;
3054
3055    cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );
3056    cvWriteInt( fs, "var", split->var_idx );
3057    cvWriteReal( fs, "quality", split->quality );
3058
3059    ci = data->get_var_type(split->var_idx);
3060    if( ci >= 0 ) // split on a categorical var
3061    {
3062        int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;
3063        for( i = 0; i < n; i++ )
3064            to_right += CV_DTREE_CAT_DIR(i,split->subset) > 0;
3065
3066        // ad-hoc rule when to use inverse categorical split notation
3067        // to achieve more compact and clear representation
3068        default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;
3069
3070        cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?
3071                            "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );
3072
3073        for( i = 0; i < n; i++ )
3074        {
3075            int dir = CV_DTREE_CAT_DIR(i,split->subset);
3076            if( dir*default_dir < 0 )
3077                cvWriteInt( fs, 0, i );
3078        }
3079        cvEndWriteStruct( fs );
3080    }
3081    else
3082        cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c );
3083
3084    cvEndWriteStruct( fs );
3085}
3086
3087
3088void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node )
3089{
3090    CvDTreeSplit* split;
3091
3092    cvStartWriteStruct( fs, 0, CV_NODE_MAP );
3093
3094    cvWriteInt( fs, "depth", node->depth );
3095    cvWriteInt( fs, "sample_count", node->sample_count );
3096    cvWriteReal( fs, "value", node->value );
3097
3098    if( data->is_classifier )
3099        cvWriteInt( fs, "norm_class_idx", node->class_idx );
3100
3101    cvWriteInt( fs, "Tn", node->Tn );
3102    cvWriteInt( fs, "complexity", node->complexity );
3103    cvWriteReal( fs, "alpha", node->alpha );
3104    cvWriteReal( fs, "node_risk", node->node_risk );
3105    cvWriteReal( fs, "tree_risk", node->tree_risk );
3106    cvWriteReal( fs, "tree_error", node->tree_error );
3107
3108    if( node->left )
3109    {
3110        cvStartWriteStruct( fs, "splits", CV_NODE_SEQ );
3111
3112        for( split = node->split; split != 0; split = split->next )
3113            write_split( fs, split );
3114
3115        cvEndWriteStruct( fs );
3116    }
3117
3118    cvEndWriteStruct( fs );
3119}
3120
3121
3122void CvDTree::write_tree_nodes( CvFileStorage* fs )
3123{
3124    //CV_FUNCNAME( "CvDTree::write_tree_nodes" );
3125
3126    __BEGIN__;
3127
3128    CvDTreeNode* node = root;
3129
3130    // traverse the tree and save all the nodes in depth-first order
3131    for(;;)
3132    {
3133        CvDTreeNode* parent;
3134        for(;;)
3135        {
3136            write_node( fs, node );
3137            if( !node->left )
3138                break;
3139            node = node->left;
3140        }
3141
3142        for( parent = node->parent; parent && parent->right == node;
3143            node = parent, parent = parent->parent )
3144            ;
3145
3146        if( !parent )
3147            break;
3148
3149        node = parent->right;
3150    }
3151
3152    __END__;
3153}
3154
3155
3156void CvDTree::write( CvFileStorage* fs, const char* name )
3157{
3158    //CV_FUNCNAME( "CvDTree::write" );
3159
3160    __BEGIN__;
3161
3162    cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );
3163
3164    get_var_importance();
3165    data->write_params( fs );
3166    if( var_importance )
3167        cvWrite( fs, "var_importance", var_importance );
3168    write( fs );
3169
3170    cvEndWriteStruct( fs );
3171
3172    __END__;
3173}
3174
3175
3176void CvDTree::write( CvFileStorage* fs )
3177{
3178    //CV_FUNCNAME( "CvDTree::write" );
3179
3180    __BEGIN__;
3181
3182    cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
3183
3184    cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
3185    write_tree_nodes( fs );
3186    cvEndWriteStruct( fs );
3187
3188    __END__;
3189}
3190
3191
3192CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )
3193{
3194    CvDTreeSplit* split = 0;
3195
3196    CV_FUNCNAME( "CvDTree::read_split" );
3197
3198    __BEGIN__;
3199
3200    int vi, ci;
3201
3202    if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3203        CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );
3204
3205    vi = cvReadIntByName( fs, fnode, "var", -1 );
3206    if( (unsigned)vi >= (unsigned)data->var_count )
3207        CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" );
3208
3209    ci = data->get_var_type(vi);
3210    if( ci >= 0 ) // split on categorical var
3211    {
3212        int i, n = data->cat_count->data.i[ci], inversed = 0, val;
3213        CvSeqReader reader;
3214        CvFileNode* inseq;
3215        split = data->new_split_cat( vi, 0 );
3216        inseq = cvGetFileNodeByName( fs, fnode, "in" );
3217        if( !inseq )
3218        {
3219            inseq = cvGetFileNodeByName( fs, fnode, "not_in" );
3220            inversed = 1;
3221        }
3222        if( !inseq ||
3223            (CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ && CV_NODE_TYPE(inseq->tag) != CV_NODE_INT))
3224            CV_ERROR( CV_StsParseError,
3225            "Either 'in' or 'not_in' tags should be inside a categorical split data" );
3226
3227        if( CV_NODE_TYPE(inseq->tag) == CV_NODE_INT )
3228        {
3229            val = inseq->data.i;
3230            if( (unsigned)val >= (unsigned)n )
3231                CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
3232
3233            split->subset[val >> 5] |= 1 << (val & 31);
3234        }
3235        else
3236        {
3237            cvStartReadSeq( inseq->data.seq, &reader );
3238
3239            for( i = 0; i < reader.seq->total; i++ )
3240            {
3241                CvFileNode* inode = (CvFileNode*)reader.ptr;
3242                val = inode->data.i;
3243                if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )
3244                    CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
3245
3246                split->subset[val >> 5] |= 1 << (val & 31);
3247                CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3248            }
3249        }
3250
3251        // for categorical splits we do not use inversed splits,
3252        // instead we inverse the variable set in the split
3253        if( inversed )
3254            for( i = 0; i < (n + 31) >> 5; i++ )
3255                split->subset[i] ^= -1;
3256    }
3257    else
3258    {
3259        CvFileNode* cmp_node;
3260        split = data->new_split_ord( vi, 0, 0, 0, 0 );
3261
3262        cmp_node = cvGetFileNodeByName( fs, fnode, "le" );
3263        if( !cmp_node )
3264        {
3265            cmp_node = cvGetFileNodeByName( fs, fnode, "gt" );
3266            split->inversed = 1;
3267        }
3268
3269        split->ord.c = (float)cvReadReal( cmp_node );
3270    }
3271
3272    split->quality = (float)cvReadRealByName( fs, fnode, "quality" );
3273
3274    __END__;
3275
3276    return split;
3277}
3278
3279
3280CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )
3281{
3282    CvDTreeNode* node = 0;
3283
3284    CV_FUNCNAME( "CvDTree::read_node" );
3285
3286    __BEGIN__;
3287
3288    CvFileNode* splits;
3289    int i, depth;
3290
3291    if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3292        CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" );
3293
3294    CV_CALL( node = data->new_node( parent, 0, 0, 0 ));
3295    depth = cvReadIntByName( fs, fnode, "depth", -1 );
3296    if( depth != node->depth )
3297        CV_ERROR( CV_StsParseError, "incorrect node depth" );
3298
3299    node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );
3300    node->value = cvReadRealByName( fs, fnode, "value" );
3301    if( data->is_classifier )
3302        node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );
3303
3304    node->Tn = cvReadIntByName( fs, fnode, "Tn" );
3305    node->complexity = cvReadIntByName( fs, fnode, "complexity" );
3306    node->alpha = cvReadRealByName( fs, fnode, "alpha" );
3307    node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );
3308    node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );
3309    node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );
3310
3311    splits = cvGetFileNodeByName( fs, fnode, "splits" );
3312    if( splits )
3313    {
3314        CvSeqReader reader;
3315        CvDTreeSplit* last_split = 0;
3316
3317        if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ )
3318            CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" );
3319
3320        cvStartReadSeq( splits->data.seq, &reader );
3321        for( i = 0; i < reader.seq->total; i++ )
3322        {
3323            CvDTreeSplit* split;
3324            CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr ));
3325            if( !last_split )
3326                node->split = last_split = split;
3327            else
3328                last_split = last_split->next = split;
3329
3330            CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3331        }
3332    }
3333
3334    __END__;
3335
3336    return node;
3337}
3338
3339
3340void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )
3341{
3342    CV_FUNCNAME( "CvDTree::read_tree_nodes" );
3343
3344    __BEGIN__;
3345
3346    CvSeqReader reader;
3347    CvDTreeNode _root;
3348    CvDTreeNode* parent = &_root;
3349    int i;
3350    parent->left = parent->right = parent->parent = 0;
3351
3352    cvStartReadSeq( fnode->data.seq, &reader );
3353
3354    for( i = 0; i < reader.seq->total; i++ )
3355    {
3356        CvDTreeNode* node;
3357
3358        CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));
3359        if( !parent->left )
3360            parent->left = node;
3361        else
3362            parent->right = node;
3363        if( node->split )
3364            parent = node;
3365        else
3366        {
3367            while( parent && parent->right )
3368                parent = parent->parent;
3369        }
3370
3371        CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3372    }
3373
3374    root = _root.left;
3375
3376    __END__;
3377}
3378
3379
3380void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )
3381{
3382    CvDTreeTrainData* _data = new CvDTreeTrainData();
3383    _data->read_params( fs, fnode );
3384
3385    read( fs, fnode, _data );
3386    get_var_importance();
3387}
3388
3389
3390// a special entry point for reading weak decision trees from the tree ensembles
3391void CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data )
3392{
3393    CV_FUNCNAME( "CvDTree::read" );
3394
3395    __BEGIN__;
3396
3397    CvFileNode* tree_nodes;
3398
3399    clear();
3400    data = _data;
3401
3402    tree_nodes = cvGetFileNodeByName( fs, node, "nodes" );
3403    if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )
3404        CV_ERROR( CV_StsParseError, "nodes tag is missing" );
3405
3406    pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 );
3407    read_tree_nodes( fs, tree_nodes );
3408
3409    __END__;
3410}
3411
3412/* End of file. */
3413