1/*M///////////////////////////////////////////////////////////////////////////////////////
2//
3//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4//
5//  By downloading, copying, installing or using the software you agree to this license.
6//  If you do not agree to this license, do not download, install,
7//  copy or use the software.
8//
9//
10//                        Intel License Agreement
11//
12// Copyright (C) 2000, Intel Corporation, all rights reserved.
13// Third party copyrights are property of their respective owners.
14//
15// Redistribution and use in source and binary forms, with or without modification,
16// are permitted provided that the following conditions are met:
17//
18//   * Redistribution's of source code must retain the above copyright notice,
19//     this list of conditions and the following disclaimer.
20//
21//   * Redistribution's in binary form must reproduce the above copyright notice,
22//     this list of conditions and the following disclaimer in the documentation
23//     and/or other materials provided with the distribution.
24//
25//   * The name of Intel Corporation may not be used to endorse or promote products
26//     derived from this software without specific prior written permission.
27//
28// This software is provided by the copyright holders and contributors "as is" and
29// any express or implied warranties, including, but not limited to, the implied
30// warranties of merchantability and fitness for a particular purpose are disclaimed.
31// In no event shall the Intel Corporation or contributors be liable for any direct,
32// indirect, incidental, special, exemplary, or consequential damages
33// (including, but not limited to, procurement of substitute goods or services;
34// loss of use, data, or profits; or business interruption) however caused
35// and on any theory of liability, whether in contract, strict liability,
36// or tort (including negligence or otherwise) arising in any way out of
37// the use of this software, even if advised of the possibility of such damage.
38//
39//M*/
40
41#include "_ml.h"
42
43static inline double
44log_ratio( double val )
45{
46    const double eps = 1e-5;
47
48    val = MAX( val, eps );
49    val = MIN( val, 1. - eps );
50    return log( val/(1. - val) );
51}
52
53
54CvBoostParams::CvBoostParams()
55{
56    boost_type = CvBoost::REAL;
57    weak_count = 100;
58    weight_trim_rate = 0.95;
59    cv_folds = 0;
60    max_depth = 1;
61}
62
63
64CvBoostParams::CvBoostParams( int _boost_type, int _weak_count,
65                                        double _weight_trim_rate, int _max_depth,
66                                        bool _use_surrogates, const float* _priors )
67{
68    boost_type = _boost_type;
69    weak_count = _weak_count;
70    weight_trim_rate = _weight_trim_rate;
71    split_criteria = CvBoost::DEFAULT;
72    cv_folds = 0;
73    max_depth = _max_depth;
74    use_surrogates = _use_surrogates;
75    priors = _priors;
76}
77
78
79
80///////////////////////////////// CvBoostTree ///////////////////////////////////
81
82CvBoostTree::CvBoostTree()
83{
84    ensemble = 0;
85}
86
87
88CvBoostTree::~CvBoostTree()
89{
90    clear();
91}
92
93
94void
95CvBoostTree::clear()
96{
97    CvDTree::clear();
98    ensemble = 0;
99}
100
101
102bool
103CvBoostTree::train( CvDTreeTrainData* _train_data,
104                    const CvMat* _subsample_idx, CvBoost* _ensemble )
105{
106    clear();
107    ensemble = _ensemble;
108    data = _train_data;
109    data->shared = true;
110
111    return do_train( _subsample_idx );
112}
113
114
115bool
116CvBoostTree::train( const CvMat*, int, const CvMat*, const CvMat*,
117                    const CvMat*, const CvMat*, const CvMat*, CvDTreeParams )
118{
119    assert(0);
120    return false;
121}
122
123
124bool
125CvBoostTree::train( CvDTreeTrainData*, const CvMat* )
126{
127    assert(0);
128    return false;
129}
130
131
132void
133CvBoostTree::scale( double scale )
134{
135    CvDTreeNode* node = root;
136
137    // traverse the tree and scale all the node values
138    for(;;)
139    {
140        CvDTreeNode* parent;
141        for(;;)
142        {
143            node->value *= scale;
144            if( !node->left )
145                break;
146            node = node->left;
147        }
148
149        for( parent = node->parent; parent && parent->right == node;
150            node = parent, parent = parent->parent )
151            ;
152
153        if( !parent )
154            break;
155
156        node = parent->right;
157    }
158}
159
160
161void
162CvBoostTree::try_split_node( CvDTreeNode* node )
163{
164    CvDTree::try_split_node( node );
165
166    if( !node->left )
167    {
168        // if the node has not been split,
169        // store the responses for the corresponding training samples
170        double* weak_eval = ensemble->get_weak_response()->data.db;
171        int* labels = data->get_labels( node );
172        int i, count = node->sample_count;
173        double value = node->value;
174
175        for( i = 0; i < count; i++ )
176            weak_eval[labels[i]] = value;
177    }
178}
179
180
181double
182CvBoostTree::calc_node_dir( CvDTreeNode* node )
183{
184    char* dir = (char*)data->direction->data.ptr;
185    const double* weights = ensemble->get_subtree_weights()->data.db;
186    int i, n = node->sample_count, vi = node->split->var_idx;
187    double L, R;
188
189    assert( !node->split->inversed );
190
191    if( data->get_var_type(vi) >= 0 ) // split on categorical var
192    {
193        const int* cat_labels = data->get_cat_var_data( node, vi );
194        const int* subset = node->split->subset;
195        double sum = 0, sum_abs = 0;
196
197        for( i = 0; i < n; i++ )
198        {
199            int idx = cat_labels[i];
200            double w = weights[i];
201            int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
202            sum += d*w; sum_abs += (d & 1)*w;
203            dir[i] = (char)d;
204        }
205
206        R = (sum_abs + sum) * 0.5;
207        L = (sum_abs - sum) * 0.5;
208    }
209    else // split on ordered var
210    {
211        const CvPair32s32f* sorted = data->get_ord_var_data(node,vi);
212        int split_point = node->split->ord.split_point;
213        int n1 = node->get_num_valid(vi);
214
215        assert( 0 <= split_point && split_point < n1-1 );
216        L = R = 0;
217
218        for( i = 0; i <= split_point; i++ )
219        {
220            int idx = sorted[i].i;
221            double w = weights[idx];
222            dir[idx] = (char)-1;
223            L += w;
224        }
225
226        for( ; i < n1; i++ )
227        {
228            int idx = sorted[i].i;
229            double w = weights[idx];
230            dir[idx] = (char)1;
231            R += w;
232        }
233
234        for( ; i < n; i++ )
235            dir[sorted[i].i] = (char)0;
236    }
237
238    node->maxlr = MAX( L, R );
239    return node->split->quality/(L + R);
240}
241
242
243CvDTreeSplit*
244CvBoostTree::find_split_ord_class( CvDTreeNode* node, int vi )
245{
246    const float epsilon = FLT_EPSILON*2;
247    const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
248    const int* responses = data->get_class_labels(node);
249    const double* weights = ensemble->get_subtree_weights()->data.db;
250    int n = node->sample_count;
251    int n1 = node->get_num_valid(vi);
252    const double* rcw0 = weights + n;
253    double lcw[2] = {0,0}, rcw[2];
254    int i, best_i = -1;
255    double best_val = 0;
256    int boost_type = ensemble->get_params().boost_type;
257    int split_criteria = ensemble->get_params().split_criteria;
258
259    rcw[0] = rcw0[0]; rcw[1] = rcw0[1];
260    for( i = n1; i < n; i++ )
261    {
262        int idx = sorted[i].i;
263        double w = weights[idx];
264        rcw[responses[idx]] -= w;
265    }
266
267    if( split_criteria != CvBoost::GINI && split_criteria != CvBoost::MISCLASS )
268        split_criteria = boost_type == CvBoost::DISCRETE ? CvBoost::MISCLASS : CvBoost::GINI;
269
270    if( split_criteria == CvBoost::GINI )
271    {
272        double L = 0, R = rcw[0] + rcw[1];
273        double lsum2 = 0, rsum2 = rcw[0]*rcw[0] + rcw[1]*rcw[1];
274
275        for( i = 0; i < n1 - 1; i++ )
276        {
277            int idx = sorted[i].i;
278            double w = weights[idx], w2 = w*w;
279            double lv, rv;
280            idx = responses[idx];
281            L += w; R -= w;
282            lv = lcw[idx]; rv = rcw[idx];
283            lsum2 += 2*lv*w + w2;
284            rsum2 -= 2*rv*w - w2;
285            lcw[idx] = lv + w; rcw[idx] = rv - w;
286
287            if( sorted[i].val + epsilon < sorted[i+1].val )
288            {
289                double val = (lsum2*R + rsum2*L)/(L*R);
290                if( best_val < val )
291                {
292                    best_val = val;
293                    best_i = i;
294                }
295            }
296        }
297    }
298    else
299    {
300        for( i = 0; i < n1 - 1; i++ )
301        {
302            int idx = sorted[i].i;
303            double w = weights[idx];
304            idx = responses[idx];
305            lcw[idx] += w;
306            rcw[idx] -= w;
307
308            if( sorted[i].val + epsilon < sorted[i+1].val )
309            {
310                double val = lcw[0] + rcw[1], val2 = lcw[1] + rcw[0];
311                val = MAX(val, val2);
312                if( best_val < val )
313                {
314                    best_val = val;
315                    best_i = i;
316                }
317            }
318        }
319    }
320
321    return best_i >= 0 ? data->new_split_ord( vi,
322        (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
323        0, (float)best_val ) : 0;
324}
325
326
327#define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
328static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )
329
330CvDTreeSplit*
331CvBoostTree::find_split_cat_class( CvDTreeNode* node, int vi )
332{
333    CvDTreeSplit* split;
334    const int* cat_labels = data->get_cat_var_data(node, vi);
335    const int* responses = data->get_class_labels(node);
336    int ci = data->get_var_type(vi);
337    int n = node->sample_count;
338    int mi = data->cat_count->data.i[ci];
339    double lcw[2]={0,0}, rcw[2]={0,0};
340    double* cjk = (double*)cvStackAlloc(2*(mi+1)*sizeof(cjk[0]))+2;
341    const double* weights = ensemble->get_subtree_weights()->data.db;
342    double** dbl_ptr = (double**)cvStackAlloc( mi*sizeof(dbl_ptr[0]) );
343    int i, j, k, idx;
344    double L = 0, R;
345    double best_val = 0;
346    int best_subset = -1, subset_i;
347    int boost_type = ensemble->get_params().boost_type;
348    int split_criteria = ensemble->get_params().split_criteria;
349
350    // init array of counters:
351    // c_{jk} - number of samples that have vi-th input variable = j and response = k.
352    for( j = -1; j < mi; j++ )
353        cjk[j*2] = cjk[j*2+1] = 0;
354
355    for( i = 0; i < n; i++ )
356    {
357        double w = weights[i];
358        j = cat_labels[i];
359        k = responses[i];
360        cjk[j*2 + k] += w;
361    }
362
363    for( j = 0; j < mi; j++ )
364    {
365        rcw[0] += cjk[j*2];
366        rcw[1] += cjk[j*2+1];
367        dbl_ptr[j] = cjk + j*2 + 1;
368    }
369
370    R = rcw[0] + rcw[1];
371
372    if( split_criteria != CvBoost::GINI && split_criteria != CvBoost::MISCLASS )
373        split_criteria = boost_type == CvBoost::DISCRETE ? CvBoost::MISCLASS : CvBoost::GINI;
374
375    // sort rows of c_jk by increasing c_j,1
376    // (i.e. by the weight of samples in j-th category that belong to class 1)
377    icvSortDblPtr( dbl_ptr, mi, 0 );
378
379    for( subset_i = 0; subset_i < mi-1; subset_i++ )
380    {
381        idx = (int)(dbl_ptr[subset_i] - cjk)/2;
382        const double* crow = cjk + idx*2;
383        double w0 = crow[0], w1 = crow[1];
384        double weight = w0 + w1;
385
386        if( weight < FLT_EPSILON )
387            continue;
388
389        lcw[0] += w0; rcw[0] -= w0;
390        lcw[1] += w1; rcw[1] -= w1;
391
392        if( split_criteria == CvBoost::GINI )
393        {
394            double lsum2 = lcw[0]*lcw[0] + lcw[1]*lcw[1];
395            double rsum2 = rcw[0]*rcw[0] + rcw[1]*rcw[1];
396
397            L += weight;
398            R -= weight;
399
400            if( L > FLT_EPSILON && R > FLT_EPSILON )
401            {
402                double val = (lsum2*R + rsum2*L)/(L*R);
403                if( best_val < val )
404                {
405                    best_val = val;
406                    best_subset = subset_i;
407                }
408            }
409        }
410        else
411        {
412            double val = lcw[0] + rcw[1];
413            double val2 = lcw[1] + rcw[0];
414
415            val = MAX(val, val2);
416            if( best_val < val )
417            {
418                best_val = val;
419                best_subset = subset_i;
420            }
421        }
422    }
423
424    if( best_subset < 0 )
425        return 0;
426
427    split = data->new_split_cat( vi, (float)best_val );
428
429    for( i = 0; i <= best_subset; i++ )
430    {
431        idx = (int)(dbl_ptr[i] - cjk) >> 1;
432        split->subset[idx >> 5] |= 1 << (idx & 31);
433    }
434
435    return split;
436}
437
438
439CvDTreeSplit*
440CvBoostTree::find_split_ord_reg( CvDTreeNode* node, int vi )
441{
442    const float epsilon = FLT_EPSILON*2;
443    const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
444    const float* responses = data->get_ord_responses(node);
445    const double* weights = ensemble->get_subtree_weights()->data.db;
446    int n = node->sample_count;
447    int n1 = node->get_num_valid(vi);
448    int i, best_i = -1;
449    double best_val = 0, lsum = 0, rsum = node->value*n;
450    double L = 0, R = weights[n];
451
452    // compensate for missing values
453    for( i = n1; i < n; i++ )
454    {
455        int idx = sorted[i].i;
456        double w = weights[idx];
457        rsum -= responses[idx]*w;
458        R -= w;
459    }
460
461    // find the optimal split
462    for( i = 0; i < n1 - 1; i++ )
463    {
464        int idx = sorted[i].i;
465        double w = weights[idx];
466        double t = responses[idx]*w;
467        L += w; R -= w;
468        lsum += t; rsum -= t;
469
470        if( sorted[i].val + epsilon < sorted[i+1].val )
471        {
472            double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
473            if( best_val < val )
474            {
475                best_val = val;
476                best_i = i;
477            }
478        }
479    }
480
481    return best_i >= 0 ? data->new_split_ord( vi,
482        (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
483        0, (float)best_val ) : 0;
484}
485
486
487CvDTreeSplit*
488CvBoostTree::find_split_cat_reg( CvDTreeNode* node, int vi )
489{
490    CvDTreeSplit* split;
491    const int* cat_labels = data->get_cat_var_data(node, vi);
492    const float* responses = data->get_ord_responses(node);
493    const double* weights = ensemble->get_subtree_weights()->data.db;
494    int ci = data->get_var_type(vi);
495    int n = node->sample_count;
496    int mi = data->cat_count->data.i[ci];
497    double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;
498    double* counts = (double*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;
499    double** sum_ptr = (double**)cvStackAlloc( mi*sizeof(sum_ptr[0]) );
500    double L = 0, R = 0, best_val = 0, lsum = 0, rsum = 0;
501    int i, best_subset = -1, subset_i;
502
503    for( i = -1; i < mi; i++ )
504        sum[i] = counts[i] = 0;
505
506    // calculate sum response and weight of each category of the input var
507    for( i = 0; i < n; i++ )
508    {
509        int idx = cat_labels[i];
510        double w = weights[i];
511        double s = sum[idx] + responses[i]*w;
512        double nc = counts[idx] + w;
513        sum[idx] = s;
514        counts[idx] = nc;
515    }
516
517    // calculate average response in each category
518    for( i = 0; i < mi; i++ )
519    {
520        R += counts[i];
521        rsum += sum[i];
522        sum[i] /= counts[i];
523        sum_ptr[i] = sum + i;
524    }
525
526    icvSortDblPtr( sum_ptr, mi, 0 );
527
528    // revert back to unnormalized sums
529    // (there should be a very little loss in accuracy)
530    for( i = 0; i < mi; i++ )
531        sum[i] *= counts[i];
532
533    for( subset_i = 0; subset_i < mi-1; subset_i++ )
534    {
535        int idx = (int)(sum_ptr[subset_i] - sum);
536        double ni = counts[idx];
537
538        if( ni > FLT_EPSILON )
539        {
540            double s = sum[idx];
541            lsum += s; L += ni;
542            rsum -= s; R -= ni;
543
544            if( L > FLT_EPSILON && R > FLT_EPSILON )
545            {
546                double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
547                if( best_val < val )
548                {
549                    best_val = val;
550                    best_subset = subset_i;
551                }
552            }
553        }
554    }
555
556    if( best_subset < 0 )
557        return 0;
558
559    split = data->new_split_cat( vi, (float)best_val );
560    for( i = 0; i <= best_subset; i++ )
561    {
562        int idx = (int)(sum_ptr[i] - sum);
563        split->subset[idx >> 5] |= 1 << (idx & 31);
564    }
565
566    return split;
567}
568
569
570CvDTreeSplit*
571CvBoostTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )
572{
573    const float epsilon = FLT_EPSILON*2;
574    const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
575    const double* weights = ensemble->get_subtree_weights()->data.db;
576    const char* dir = (char*)data->direction->data.ptr;
577    int n1 = node->get_num_valid(vi);
578    // LL - number of samples that both the primary and the surrogate splits send to the left
579    // LR - ... primary split sends to the left and the surrogate split sends to the right
580    // RL - ... primary split sends to the right and the surrogate split sends to the left
581    // RR - ... both send to the right
582    int i, best_i = -1, best_inversed = 0;
583    double best_val;
584    double LL = 0, RL = 0, LR, RR;
585    double worst_val = node->maxlr;
586    double sum = 0, sum_abs = 0;
587    best_val = worst_val;
588
589    for( i = 0; i < n1; i++ )
590    {
591        int idx = sorted[i].i;
592        double w = weights[idx];
593        int d = dir[idx];
594        sum += d*w; sum_abs += (d & 1)*w;
595    }
596
597    // sum_abs = R + L; sum = R - L
598    RR = (sum_abs + sum)*0.5;
599    LR = (sum_abs - sum)*0.5;
600
601    // initially all the samples are sent to the right by the surrogate split,
602    // LR of them are sent to the left by primary split, and RR - to the right.
603    // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
604    for( i = 0; i < n1 - 1; i++ )
605    {
606        int idx = sorted[i].i;
607        double w = weights[idx];
608        int d = dir[idx];
609
610        if( d < 0 )
611        {
612            LL += w; LR -= w;
613            if( LL + RR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
614            {
615                best_val = LL + RR;
616                best_i = i; best_inversed = 0;
617            }
618        }
619        else if( d > 0 )
620        {
621            RL += w; RR -= w;
622            if( RL + LR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
623            {
624                best_val = RL + LR;
625                best_i = i; best_inversed = 1;
626            }
627        }
628    }
629
630    return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
631        (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
632        best_inversed, (float)best_val ) : 0;
633}
634
635
636CvDTreeSplit*
637CvBoostTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )
638{
639    const int* cat_labels = data->get_cat_var_data(node, vi);
640    const char* dir = (char*)data->direction->data.ptr;
641    const double* weights = ensemble->get_subtree_weights()->data.db;
642    int n = node->sample_count;
643    // LL - number of samples that both the primary and the surrogate splits send to the left
644    // LR - ... primary split sends to the left and the surrogate split sends to the right
645    // RL - ... primary split sends to the right and the surrogate split sends to the left
646    // RR - ... both send to the right
647    CvDTreeSplit* split = data->new_split_cat( vi, 0 );
648    int i, mi = data->cat_count->data.i[data->get_var_type(vi)];
649    double best_val = 0;
650    double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;
651    double* rc = lc + mi + 1;
652
653    for( i = -1; i < mi; i++ )
654        lc[i] = rc[i] = 0;
655
656    // 1. for each category calculate the weight of samples
657    // sent to the left (lc) and to the right (rc) by the primary split
658    for( i = 0; i < n; i++ )
659    {
660        int idx = cat_labels[i];
661        double w = weights[i];
662        int d = dir[i];
663        double sum = lc[idx] + d*w;
664        double sum_abs = rc[idx] + (d & 1)*w;
665        lc[idx] = sum; rc[idx] = sum_abs;
666    }
667
668    for( i = 0; i < mi; i++ )
669    {
670        double sum = lc[i];
671        double sum_abs = rc[i];
672        lc[i] = (sum_abs - sum) * 0.5;
673        rc[i] = (sum_abs + sum) * 0.5;
674    }
675
676    // 2. now form the split.
677    // in each category send all the samples to the same direction as majority
678    for( i = 0; i < mi; i++ )
679    {
680        double lval = lc[i], rval = rc[i];
681        if( lval > rval )
682        {
683            split->subset[i >> 5] |= 1 << (i & 31);
684            best_val += lval;
685        }
686        else
687            best_val += rval;
688    }
689
690    split->quality = (float)best_val;
691    if( split->quality <= node->maxlr )
692        cvSetRemoveByPtr( data->split_heap, split ), split = 0;
693
694    return split;
695}
696
697
698void
699CvBoostTree::calc_node_value( CvDTreeNode* node )
700{
701    int i, count = node->sample_count;
702    const double* weights = ensemble->get_weights()->data.db;
703    const int* labels = data->get_labels(node);
704    double* subtree_weights = ensemble->get_subtree_weights()->data.db;
705    double rcw[2] = {0,0};
706    int boost_type = ensemble->get_params().boost_type;
707    //const double* priors = data->priors->data.db;
708
709    if( data->is_classifier )
710    {
711        const int* responses = data->get_class_labels(node);
712
713        for( i = 0; i < count; i++ )
714        {
715            int idx = labels[i];
716            double w = weights[idx]/*priors[responses[i]]*/;
717            rcw[responses[i]] += w;
718            subtree_weights[i] = w;
719        }
720
721        node->class_idx = rcw[1] > rcw[0];
722
723        if( boost_type == CvBoost::DISCRETE )
724        {
725            // ignore cat_map for responses, and use {-1,1},
726            // as the whole ensemble response is computes as sign(sum_i(weak_response_i)
727            node->value = node->class_idx*2 - 1;
728        }
729        else
730        {
731            double p = rcw[1]/(rcw[0] + rcw[1]);
732            assert( boost_type == CvBoost::REAL );
733
734            // store log-ratio of the probability
735            node->value = 0.5*log_ratio(p);
736        }
737    }
738    else
739    {
740        // in case of regression tree:
741        //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
742        //    n is the number of samples in the node.
743        //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
744        double sum = 0, sum2 = 0, iw;
745        const float* values = data->get_ord_responses(node);
746
747        for( i = 0; i < count; i++ )
748        {
749            int idx = labels[i];
750            double w = weights[idx]/*priors[values[i] > 0]*/;
751            double t = values[i];
752            rcw[0] += w;
753            subtree_weights[i] = w;
754            sum += t*w;
755            sum2 += t*t*w;
756        }
757
758        iw = 1./rcw[0];
759        node->value = sum*iw;
760        node->node_risk = sum2 - (sum*iw)*sum;
761
762        // renormalize the risk, as in try_split_node the unweighted formula
763        // sqrt(risk)/n is used, rather than sqrt(risk)/sum(weights_i)
764        node->node_risk *= count*iw*count*iw;
765    }
766
767    // store summary weights
768    subtree_weights[count] = rcw[0];
769    subtree_weights[count+1] = rcw[1];
770}
771
772
773void CvBoostTree::read( CvFileStorage* fs, CvFileNode* fnode, CvBoost* _ensemble, CvDTreeTrainData* _data )
774{
775    CvDTree::read( fs, fnode, _data );
776    ensemble = _ensemble;
777}
778
779
780void CvBoostTree::read( CvFileStorage*, CvFileNode* )
781{
782    assert(0);
783}
784
785void CvBoostTree::read( CvFileStorage* _fs, CvFileNode* _node,
786                        CvDTreeTrainData* _data )
787{
788    CvDTree::read( _fs, _node, _data );
789}
790
791
792/////////////////////////////////// CvBoost /////////////////////////////////////
793
794CvBoost::CvBoost()
795{
796    data = 0;
797    weak = 0;
798    default_model_name = "my_boost_tree";
799    orig_response = sum_response = weak_eval = subsample_mask =
800        weights = subtree_weights = 0;
801
802    clear();
803}
804
805
806void CvBoost::prune( CvSlice slice )
807{
808    if( weak )
809    {
810        CvSeqReader reader;
811        int i, count = cvSliceLength( slice, weak );
812
813        cvStartReadSeq( weak, &reader );
814        cvSetSeqReaderPos( &reader, slice.start_index );
815
816        for( i = 0; i < count; i++ )
817        {
818            CvBoostTree* w;
819            CV_READ_SEQ_ELEM( w, reader );
820            delete w;
821        }
822
823        cvSeqRemoveSlice( weak, slice );
824    }
825}
826
827
828void CvBoost::clear()
829{
830    if( weak )
831    {
832        prune( CV_WHOLE_SEQ );
833        cvReleaseMemStorage( &weak->storage );
834    }
835    if( data )
836        delete data;
837    weak = 0;
838    data = 0;
839    cvReleaseMat( &orig_response );
840    cvReleaseMat( &sum_response );
841    cvReleaseMat( &weak_eval );
842    cvReleaseMat( &subsample_mask );
843    cvReleaseMat( &weights );
844    have_subsample = false;
845}
846
847
848CvBoost::~CvBoost()
849{
850    clear();
851}
852
853
854CvBoost::CvBoost( const CvMat* _train_data, int _tflag,
855                  const CvMat* _responses, const CvMat* _var_idx,
856                  const CvMat* _sample_idx, const CvMat* _var_type,
857                  const CvMat* _missing_mask, CvBoostParams _params )
858{
859    weak = 0;
860    data = 0;
861    default_model_name = "my_boost_tree";
862    orig_response = sum_response = weak_eval = subsample_mask = weights = 0;
863
864    train( _train_data, _tflag, _responses, _var_idx, _sample_idx,
865           _var_type, _missing_mask, _params );
866}
867
868
869bool
870CvBoost::set_params( const CvBoostParams& _params )
871{
872    bool ok = false;
873
874    CV_FUNCNAME( "CvBoost::set_params" );
875
876    __BEGIN__;
877
878    params = _params;
879    if( params.boost_type != DISCRETE && params.boost_type != REAL &&
880        params.boost_type != LOGIT && params.boost_type != GENTLE )
881        CV_ERROR( CV_StsBadArg, "Unknown/unsupported boosting type" );
882
883    params.weak_count = MAX( params.weak_count, 1 );
884    params.weight_trim_rate = MAX( params.weight_trim_rate, 0. );
885    params.weight_trim_rate = MIN( params.weight_trim_rate, 1. );
886    if( params.weight_trim_rate < FLT_EPSILON )
887        params.weight_trim_rate = 1.f;
888
889    if( params.boost_type == DISCRETE &&
890        params.split_criteria != GINI && params.split_criteria != MISCLASS )
891        params.split_criteria = MISCLASS;
892    if( params.boost_type == REAL &&
893        params.split_criteria != GINI && params.split_criteria != MISCLASS )
894        params.split_criteria = GINI;
895    if( (params.boost_type == LOGIT || params.boost_type == GENTLE) &&
896        params.split_criteria != SQERR )
897        params.split_criteria = SQERR;
898
899    ok = true;
900
901    __END__;
902
903    return ok;
904}
905
906
907bool
908CvBoost::train( const CvMat* _train_data, int _tflag,
909              const CvMat* _responses, const CvMat* _var_idx,
910              const CvMat* _sample_idx, const CvMat* _var_type,
911              const CvMat* _missing_mask,
912              CvBoostParams _params, bool _update )
913{
914    bool ok = false;
915    CvMemStorage* storage = 0;
916
917    CV_FUNCNAME( "CvBoost::train" );
918
919    __BEGIN__;
920
921    int i;
922
923    set_params( _params );
924
925    if( !_update || !data )
926    {
927        clear();
928        data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
929            _sample_idx, _var_type, _missing_mask, _params, true, true );
930
931        if( data->get_num_classes() != 2 )
932            CV_ERROR( CV_StsNotImplemented,
933            "Boosted trees can only be used for 2-class classification." );
934        CV_CALL( storage = cvCreateMemStorage() );
935        weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
936        storage = 0;
937    }
938    else
939    {
940        data->set_data( _train_data, _tflag, _responses, _var_idx,
941            _sample_idx, _var_type, _missing_mask, _params, true, true, true );
942    }
943
944    update_weights( 0 );
945
946    for( i = 0; i < params.weak_count; i++ )
947    {
948        CvBoostTree* tree = new CvBoostTree;
949        if( !tree->train( data, subsample_mask, this ) )
950        {
951            delete tree;
952            continue;
953        }
954        //cvCheckArr( get_weak_response());
955        cvSeqPush( weak, &tree );
956        update_weights( tree );
957        trim_weights();
958    }
959
960    data->is_classifier = true;
961    ok = true;
962
963    __END__;
964
965    return ok;
966}
967
968
969void
970CvBoost::update_weights( CvBoostTree* tree )
971{
972    CV_FUNCNAME( "CvBoost::update_weights" );
973
974    __BEGIN__;
975
976    int i, count = data->sample_count;
977    double sumw = 0.;
978
979    if( !tree ) // before training the first tree, initialize weights and other parameters
980    {
981        const int* class_labels = data->get_class_labels(data->data_root);
982        // in case of logitboost and gentle adaboost each weak tree is a regression tree,
983        // so we need to convert class labels to floating-point values
984        float* responses = data->get_ord_responses(data->data_root);
985        int* labels = data->get_labels(data->data_root);
986        double w0 = 1./count;
987        double p[2] = { 1, 1 };
988
989        cvReleaseMat( &orig_response );
990        cvReleaseMat( &sum_response );
991        cvReleaseMat( &weak_eval );
992        cvReleaseMat( &subsample_mask );
993        cvReleaseMat( &weights );
994
995        CV_CALL( orig_response = cvCreateMat( 1, count, CV_32S ));
996        CV_CALL( weak_eval = cvCreateMat( 1, count, CV_64F ));
997        CV_CALL( subsample_mask = cvCreateMat( 1, count, CV_8U ));
998        CV_CALL( weights = cvCreateMat( 1, count, CV_64F ));
999        CV_CALL( subtree_weights = cvCreateMat( 1, count + 2, CV_64F ));
1000
1001        if( data->have_priors )
1002        {
1003            // compute weight scale for each class from their prior probabilities
1004            int c1 = 0;
1005            for( i = 0; i < count; i++ )
1006                c1 += class_labels[i];
1007            p[0] = data->priors->data.db[0]*(c1 < count ? 1./(count - c1) : 0.);
1008            p[1] = data->priors->data.db[1]*(c1 > 0 ? 1./c1 : 0.);
1009            p[0] /= p[0] + p[1];
1010            p[1] = 1. - p[0];
1011        }
1012
1013        for( i = 0; i < count; i++ )
1014        {
1015            // save original categorical responses {0,1}, convert them to {-1,1}
1016            orig_response->data.i[i] = class_labels[i]*2 - 1;
1017            // make all the samples active at start.
1018            // later, in trim_weights() deactivate/reactive again some, if need
1019            subsample_mask->data.ptr[i] = (uchar)1;
1020            // make all the initial weights the same.
1021            weights->data.db[i] = w0*p[class_labels[i]];
1022            // set the labels to find (from within weak tree learning proc)
1023            // the particular sample weight, and where to store the response.
1024            labels[i] = i;
1025        }
1026
1027        if( params.boost_type == LOGIT )
1028        {
1029            CV_CALL( sum_response = cvCreateMat( 1, count, CV_64F ));
1030
1031            for( i = 0; i < count; i++ )
1032            {
1033                sum_response->data.db[i] = 0;
1034                responses[i] = orig_response->data.i[i] > 0 ? 2.f : -2.f;
1035            }
1036
1037            // in case of logitboost each weak tree is a regression tree.
1038            // the target function values are recalculated for each of the trees
1039            data->is_classifier = false;
1040        }
1041        else if( params.boost_type == GENTLE )
1042        {
1043            for( i = 0; i < count; i++ )
1044                responses[i] = (float)orig_response->data.i[i];
1045
1046            data->is_classifier = false;
1047        }
1048    }
1049    else
1050    {
1051        // at this moment, for all the samples that participated in the training of the most
1052        // recent weak classifier we know the responses. For other samples we need to compute them
1053        if( have_subsample )
1054        {
1055            float* values = (float*)(data->buf->data.ptr + data->buf->step);
1056            uchar* missing = data->buf->data.ptr + data->buf->step*2;
1057            CvMat _sample, _mask;
1058
1059            // invert the subsample mask
1060            cvXorS( subsample_mask, cvScalar(1.), subsample_mask );
1061            data->get_vectors( subsample_mask, values, missing, 0 );
1062            //data->get_vectors( 0, values, missing, 0 );
1063
1064            _sample = cvMat( 1, data->var_count, CV_32F );
1065            _mask = cvMat( 1, data->var_count, CV_8U );
1066
1067            // run tree through all the non-processed samples
1068            for( i = 0; i < count; i++ )
1069                if( subsample_mask->data.ptr[i] )
1070                {
1071                    _sample.data.fl = values;
1072                    _mask.data.ptr = missing;
1073                    values += _sample.cols;
1074                    missing += _mask.cols;
1075                    weak_eval->data.db[i] = tree->predict( &_sample, &_mask, true )->value;
1076                }
1077        }
1078
1079        // now update weights and other parameters for each type of boosting
1080        if( params.boost_type == DISCRETE )
1081        {
1082            // Discrete AdaBoost:
1083            //   weak_eval[i] (=f(x_i)) is in {-1,1}
1084            //   err = sum(w_i*(f(x_i) != y_i))/sum(w_i)
1085            //   C = log((1-err)/err)
1086            //   w_i *= exp(C*(f(x_i) != y_i))
1087
1088            double C, err = 0.;
1089            double scale[] = { 1., 0. };
1090
1091            for( i = 0; i < count; i++ )
1092            {
1093                double w = weights->data.db[i];
1094                sumw += w;
1095                err += w*(weak_eval->data.db[i] != orig_response->data.i[i]);
1096            }
1097
1098            if( sumw != 0 )
1099                err /= sumw;
1100            C = err = -log_ratio( err );
1101            scale[1] = exp(err);
1102
1103            sumw = 0;
1104            for( i = 0; i < count; i++ )
1105            {
1106                double w = weights->data.db[i]*
1107                    scale[weak_eval->data.db[i] != orig_response->data.i[i]];
1108                sumw += w;
1109                weights->data.db[i] = w;
1110            }
1111
1112            tree->scale( C );
1113        }
1114        else if( params.boost_type == REAL )
1115        {
1116            // Real AdaBoost:
1117            //   weak_eval[i] = f(x_i) = 0.5*log(p(x_i)/(1-p(x_i))), p(x_i)=P(y=1|x_i)
1118            //   w_i *= exp(-y_i*f(x_i))
1119
1120            for( i = 0; i < count; i++ )
1121                weak_eval->data.db[i] *= -orig_response->data.i[i];
1122
1123            cvExp( weak_eval, weak_eval );
1124
1125            for( i = 0; i < count; i++ )
1126            {
1127                double w = weights->data.db[i]*weak_eval->data.db[i];
1128                sumw += w;
1129                weights->data.db[i] = w;
1130            }
1131        }
1132        else if( params.boost_type == LOGIT )
1133        {
1134            // LogitBoost:
1135            //   weak_eval[i] = f(x_i) in [-z_max,z_max]
1136            //   sum_response = F(x_i).
1137            //   F(x_i) += 0.5*f(x_i)
1138            //   p(x_i) = exp(F(x_i))/(exp(F(x_i)) + exp(-F(x_i))=1/(1+exp(-2*F(x_i)))
1139            //   reuse weak_eval: weak_eval[i] <- p(x_i)
1140            //   w_i = p(x_i)*1(1 - p(x_i))
1141            //   z_i = ((y_i+1)/2 - p(x_i))/(p(x_i)*(1 - p(x_i)))
1142            //   store z_i to the data->data_root as the new target responses
1143
1144            const double lb_weight_thresh = FLT_EPSILON;
1145            const double lb_z_max = 10.;
1146            float* responses = data->get_ord_responses(data->data_root);
1147
1148            /*if( weak->total == 7 )
1149                putchar('*');*/
1150
1151            for( i = 0; i < count; i++ )
1152            {
1153                double s = sum_response->data.db[i] + 0.5*weak_eval->data.db[i];
1154                sum_response->data.db[i] = s;
1155                weak_eval->data.db[i] = -2*s;
1156            }
1157
1158            cvExp( weak_eval, weak_eval );
1159
1160            for( i = 0; i < count; i++ )
1161            {
1162                double p = 1./(1. + weak_eval->data.db[i]);
1163                double w = p*(1 - p), z;
1164                w = MAX( w, lb_weight_thresh );
1165                weights->data.db[i] = w;
1166                sumw += w;
1167                if( orig_response->data.i[i] > 0 )
1168                {
1169                    z = 1./p;
1170                    responses[i] = (float)MIN(z, lb_z_max);
1171                }
1172                else
1173                {
1174                    z = 1./(1-p);
1175                    responses[i] = (float)-MIN(z, lb_z_max);
1176                }
1177            }
1178        }
1179        else
1180        {
1181            // Gentle AdaBoost:
1182            //   weak_eval[i] = f(x_i) in [-1,1]
1183            //   w_i *= exp(-y_i*f(x_i))
1184            assert( params.boost_type == GENTLE );
1185
1186            for( i = 0; i < count; i++ )
1187                weak_eval->data.db[i] *= -orig_response->data.i[i];
1188
1189            cvExp( weak_eval, weak_eval );
1190
1191            for( i = 0; i < count; i++ )
1192            {
1193                double w = weights->data.db[i] * weak_eval->data.db[i];
1194                weights->data.db[i] = w;
1195                sumw += w;
1196            }
1197        }
1198    }
1199
1200    // renormalize weights
1201    if( sumw > FLT_EPSILON )
1202    {
1203        sumw = 1./sumw;
1204        for( i = 0; i < count; ++i )
1205            weights->data.db[i] *= sumw;
1206    }
1207
1208    __END__;
1209}
1210
1211
1212static CV_IMPLEMENT_QSORT_EX( icvSort_64f, double, CV_LT, int )
1213
1214
1215void
1216CvBoost::trim_weights()
1217{
1218    CV_FUNCNAME( "CvBoost::trim_weights" );
1219
1220    __BEGIN__;
1221
1222    int i, count = data->sample_count, nz_count = 0;
1223    double sum, threshold;
1224
1225    if( params.weight_trim_rate <= 0. || params.weight_trim_rate >= 1. )
1226        EXIT;
1227
1228    // use weak_eval as temporary buffer for sorted weights
1229    cvCopy( weights, weak_eval );
1230
1231    icvSort_64f( weak_eval->data.db, count, 0 );
1232
1233    // as weight trimming occurs immediately after updating the weights,
1234    // where they are renormalized, we assume that the weight sum = 1.
1235    sum = 1. - params.weight_trim_rate;
1236
1237    for( i = 0; i < count; i++ )
1238    {
1239        double w = weak_eval->data.db[i];
1240        if( sum > w )
1241            break;
1242        sum -= w;
1243    }
1244
1245    threshold = i < count ? weak_eval->data.db[i] : DBL_MAX;
1246
1247    for( i = 0; i < count; i++ )
1248    {
1249        double w = weights->data.db[i];
1250        int f = w > threshold;
1251        subsample_mask->data.ptr[i] = (uchar)f;
1252        nz_count += f;
1253    }
1254
1255    have_subsample = nz_count < count;
1256
1257    __END__;
1258}
1259
1260
1261float
1262CvBoost::predict( const CvMat* _sample, const CvMat* _missing,
1263                  CvMat* weak_responses, CvSlice slice,
1264                  bool raw_mode ) const
1265{
1266    float* buf = 0;
1267    bool allocated = false;
1268    float value = -FLT_MAX;
1269
1270    CV_FUNCNAME( "CvBoost::predict" );
1271
1272    __BEGIN__;
1273
1274    int i, weak_count, var_count;
1275    CvMat sample, missing;
1276    CvSeqReader reader;
1277    double sum = 0;
1278    int cls_idx;
1279    int wstep = 0;
1280    const int* vtype;
1281    const int* cmap;
1282    const int* cofs;
1283
1284    if( !weak )
1285        CV_ERROR( CV_StsError, "The boosted tree ensemble has not been trained yet" );
1286
1287    if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
1288        _sample->cols != 1 && _sample->rows != 1 ||
1289        _sample->cols + _sample->rows - 1 != data->var_all && !raw_mode ||
1290        _sample->cols + _sample->rows - 1 != data->var_count && raw_mode )
1291            CV_ERROR( CV_StsBadArg,
1292        "the input sample must be 1d floating-point vector with the same "
1293        "number of elements as the total number of variables used for training" );
1294
1295    if( _missing )
1296    {
1297        if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
1298            !CV_ARE_SIZES_EQ(_missing, _sample) )
1299            CV_ERROR( CV_StsBadArg,
1300            "the missing data mask must be 8-bit vector of the same size as input sample" );
1301    }
1302
1303    weak_count = cvSliceLength( slice, weak );
1304    if( weak_count >= weak->total )
1305    {
1306        weak_count = weak->total;
1307        slice.start_index = 0;
1308    }
1309
1310    if( weak_responses )
1311    {
1312        if( !CV_IS_MAT(weak_responses) ||
1313            CV_MAT_TYPE(weak_responses->type) != CV_32FC1 ||
1314            weak_responses->cols != 1 && weak_responses->rows != 1 ||
1315            weak_responses->cols + weak_responses->rows - 1 != weak_count )
1316            CV_ERROR( CV_StsBadArg,
1317            "The output matrix of weak classifier responses must be valid "
1318            "floating-point vector of the same number of components as the length of input slice" );
1319        wstep = CV_IS_MAT_CONT(weak_responses->type) ? 1 : weak_responses->step/sizeof(float);
1320    }
1321
1322    var_count = data->var_count;
1323    vtype = data->var_type->data.i;
1324    cmap = data->cat_map->data.i;
1325    cofs = data->cat_ofs->data.i;
1326
1327    // if need, preprocess the input vector
1328    if( !raw_mode && (data->cat_var_count > 0 || data->var_idx) )
1329    {
1330        int bufsize;
1331        int step, mstep = 0;
1332        const float* src_sample;
1333        const uchar* src_mask = 0;
1334        float* dst_sample;
1335        uchar* dst_mask;
1336        const int* vidx = data->var_idx && !raw_mode ? data->var_idx->data.i : 0;
1337        bool have_mask = _missing != 0;
1338
1339        bufsize = var_count*(sizeof(float) + sizeof(uchar));
1340        if( bufsize <= CV_MAX_LOCAL_SIZE )
1341            buf = (float*)cvStackAlloc( bufsize );
1342        else
1343        {
1344            CV_CALL( buf = (float*)cvAlloc( bufsize ));
1345            allocated = true;
1346        }
1347        dst_sample = buf;
1348        dst_mask = (uchar*)(buf + var_count);
1349
1350        src_sample = _sample->data.fl;
1351        step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(src_sample[0]);
1352
1353        if( _missing )
1354        {
1355            src_mask = _missing->data.ptr;
1356            mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step;
1357        }
1358
1359        for( i = 0; i < var_count; i++ )
1360        {
1361            int idx = vidx ? vidx[i] : i;
1362            float val = src_sample[idx*step];
1363            int ci = vtype[i];
1364            uchar m = src_mask ? src_mask[i] : (uchar)0;
1365
1366            if( ci >= 0 )
1367            {
1368                int a = cofs[ci], b = cofs[ci+1], c = a;
1369                int ival = cvRound(val);
1370                if( ival != val )
1371                    CV_ERROR( CV_StsBadArg,
1372                    "one of input categorical variable is not an integer" );
1373
1374                while( a < b )
1375                {
1376                    c = (a + b) >> 1;
1377                    if( ival < cmap[c] )
1378                        b = c;
1379                    else if( ival > cmap[c] )
1380                        a = c+1;
1381                    else
1382                        break;
1383                }
1384
1385                if( c < 0 || ival != cmap[c] )
1386                {
1387                    m = 1;
1388                    have_mask = true;
1389                }
1390                else
1391                {
1392                    val = (float)(c - cofs[ci]);
1393                }
1394            }
1395
1396            dst_sample[i] = val;
1397            dst_mask[i] = m;
1398        }
1399
1400        sample = cvMat( 1, var_count, CV_32F, dst_sample );
1401        _sample = &sample;
1402
1403        if( have_mask )
1404        {
1405            missing = cvMat( 1, var_count, CV_8UC1, dst_mask );
1406            _missing = &missing;
1407        }
1408    }
1409
1410    cvStartReadSeq( weak, &reader );
1411    cvSetSeqReaderPos( &reader, slice.start_index );
1412
1413    for( i = 0; i < weak_count; i++ )
1414    {
1415        CvBoostTree* wtree;
1416        double val;
1417
1418        CV_READ_SEQ_ELEM( wtree, reader );
1419
1420        val = wtree->predict( _sample, _missing, true )->value;
1421        if( weak_responses )
1422            weak_responses->data.fl[i*wstep] = (float)val;
1423
1424        sum += val;
1425    }
1426
1427    cls_idx = sum >= 0;
1428    if( raw_mode )
1429        value = (float)cls_idx;
1430    else
1431        value = (float)cmap[cofs[vtype[var_count]] + cls_idx];
1432
1433    __END__;
1434
1435    if( allocated )
1436        cvFree( &buf );
1437
1438    return value;
1439}
1440
1441
1442
1443void CvBoost::write_params( CvFileStorage* fs )
1444{
1445    CV_FUNCNAME( "CvBoost::write_params" );
1446
1447    __BEGIN__;
1448
1449    const char* boost_type_str =
1450        params.boost_type == DISCRETE ? "DiscreteAdaboost" :
1451        params.boost_type == REAL ? "RealAdaboost" :
1452        params.boost_type == LOGIT ? "LogitBoost" :
1453        params.boost_type == GENTLE ? "GentleAdaboost" : 0;
1454
1455    const char* split_crit_str =
1456        params.split_criteria == DEFAULT ? "Default" :
1457        params.split_criteria == GINI ? "Gini" :
1458        params.boost_type == MISCLASS ? "Misclassification" :
1459        params.boost_type == SQERR ? "SquaredErr" : 0;
1460
1461    if( boost_type_str )
1462        cvWriteString( fs, "boosting_type", boost_type_str );
1463    else
1464        cvWriteInt( fs, "boosting_type", params.boost_type );
1465
1466    if( split_crit_str )
1467        cvWriteString( fs, "splitting_criteria", split_crit_str );
1468    else
1469        cvWriteInt( fs, "splitting_criteria", params.split_criteria );
1470
1471    cvWriteInt( fs, "ntrees", params.weak_count );
1472    cvWriteReal( fs, "weight_trimming_rate", params.weight_trim_rate );
1473
1474    data->write_params( fs );
1475
1476    __END__;
1477}
1478
1479
1480void CvBoost::read_params( CvFileStorage* fs, CvFileNode* fnode )
1481{
1482    CV_FUNCNAME( "CvBoost::read_params" );
1483
1484    __BEGIN__;
1485
1486    CvFileNode* temp;
1487
1488    if( !fnode || !CV_NODE_IS_MAP(fnode->tag) )
1489        return;
1490
1491    data = new CvDTreeTrainData();
1492    CV_CALL( data->read_params(fs, fnode));
1493    data->shared = true;
1494
1495    params.max_depth = data->params.max_depth;
1496    params.min_sample_count = data->params.min_sample_count;
1497    params.max_categories = data->params.max_categories;
1498    params.priors = data->params.priors;
1499    params.regression_accuracy = data->params.regression_accuracy;
1500    params.use_surrogates = data->params.use_surrogates;
1501
1502    temp = cvGetFileNodeByName( fs, fnode, "boosting_type" );
1503    if( !temp )
1504        return;
1505
1506    if( temp && CV_NODE_IS_STRING(temp->tag) )
1507    {
1508        const char* boost_type_str = cvReadString( temp, "" );
1509        params.boost_type = strcmp( boost_type_str, "DiscreteAdaboost" ) == 0 ? DISCRETE :
1510                            strcmp( boost_type_str, "RealAdaboost" ) == 0 ? REAL :
1511                            strcmp( boost_type_str, "LogitBoost" ) == 0 ? LOGIT :
1512                            strcmp( boost_type_str, "GentleAdaboost" ) == 0 ? GENTLE : -1;
1513    }
1514    else
1515        params.boost_type = cvReadInt( temp, -1 );
1516
1517    if( params.boost_type < DISCRETE || params.boost_type > GENTLE )
1518        CV_ERROR( CV_StsBadArg, "Unknown boosting type" );
1519
1520    temp = cvGetFileNodeByName( fs, fnode, "splitting_criteria" );
1521    if( temp && CV_NODE_IS_STRING(temp->tag) )
1522    {
1523        const char* split_crit_str = cvReadString( temp, "" );
1524        params.split_criteria = strcmp( split_crit_str, "Default" ) == 0 ? DEFAULT :
1525                                strcmp( split_crit_str, "Gini" ) == 0 ? GINI :
1526                                strcmp( split_crit_str, "Misclassification" ) == 0 ? MISCLASS :
1527                                strcmp( split_crit_str, "SquaredErr" ) == 0 ? SQERR : -1;
1528    }
1529    else
1530        params.split_criteria = cvReadInt( temp, -1 );
1531
1532    if( params.split_criteria < DEFAULT || params.boost_type > SQERR )
1533        CV_ERROR( CV_StsBadArg, "Unknown boosting type" );
1534
1535    params.weak_count = cvReadIntByName( fs, fnode, "ntrees" );
1536    params.weight_trim_rate = cvReadRealByName( fs, fnode, "weight_trimming_rate", 0. );
1537
1538    __END__;
1539}
1540
1541
1542
1543void
1544CvBoost::read( CvFileStorage* fs, CvFileNode* node )
1545{
1546    CV_FUNCNAME( "CvRTrees::read" );
1547
1548    __BEGIN__;
1549
1550    CvSeqReader reader;
1551    CvFileNode* trees_fnode;
1552    CvMemStorage* storage;
1553    int i, ntrees;
1554
1555    clear();
1556    read_params( fs, node );
1557
1558    if( !data )
1559        EXIT;
1560
1561    trees_fnode = cvGetFileNodeByName( fs, node, "trees" );
1562    if( !trees_fnode || !CV_NODE_IS_SEQ(trees_fnode->tag) )
1563        CV_ERROR( CV_StsParseError, "<trees> tag is missing" );
1564
1565    cvStartReadSeq( trees_fnode->data.seq, &reader );
1566    ntrees = trees_fnode->data.seq->total;
1567
1568    if( ntrees != params.weak_count )
1569        CV_ERROR( CV_StsUnmatchedSizes,
1570        "The number of trees stored does not match <ntrees> tag value" );
1571
1572    CV_CALL( storage = cvCreateMemStorage() );
1573    weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
1574
1575    for( i = 0; i < ntrees; i++ )
1576    {
1577        CvBoostTree* tree = new CvBoostTree();
1578        CV_CALL(tree->read( fs, (CvFileNode*)reader.ptr, this, data ));
1579        CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
1580        cvSeqPush( weak, &tree );
1581    }
1582
1583    __END__;
1584}
1585
1586
1587void
1588CvBoost::write( CvFileStorage* fs, const char* name )
1589{
1590    CV_FUNCNAME( "CvBoost::write" );
1591
1592    __BEGIN__;
1593
1594    CvSeqReader reader;
1595    int i;
1596
1597    cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_BOOSTING );
1598
1599    if( !weak )
1600        CV_ERROR( CV_StsBadArg, "The classifier has not been trained yet" );
1601
1602    write_params( fs );
1603    cvStartWriteStruct( fs, "trees", CV_NODE_SEQ );
1604
1605    cvStartReadSeq( weak, &reader );
1606
1607    for( i = 0; i < weak->total; i++ )
1608    {
1609        CvBoostTree* tree;
1610        CV_READ_SEQ_ELEM( tree, reader );
1611        cvStartWriteStruct( fs, 0, CV_NODE_MAP );
1612        tree->write( fs );
1613        cvEndWriteStruct( fs );
1614    }
1615
1616    cvEndWriteStruct( fs );
1617    cvEndWriteStruct( fs );
1618
1619    __END__;
1620}
1621
1622
1623CvMat*
1624CvBoost::get_weights()
1625{
1626    return weights;
1627}
1628
1629
1630CvMat*
1631CvBoost::get_subtree_weights()
1632{
1633    return subtree_weights;
1634}
1635
1636
1637CvMat*
1638CvBoost::get_weak_response()
1639{
1640    return weak_eval;
1641}
1642
1643
1644const CvBoostParams&
1645CvBoost::get_params() const
1646{
1647    return params;
1648}
1649
1650CvSeq* CvBoost::get_weak_predictors()
1651{
1652    return weak;
1653}
1654
1655/* End of file. */
1656