16acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/*M///////////////////////////////////////////////////////////////////////////////////////
26acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
36acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
46acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
56acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  By downloading, copying, installing or using the software you agree to this license.
66acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  If you do not agree to this license, do not download, install,
76acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  copy or use the software.
86acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
96acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//                        Intel License Agreement
116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Copyright (C) 2000, Intel Corporation, all rights reserved.
136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Third party copyrights are property of their respective owners.
146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Redistribution and use in source and binary forms, with or without modification,
166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// are permitted provided that the following conditions are met:
176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//   * Redistribution's of source code must retain the above copyright notice,
196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     this list of conditions and the following disclaimer.
206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//   * Redistribution's in binary form must reproduce the above copyright notice,
226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     this list of conditions and the following disclaimer in the documentation
236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     and/or other materials provided with the distribution.
246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//   * The name of Intel Corporation may not be used to endorse or promote products
266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     derived from this software without specific prior written permission.
276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// This software is provided by the copyright holders and contributors "as is" and
296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// any express or implied warranties, including, but not limited to, the implied
306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// warranties of merchantability and fitness for a particular purpose are disclaimed.
316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// In no event shall the Intel Corporation or contributors be liable for any direct,
326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// indirect, incidental, special, exemplary, or consequential damages
336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// (including, but not limited to, procurement of substitute goods or services;
346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// loss of use, data, or profits; or business interruption) however caused
356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// and on any theory of liability, whether in contract, strict liability,
366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// or tort (including negligence or otherwise) arising in any way out of
376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// the use of this software, even if advised of the possibility of such damage.
386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//M*/
406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#include "_ml.h"
426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic inline double
446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennlog_ratio( double val )
456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const double eps = 1e-5;
476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    val = MAX( val, eps );
496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    val = MIN( val, 1. - eps );
506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return log( val/(1. - val) );
516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostParams::CvBoostParams()
556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    boost_type = CvBoost::REAL;
576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    weak_count = 100;
586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    weight_trim_rate = 0.95;
596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cv_folds = 0;
606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    max_depth = 1;
616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostParams::CvBoostParams( int _boost_type, int _weak_count,
656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                        double _weight_trim_rate, int _max_depth,
666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                        bool _use_surrogates, const float* _priors )
676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    boost_type = _boost_type;
696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    weak_count = _weak_count;
706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    weight_trim_rate = _weight_trim_rate;
716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    split_criteria = CvBoost::DEFAULT;
726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cv_folds = 0;
736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    max_depth = _max_depth;
746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    use_surrogates = _use_surrogates;
756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    priors = _priors;
766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn///////////////////////////////// CvBoostTree ///////////////////////////////////
816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::CvBoostTree()
836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ensemble = 0;
856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::~CvBoostTree()
896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid
956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::clear()
966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTree::clear();
986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ensemble = 0;
996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
1006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool
1036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::train( CvDTreeTrainData* _train_data,
1046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    const CvMat* _subsample_idx, CvBoost* _ensemble )
1056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
1066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
1076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ensemble = _ensemble;
1086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data = _train_data;
1096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data->shared = true;
1106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return do_train( _subsample_idx );
1126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
1136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool
1166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::train( const CvMat*, int, const CvMat*, const CvMat*,
1176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    const CvMat*, const CvMat*, const CvMat*, CvDTreeParams )
1186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
1196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    assert(0);
1206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return false;
1216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
1226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool
1256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::train( CvDTreeTrainData*, const CvMat* )
1266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
1276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    assert(0);
1286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return false;
1296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
1306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid
1336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::scale( double scale )
1346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
1356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeNode* node = root;
1366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // traverse the tree and scale all the node values
1386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for(;;)
1396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
1406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvDTreeNode* parent;
1416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for(;;)
1426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
1436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            node->value *= scale;
1446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( !node->left )
1456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                break;
1466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            node = node->left;
1476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
1486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( parent = node->parent; parent && parent->right == node;
1506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            node = parent, parent = parent->parent )
1516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            ;
1526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !parent )
1546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            break;
1556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node = parent->right;
1576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
1586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
1596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid
1626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::try_split_node( CvDTreeNode* node )
1636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
1646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTree::try_split_node( node );
1656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !node->left )
1676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
1686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // if the node has not been split,
1696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // store the responses for the corresponding training samples
1706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double* weak_eval = ensemble->get_weak_response()->data.db;
1716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int* labels = data->get_labels( node );
1726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int i, count = node->sample_count;
1736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double value = node->value;
1746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < count; i++ )
1766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            weak_eval[labels[i]] = value;
1776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
1786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
1796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renndouble
1826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::calc_node_dir( CvDTreeNode* node )
1836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
1846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    char* dir = (char*)data->direction->data.ptr;
1856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const double* weights = ensemble->get_subtree_weights()->data.db;
1866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, n = node->sample_count, vi = node->split->var_idx;
1876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double L, R;
1886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    assert( !node->split->inversed );
1906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( data->get_var_type(vi) >= 0 ) // split on categorical var
1926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
1936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const int* cat_labels = data->get_cat_var_data( node, vi );
1946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const int* subset = node->split->subset;
1956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double sum = 0, sum_abs = 0;
1966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n; i++ )
1986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
1996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int idx = cat_labels[i];
2006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double w = weights[i];
2016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
2026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            sum += d*w; sum_abs += (d & 1)*w;
2036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            dir[i] = (char)d;
2046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
2056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        R = (sum_abs + sum) * 0.5;
2076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        L = (sum_abs - sum) * 0.5;
2086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
2096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else // split on ordered var
2106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
2116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const CvPair32s32f* sorted = data->get_ord_var_data(node,vi);
2126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int split_point = node->split->ord.split_point;
2136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int n1 = node->get_num_valid(vi);
2146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        assert( 0 <= split_point && split_point < n1-1 );
2166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        L = R = 0;
2176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i <= split_point; i++ )
2196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
2206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int idx = sorted[i].i;
2216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double w = weights[idx];
2226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            dir[idx] = (char)-1;
2236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            L += w;
2246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
2256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( ; i < n1; i++ )
2276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
2286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int idx = sorted[i].i;
2296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double w = weights[idx];
2306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            dir[idx] = (char)1;
2316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            R += w;
2326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
2336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( ; i < n; i++ )
2356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            dir[sorted[i].i] = (char)0;
2366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
2376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->maxlr = MAX( L, R );
2396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return node->split->quality/(L + R);
2406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
2416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit*
2446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::find_split_ord_class( CvDTreeNode* node, int vi )
2456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
2466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const float epsilon = FLT_EPSILON*2;
2476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
2486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int* responses = data->get_class_labels(node);
2496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const double* weights = ensemble->get_subtree_weights()->data.db;
2506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n = node->sample_count;
2516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n1 = node->get_num_valid(vi);
2526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const double* rcw0 = weights + n;
2536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double lcw[2] = {0,0}, rcw[2];
2546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, best_i = -1;
2556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double best_val = 0;
2566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int boost_type = ensemble->get_params().boost_type;
2576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int split_criteria = ensemble->get_params().split_criteria;
2586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    rcw[0] = rcw0[0]; rcw[1] = rcw0[1];
2606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = n1; i < n; i++ )
2616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
2626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int idx = sorted[i].i;
2636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double w = weights[idx];
2646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        rcw[responses[idx]] -= w;
2656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
2666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( split_criteria != CvBoost::GINI && split_criteria != CvBoost::MISCLASS )
2686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        split_criteria = boost_type == CvBoost::DISCRETE ? CvBoost::MISCLASS : CvBoost::GINI;
2696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( split_criteria == CvBoost::GINI )
2716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
2726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double L = 0, R = rcw[0] + rcw[1];
2736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double lsum2 = 0, rsum2 = rcw[0]*rcw[0] + rcw[1]*rcw[1];
2746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n1 - 1; i++ )
2766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
2776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int idx = sorted[i].i;
2786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double w = weights[idx], w2 = w*w;
2796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double lv, rv;
2806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            idx = responses[idx];
2816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            L += w; R -= w;
2826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            lv = lcw[idx]; rv = rcw[idx];
2836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            lsum2 += 2*lv*w + w2;
2846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            rsum2 -= 2*rv*w - w2;
2856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            lcw[idx] = lv + w; rcw[idx] = rv - w;
2866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( sorted[i].val + epsilon < sorted[i+1].val )
2886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
2896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double val = (lsum2*R + rsum2*L)/(L*R);
2906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( best_val < val )
2916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
2926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    best_val = val;
2936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    best_i = i;
2946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
2956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
2966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
2976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
2986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
2996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n1 - 1; i++ )
3016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
3026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int idx = sorted[i].i;
3036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double w = weights[idx];
3046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            idx = responses[idx];
3056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            lcw[idx] += w;
3066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            rcw[idx] -= w;
3076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( sorted[i].val + epsilon < sorted[i+1].val )
3096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
3106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double val = lcw[0] + rcw[1], val2 = lcw[1] + rcw[0];
3116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                val = MAX(val, val2);
3126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( best_val < val )
3136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
3146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    best_val = val;
3156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    best_i = i;
3166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
3176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
3186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
3196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
3206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return best_i >= 0 ? data->new_split_ord( vi,
3226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
3236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        0, (float)best_val ) : 0;
3246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
3256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
3286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )
3296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit*
3316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::find_split_cat_class( CvDTreeNode* node, int vi )
3326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
3336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeSplit* split;
3346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int* cat_labels = data->get_cat_var_data(node, vi);
3356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int* responses = data->get_class_labels(node);
3366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int ci = data->get_var_type(vi);
3376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n = node->sample_count;
3386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int mi = data->cat_count->data.i[ci];
3396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double lcw[2]={0,0}, rcw[2]={0,0};
3406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* cjk = (double*)cvStackAlloc(2*(mi+1)*sizeof(cjk[0]))+2;
3416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const double* weights = ensemble->get_subtree_weights()->data.db;
3426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double** dbl_ptr = (double**)cvStackAlloc( mi*sizeof(dbl_ptr[0]) );
3436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, j, k, idx;
3446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double L = 0, R;
3456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double best_val = 0;
3466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int best_subset = -1, subset_i;
3476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int boost_type = ensemble->get_params().boost_type;
3486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int split_criteria = ensemble->get_params().split_criteria;
3496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // init array of counters:
3516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // c_{jk} - number of samples that have vi-th input variable = j and response = k.
3526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( j = -1; j < mi; j++ )
3536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cjk[j*2] = cjk[j*2+1] = 0;
3546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < n; i++ )
3566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double w = weights[i];
3586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        j = cat_labels[i];
3596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        k = responses[i];
3606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cjk[j*2 + k] += w;
3616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
3626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( j = 0; j < mi; j++ )
3646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        rcw[0] += cjk[j*2];
3666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        rcw[1] += cjk[j*2+1];
3676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        dbl_ptr[j] = cjk + j*2 + 1;
3686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
3696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    R = rcw[0] + rcw[1];
3716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( split_criteria != CvBoost::GINI && split_criteria != CvBoost::MISCLASS )
3736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        split_criteria = boost_type == CvBoost::DISCRETE ? CvBoost::MISCLASS : CvBoost::GINI;
3746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // sort rows of c_jk by increasing c_j,1
3766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // (i.e. by the weight of samples in j-th category that belong to class 1)
3776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    icvSortDblPtr( dbl_ptr, mi, 0 );
3786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( subset_i = 0; subset_i < mi-1; subset_i++ )
3806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        idx = (int)(dbl_ptr[subset_i] - cjk)/2;
3826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const double* crow = cjk + idx*2;
3836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double w0 = crow[0], w1 = crow[1];
3846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double weight = w0 + w1;
3856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( weight < FLT_EPSILON )
3876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            continue;
3886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        lcw[0] += w0; rcw[0] -= w0;
3906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        lcw[1] += w1; rcw[1] -= w1;
3916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( split_criteria == CvBoost::GINI )
3936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
3946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double lsum2 = lcw[0]*lcw[0] + lcw[1]*lcw[1];
3956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double rsum2 = rcw[0]*rcw[0] + rcw[1]*rcw[1];
3966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            L += weight;
3986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            R -= weight;
3996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( L > FLT_EPSILON && R > FLT_EPSILON )
4016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
4026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double val = (lsum2*R + rsum2*L)/(L*R);
4036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( best_val < val )
4046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
4056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    best_val = val;
4066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    best_subset = subset_i;
4076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
4086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
4096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
4106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
4116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
4126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double val = lcw[0] + rcw[1];
4136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double val2 = lcw[1] + rcw[0];
4146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            val = MAX(val, val2);
4166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( best_val < val )
4176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
4186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                best_val = val;
4196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                best_subset = subset_i;
4206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
4216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
4226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
4236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( best_subset < 0 )
4256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        return 0;
4266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    split = data->new_split_cat( vi, (float)best_val );
4286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i <= best_subset; i++ )
4306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
4316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        idx = (int)(dbl_ptr[i] - cjk) >> 1;
4326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        split->subset[idx >> 5] |= 1 << (idx & 31);
4336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
4346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return split;
4366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
4376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit*
4406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::find_split_ord_reg( CvDTreeNode* node, int vi )
4416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
4426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const float epsilon = FLT_EPSILON*2;
4436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
4446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const float* responses = data->get_ord_responses(node);
4456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const double* weights = ensemble->get_subtree_weights()->data.db;
4466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n = node->sample_count;
4476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n1 = node->get_num_valid(vi);
4486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, best_i = -1;
4496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double best_val = 0, lsum = 0, rsum = node->value*n;
4506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double L = 0, R = weights[n];
4516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // compensate for missing values
4536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = n1; i < n; i++ )
4546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
4556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int idx = sorted[i].i;
4566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double w = weights[idx];
4576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        rsum -= responses[idx]*w;
4586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        R -= w;
4596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
4606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // find the optimal split
4626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < n1 - 1; i++ )
4636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
4646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int idx = sorted[i].i;
4656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double w = weights[idx];
4666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double t = responses[idx]*w;
4676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        L += w; R -= w;
4686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        lsum += t; rsum -= t;
4696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( sorted[i].val + epsilon < sorted[i+1].val )
4716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
4726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
4736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( best_val < val )
4746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
4756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                best_val = val;
4766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                best_i = i;
4776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
4786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
4796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
4806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return best_i >= 0 ? data->new_split_ord( vi,
4826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
4836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        0, (float)best_val ) : 0;
4846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
4856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit*
4886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::find_split_cat_reg( CvDTreeNode* node, int vi )
4896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
4906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeSplit* split;
4916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int* cat_labels = data->get_cat_var_data(node, vi);
4926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const float* responses = data->get_ord_responses(node);
4936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const double* weights = ensemble->get_subtree_weights()->data.db;
4946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int ci = data->get_var_type(vi);
4956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n = node->sample_count;
4966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int mi = data->cat_count->data.i[ci];
4976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;
4986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* counts = (double*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;
4996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double** sum_ptr = (double**)cvStackAlloc( mi*sizeof(sum_ptr[0]) );
5006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double L = 0, R = 0, best_val = 0, lsum = 0, rsum = 0;
5016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, best_subset = -1, subset_i;
5026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = -1; i < mi; i++ )
5046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sum[i] = counts[i] = 0;
5056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // calculate sum response and weight of each category of the input var
5076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < n; i++ )
5086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
5096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int idx = cat_labels[i];
5106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double w = weights[i];
5116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double s = sum[idx] + responses[i]*w;
5126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double nc = counts[idx] + w;
5136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sum[idx] = s;
5146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        counts[idx] = nc;
5156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
5166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // calculate average response in each category
5186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < mi; i++ )
5196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
5206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        R += counts[i];
5216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        rsum += sum[i];
5226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sum[i] /= counts[i];
5236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sum_ptr[i] = sum + i;
5246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
5256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    icvSortDblPtr( sum_ptr, mi, 0 );
5276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // revert back to unnormalized sums
5296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // (there should be a very little loss in accuracy)
5306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < mi; i++ )
5316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sum[i] *= counts[i];
5326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( subset_i = 0; subset_i < mi-1; subset_i++ )
5346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
5356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int idx = (int)(sum_ptr[subset_i] - sum);
5366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double ni = counts[idx];
5376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( ni > FLT_EPSILON )
5396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
5406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double s = sum[idx];
5416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            lsum += s; L += ni;
5426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            rsum -= s; R -= ni;
5436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( L > FLT_EPSILON && R > FLT_EPSILON )
5456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
5466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
5476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( best_val < val )
5486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
5496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    best_val = val;
5506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    best_subset = subset_i;
5516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
5526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
5536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
5546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
5556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( best_subset < 0 )
5576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        return 0;
5586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    split = data->new_split_cat( vi, (float)best_val );
5606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i <= best_subset; i++ )
5616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
5626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int idx = (int)(sum_ptr[i] - sum);
5636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        split->subset[idx >> 5] |= 1 << (idx & 31);
5646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
5656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return split;
5676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
5686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit*
5716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )
5726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
5736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const float epsilon = FLT_EPSILON*2;
5746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
5756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const double* weights = ensemble->get_subtree_weights()->data.db;
5766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const char* dir = (char*)data->direction->data.ptr;
5776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n1 = node->get_num_valid(vi);
5786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // LL - number of samples that both the primary and the surrogate splits send to the left
5796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // LR - ... primary split sends to the left and the surrogate split sends to the right
5806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // RL - ... primary split sends to the right and the surrogate split sends to the left
5816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // RR - ... both send to the right
5826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, best_i = -1, best_inversed = 0;
5836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double best_val;
5846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double LL = 0, RL = 0, LR, RR;
5856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double worst_val = node->maxlr;
5866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double sum = 0, sum_abs = 0;
5876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    best_val = worst_val;
5886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < n1; i++ )
5906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
5916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int idx = sorted[i].i;
5926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double w = weights[idx];
5936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int d = dir[idx];
5946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sum += d*w; sum_abs += (d & 1)*w;
5956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
5966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // sum_abs = R + L; sum = R - L
5986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    RR = (sum_abs + sum)*0.5;
5996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    LR = (sum_abs - sum)*0.5;
6006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // initially all the samples are sent to the right by the surrogate split,
6026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // LR of them are sent to the left by primary split, and RR - to the right.
6036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
6046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < n1 - 1; i++ )
6056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
6066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int idx = sorted[i].i;
6076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double w = weights[idx];
6086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int d = dir[idx];
6096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( d < 0 )
6116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
6126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            LL += w; LR -= w;
6136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( LL + RR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
6146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
6156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                best_val = LL + RR;
6166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                best_i = i; best_inversed = 0;
6176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
6186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
6196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else if( d > 0 )
6206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
6216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            RL += w; RR -= w;
6226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( RL + LR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
6236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
6246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                best_val = RL + LR;
6256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                best_i = i; best_inversed = 1;
6266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
6276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
6286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
6296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
6316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
6326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        best_inversed, (float)best_val ) : 0;
6336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
6346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit*
6376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )
6386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
6396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int* cat_labels = data->get_cat_var_data(node, vi);
6406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const char* dir = (char*)data->direction->data.ptr;
6416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const double* weights = ensemble->get_subtree_weights()->data.db;
6426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n = node->sample_count;
6436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // LL - number of samples that both the primary and the surrogate splits send to the left
6446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // LR - ... primary split sends to the left and the surrogate split sends to the right
6456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // RL - ... primary split sends to the right and the surrogate split sends to the left
6466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // RR - ... both send to the right
6476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeSplit* split = data->new_split_cat( vi, 0 );
6486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, mi = data->cat_count->data.i[data->get_var_type(vi)];
6496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double best_val = 0;
6506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;
6516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* rc = lc + mi + 1;
6526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = -1; i < mi; i++ )
6546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        lc[i] = rc[i] = 0;
6556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // 1. for each category calculate the weight of samples
6576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // sent to the left (lc) and to the right (rc) by the primary split
6586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < n; i++ )
6596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
6606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int idx = cat_labels[i];
6616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double w = weights[i];
6626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int d = dir[i];
6636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double sum = lc[idx] + d*w;
6646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double sum_abs = rc[idx] + (d & 1)*w;
6656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        lc[idx] = sum; rc[idx] = sum_abs;
6666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
6676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < mi; i++ )
6696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
6706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double sum = lc[i];
6716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double sum_abs = rc[i];
6726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        lc[i] = (sum_abs - sum) * 0.5;
6736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        rc[i] = (sum_abs + sum) * 0.5;
6746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
6756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // 2. now form the split.
6776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // in each category send all the samples to the same direction as majority
6786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < mi; i++ )
6796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
6806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double lval = lc[i], rval = rc[i];
6816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( lval > rval )
6826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
6836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            split->subset[i >> 5] |= 1 << (i & 31);
6846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            best_val += lval;
6856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
6866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
6876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            best_val += rval;
6886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
6896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    split->quality = (float)best_val;
6916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( split->quality <= node->maxlr )
6926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvSetRemoveByPtr( data->split_heap, split ), split = 0;
6936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return split;
6956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
6966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid
6996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::calc_node_value( CvDTreeNode* node )
7006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
7016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, count = node->sample_count;
7026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const double* weights = ensemble->get_weights()->data.db;
7036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int* labels = data->get_labels(node);
7046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* subtree_weights = ensemble->get_subtree_weights()->data.db;
7056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double rcw[2] = {0,0};
7066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int boost_type = ensemble->get_params().boost_type;
7076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    //const double* priors = data->priors->data.db;
7086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( data->is_classifier )
7106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
7116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const int* responses = data->get_class_labels(node);
7126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < count; i++ )
7146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
7156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int idx = labels[i];
7166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double w = weights[idx]/*priors[responses[i]]*/;
7176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            rcw[responses[i]] += w;
7186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            subtree_weights[i] = w;
7196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
7206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->class_idx = rcw[1] > rcw[0];
7226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( boost_type == CvBoost::DISCRETE )
7246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
7256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // ignore cat_map for responses, and use {-1,1},
7266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // as the whole ensemble response is computes as sign(sum_i(weak_response_i)
7276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            node->value = node->class_idx*2 - 1;
7286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
7296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
7306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
7316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double p = rcw[1]/(rcw[0] + rcw[1]);
7326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            assert( boost_type == CvBoost::REAL );
7336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // store log-ratio of the probability
7356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            node->value = 0.5*log_ratio(p);
7366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
7376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
7386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
7396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
7406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // in case of regression tree:
7416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
7426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        //    n is the number of samples in the node.
7436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
7446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double sum = 0, sum2 = 0, iw;
7456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const float* values = data->get_ord_responses(node);
7466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < count; i++ )
7486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
7496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int idx = labels[i];
7506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double w = weights[idx]/*priors[values[i] > 0]*/;
7516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double t = values[i];
7526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            rcw[0] += w;
7536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            subtree_weights[i] = w;
7546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            sum += t*w;
7556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            sum2 += t*t*w;
7566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
7576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        iw = 1./rcw[0];
7596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->value = sum*iw;
7606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->node_risk = sum2 - (sum*iw)*sum;
7616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // renormalize the risk, as in try_split_node the unweighted formula
7636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // sqrt(risk)/n is used, rather than sqrt(risk)/sum(weights_i)
7646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->node_risk *= count*iw*count*iw;
7656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
7666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // store summary weights
7686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    subtree_weights[count] = rcw[0];
7696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    subtree_weights[count+1] = rcw[1];
7706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
7716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvBoostTree::read( CvFileStorage* fs, CvFileNode* fnode, CvBoost* _ensemble, CvDTreeTrainData* _data )
7746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
7756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTree::read( fs, fnode, _data );
7766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ensemble = _ensemble;
7776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
7786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvBoostTree::read( CvFileStorage*, CvFileNode* )
7816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
7826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    assert(0);
7836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
7846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvBoostTree::read( CvFileStorage* _fs, CvFileNode* _node,
7866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        CvDTreeTrainData* _data )
7876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
7886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTree::read( _fs, _node, _data );
7896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
7906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/////////////////////////////////// CvBoost /////////////////////////////////////
7936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::CvBoost()
7956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
7966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data = 0;
7976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    weak = 0;
7986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    default_model_name = "my_boost_tree";
7996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    orig_response = sum_response = weak_eval = subsample_mask =
8006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        weights = subtree_weights = 0;
8016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
8036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
8046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvBoost::prune( CvSlice slice )
8076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
8086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( weak )
8096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
8106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvSeqReader reader;
8116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int i, count = cvSliceLength( slice, weak );
8126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvStartReadSeq( weak, &reader );
8146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvSetSeqReaderPos( &reader, slice.start_index );
8156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < count; i++ )
8176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
8186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CvBoostTree* w;
8196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_READ_SEQ_ELEM( w, reader );
8206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            delete w;
8216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
8226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvSeqRemoveSlice( weak, slice );
8246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
8256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
8266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvBoost::clear()
8296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
8306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( weak )
8316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
8326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        prune( CV_WHOLE_SEQ );
8336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReleaseMemStorage( &weak->storage );
8346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
8356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( data )
8366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        delete data;
8376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    weak = 0;
8386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data = 0;
8396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &orig_response );
8406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &sum_response );
8416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &weak_eval );
8426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &subsample_mask );
8436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &weights );
8446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    have_subsample = false;
8456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
8466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::~CvBoost()
8496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
8506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
8516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
8526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::CvBoost( const CvMat* _train_data, int _tflag,
8556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                  const CvMat* _responses, const CvMat* _var_idx,
8566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                  const CvMat* _sample_idx, const CvMat* _var_type,
8576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                  const CvMat* _missing_mask, CvBoostParams _params )
8586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
8596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    weak = 0;
8606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data = 0;
8616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    default_model_name = "my_boost_tree";
8626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    orig_response = sum_response = weak_eval = subsample_mask = weights = 0;
8636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    train( _train_data, _tflag, _responses, _var_idx, _sample_idx,
8656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn           _var_type, _missing_mask, _params );
8666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
8676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool
8706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::set_params( const CvBoostParams& _params )
8716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
8726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bool ok = false;
8736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvBoost::set_params" );
8756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
8776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params = _params;
8796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.boost_type != DISCRETE && params.boost_type != REAL &&
8806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.boost_type != LOGIT && params.boost_type != GENTLE )
8816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Unknown/unsupported boosting type" );
8826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params.weak_count = MAX( params.weak_count, 1 );
8846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params.weight_trim_rate = MAX( params.weight_trim_rate, 0. );
8856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params.weight_trim_rate = MIN( params.weight_trim_rate, 1. );
8866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.weight_trim_rate < FLT_EPSILON )
8876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.weight_trim_rate = 1.f;
8886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.boost_type == DISCRETE &&
8906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.split_criteria != GINI && params.split_criteria != MISCLASS )
8916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.split_criteria = MISCLASS;
8926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.boost_type == REAL &&
8936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.split_criteria != GINI && params.split_criteria != MISCLASS )
8946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.split_criteria = GINI;
8956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( (params.boost_type == LOGIT || params.boost_type == GENTLE) &&
8966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.split_criteria != SQERR )
8976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.split_criteria = SQERR;
8986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ok = true;
9006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
9026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return ok;
9046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
9056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool
9086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::train( const CvMat* _train_data, int _tflag,
9096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn              const CvMat* _responses, const CvMat* _var_idx,
9106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn              const CvMat* _sample_idx, const CvMat* _var_type,
9116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn              const CvMat* _missing_mask,
9126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn              CvBoostParams _params, bool _update )
9136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
9146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bool ok = false;
9156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMemStorage* storage = 0;
9166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvBoost::train" );
9186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
9206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i;
9226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    set_params( _params );
9246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !_update || !data )
9266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
9276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        clear();
9286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
9296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            _sample_idx, _var_type, _missing_mask, _params, true, true );
9306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( data->get_num_classes() != 2 )
9326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsNotImplemented,
9336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            "Boosted trees can only be used for 2-class classification." );
9346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( storage = cvCreateMemStorage() );
9356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
9366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        storage = 0;
9376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
9386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
9396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
9406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        data->set_data( _train_data, _tflag, _responses, _var_idx,
9416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            _sample_idx, _var_type, _missing_mask, _params, true, true, true );
9426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
9436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    update_weights( 0 );
9456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < params.weak_count; i++ )
9476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
9486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvBoostTree* tree = new CvBoostTree;
9496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !tree->train( data, subsample_mask, this ) )
9506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
9516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            delete tree;
9526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            continue;
9536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
9546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        //cvCheckArr( get_weak_response());
9556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvSeqPush( weak, &tree );
9566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        update_weights( tree );
9576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        trim_weights();
9586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
9596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data->is_classifier = true;
9616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ok = true;
9626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
9646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return ok;
9666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
9676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid
9706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::update_weights( CvBoostTree* tree )
9716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
9726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvBoost::update_weights" );
9736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
9756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, count = data->sample_count;
9776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double sumw = 0.;
9786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !tree ) // before training the first tree, initialize weights and other parameters
9806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
9816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const int* class_labels = data->get_class_labels(data->data_root);
9826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // in case of logitboost and gentle adaboost each weak tree is a regression tree,
9836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // so we need to convert class labels to floating-point values
9846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        float* responses = data->get_ord_responses(data->data_root);
9856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int* labels = data->get_labels(data->data_root);
9866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double w0 = 1./count;
9876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double p[2] = { 1, 1 };
9886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReleaseMat( &orig_response );
9906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReleaseMat( &sum_response );
9916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReleaseMat( &weak_eval );
9926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReleaseMat( &subsample_mask );
9936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReleaseMat( &weights );
9946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( orig_response = cvCreateMat( 1, count, CV_32S ));
9966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( weak_eval = cvCreateMat( 1, count, CV_64F ));
9976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( subsample_mask = cvCreateMat( 1, count, CV_8U ));
9986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( weights = cvCreateMat( 1, count, CV_64F ));
9996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( subtree_weights = cvCreateMat( 1, count + 2, CV_64F ));
10006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( data->have_priors )
10026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
10036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // compute weight scale for each class from their prior probabilities
10046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int c1 = 0;
10056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < count; i++ )
10066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                c1 += class_labels[i];
10076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            p[0] = data->priors->data.db[0]*(c1 < count ? 1./(count - c1) : 0.);
10086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            p[1] = data->priors->data.db[1]*(c1 > 0 ? 1./c1 : 0.);
10096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            p[0] /= p[0] + p[1];
10106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            p[1] = 1. - p[0];
10116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
10126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < count; i++ )
10146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
10156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // save original categorical responses {0,1}, convert them to {-1,1}
10166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            orig_response->data.i[i] = class_labels[i]*2 - 1;
10176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // make all the samples active at start.
10186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // later, in trim_weights() deactivate/reactive again some, if need
10196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            subsample_mask->data.ptr[i] = (uchar)1;
10206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // make all the initial weights the same.
10216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            weights->data.db[i] = w0*p[class_labels[i]];
10226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // set the labels to find (from within weak tree learning proc)
10236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // the particular sample weight, and where to store the response.
10246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            labels[i] = i;
10256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
10266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( params.boost_type == LOGIT )
10286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
10296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_CALL( sum_response = cvCreateMat( 1, count, CV_64F ));
10306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < count; i++ )
10326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
10336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                sum_response->data.db[i] = 0;
10346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                responses[i] = orig_response->data.i[i] > 0 ? 2.f : -2.f;
10356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
10366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // in case of logitboost each weak tree is a regression tree.
10386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // the target function values are recalculated for each of the trees
10396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            data->is_classifier = false;
10406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
10416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else if( params.boost_type == GENTLE )
10426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
10436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < count; i++ )
10446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                responses[i] = (float)orig_response->data.i[i];
10456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            data->is_classifier = false;
10476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
10486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
10496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
10506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
10516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // at this moment, for all the samples that participated in the training of the most
10526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // recent weak classifier we know the responses. For other samples we need to compute them
10536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( have_subsample )
10546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
10556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            float* values = (float*)(data->buf->data.ptr + data->buf->step);
10566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            uchar* missing = data->buf->data.ptr + data->buf->step*2;
10576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CvMat _sample, _mask;
10586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // invert the subsample mask
10606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvXorS( subsample_mask, cvScalar(1.), subsample_mask );
10616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            data->get_vectors( subsample_mask, values, missing, 0 );
10626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            //data->get_vectors( 0, values, missing, 0 );
10636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            _sample = cvMat( 1, data->var_count, CV_32F );
10656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            _mask = cvMat( 1, data->var_count, CV_8U );
10666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // run tree through all the non-processed samples
10686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < count; i++ )
10696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( subsample_mask->data.ptr[i] )
10706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
10716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    _sample.data.fl = values;
10726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    _mask.data.ptr = missing;
10736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    values += _sample.cols;
10746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    missing += _mask.cols;
10756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    weak_eval->data.db[i] = tree->predict( &_sample, &_mask, true )->value;
10766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
10776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
10786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // now update weights and other parameters for each type of boosting
10806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( params.boost_type == DISCRETE )
10816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
10826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // Discrete AdaBoost:
10836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            //   weak_eval[i] (=f(x_i)) is in {-1,1}
10846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            //   err = sum(w_i*(f(x_i) != y_i))/sum(w_i)
10856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            //   C = log((1-err)/err)
10866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            //   w_i *= exp(C*(f(x_i) != y_i))
10876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double C, err = 0.;
10896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double scale[] = { 1., 0. };
10906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < count; i++ )
10926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
10936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double w = weights->data.db[i];
10946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                sumw += w;
10956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                err += w*(weak_eval->data.db[i] != orig_response->data.i[i]);
10966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
10976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( sumw != 0 )
10996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                err /= sumw;
11006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            C = err = -log_ratio( err );
11016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            scale[1] = exp(err);
11026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            sumw = 0;
11046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < count; i++ )
11056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
11066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double w = weights->data.db[i]*
11076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    scale[weak_eval->data.db[i] != orig_response->data.i[i]];
11086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                sumw += w;
11096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                weights->data.db[i] = w;
11106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
11116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            tree->scale( C );
11136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
11146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else if( params.boost_type == REAL )
11156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
11166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // Real AdaBoost:
11176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            //   weak_eval[i] = f(x_i) = 0.5*log(p(x_i)/(1-p(x_i))), p(x_i)=P(y=1|x_i)
11186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            //   w_i *= exp(-y_i*f(x_i))
11196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < count; i++ )
11216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                weak_eval->data.db[i] *= -orig_response->data.i[i];
11226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvExp( weak_eval, weak_eval );
11246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < count; i++ )
11266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
11276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double w = weights->data.db[i]*weak_eval->data.db[i];
11286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                sumw += w;
11296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                weights->data.db[i] = w;
11306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
11316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
11326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else if( params.boost_type == LOGIT )
11336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
11346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // LogitBoost:
11356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            //   weak_eval[i] = f(x_i) in [-z_max,z_max]
11366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            //   sum_response = F(x_i).
11376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            //   F(x_i) += 0.5*f(x_i)
11386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            //   p(x_i) = exp(F(x_i))/(exp(F(x_i)) + exp(-F(x_i))=1/(1+exp(-2*F(x_i)))
11396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            //   reuse weak_eval: weak_eval[i] <- p(x_i)
11406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            //   w_i = p(x_i)*1(1 - p(x_i))
11416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            //   z_i = ((y_i+1)/2 - p(x_i))/(p(x_i)*(1 - p(x_i)))
11426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            //   store z_i to the data->data_root as the new target responses
11436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const double lb_weight_thresh = FLT_EPSILON;
11456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const double lb_z_max = 10.;
11466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            float* responses = data->get_ord_responses(data->data_root);
11476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            /*if( weak->total == 7 )
11496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                putchar('*');*/
11506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < count; i++ )
11526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
11536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double s = sum_response->data.db[i] + 0.5*weak_eval->data.db[i];
11546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                sum_response->data.db[i] = s;
11556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                weak_eval->data.db[i] = -2*s;
11566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
11576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvExp( weak_eval, weak_eval );
11596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < count; i++ )
11616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
11626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double p = 1./(1. + weak_eval->data.db[i]);
11636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double w = p*(1 - p), z;
11646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                w = MAX( w, lb_weight_thresh );
11656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                weights->data.db[i] = w;
11666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                sumw += w;
11676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( orig_response->data.i[i] > 0 )
11686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
11696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    z = 1./p;
11706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    responses[i] = (float)MIN(z, lb_z_max);
11716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
11726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                else
11736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
11746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    z = 1./(1-p);
11756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    responses[i] = (float)-MIN(z, lb_z_max);
11766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
11776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
11786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
11796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
11806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
11816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // Gentle AdaBoost:
11826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            //   weak_eval[i] = f(x_i) in [-1,1]
11836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            //   w_i *= exp(-y_i*f(x_i))
11846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            assert( params.boost_type == GENTLE );
11856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < count; i++ )
11876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                weak_eval->data.db[i] *= -orig_response->data.i[i];
11886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvExp( weak_eval, weak_eval );
11906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < count; i++ )
11926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
11936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double w = weights->data.db[i] * weak_eval->data.db[i];
11946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                weights->data.db[i] = w;
11956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                sumw += w;
11966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
11976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
11986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
11996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // renormalize weights
12016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( sumw > FLT_EPSILON )
12026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
12036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sumw = 1./sumw;
12046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < count; ++i )
12056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            weights->data.db[i] *= sumw;
12066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
12076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
12096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
12106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic CV_IMPLEMENT_QSORT_EX( icvSort_64f, double, CV_LT, int )
12136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid
12166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::trim_weights()
12176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
12186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvBoost::trim_weights" );
12196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
12216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, count = data->sample_count, nz_count = 0;
12236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double sum, threshold;
12246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.weight_trim_rate <= 0. || params.weight_trim_rate >= 1. )
12266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        EXIT;
12276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // use weak_eval as temporary buffer for sorted weights
12296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvCopy( weights, weak_eval );
12306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    icvSort_64f( weak_eval->data.db, count, 0 );
12326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // as weight trimming occurs immediately after updating the weights,
12346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // where they are renormalized, we assume that the weight sum = 1.
12356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    sum = 1. - params.weight_trim_rate;
12366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < count; i++ )
12386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
12396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double w = weak_eval->data.db[i];
12406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( sum > w )
12416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            break;
12426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sum -= w;
12436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
12446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    threshold = i < count ? weak_eval->data.db[i] : DBL_MAX;
12466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < count; i++ )
12486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
12496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double w = weights->data.db[i];
12506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int f = w > threshold;
12516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        subsample_mask->data.ptr[i] = (uchar)f;
12526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        nz_count += f;
12536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
12546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    have_subsample = nz_count < count;
12566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
12586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
12596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennfloat
12626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::predict( const CvMat* _sample, const CvMat* _missing,
12636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                  CvMat* weak_responses, CvSlice slice,
12646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                  bool raw_mode ) const
12656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
12666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float* buf = 0;
12676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bool allocated = false;
12686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float value = -FLT_MAX;
12696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvBoost::predict" );
12716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
12736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, weak_count, var_count;
12756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat sample, missing;
12766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvSeqReader reader;
12776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double sum = 0;
12786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int cls_idx;
12796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int wstep = 0;
12806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int* vtype;
12816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int* cmap;
12826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int* cofs;
12836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !weak )
12856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsError, "The boosted tree ensemble has not been trained yet" );
12866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
12886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        _sample->cols != 1 && _sample->rows != 1 ||
12896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        _sample->cols + _sample->rows - 1 != data->var_all && !raw_mode ||
12906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        _sample->cols + _sample->rows - 1 != data->var_count && raw_mode )
12916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadArg,
12926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        "the input sample must be 1d floating-point vector with the same "
12936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        "number of elements as the total number of variables used for training" );
12946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( _missing )
12966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
12976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
12986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            !CV_ARE_SIZES_EQ(_missing, _sample) )
12996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadArg,
13006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            "the missing data mask must be 8-bit vector of the same size as input sample" );
13016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
13026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    weak_count = cvSliceLength( slice, weak );
13046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( weak_count >= weak->total )
13056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
13066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        weak_count = weak->total;
13076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        slice.start_index = 0;
13086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
13096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( weak_responses )
13116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
13126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !CV_IS_MAT(weak_responses) ||
13136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_MAT_TYPE(weak_responses->type) != CV_32FC1 ||
13146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            weak_responses->cols != 1 && weak_responses->rows != 1 ||
13156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            weak_responses->cols + weak_responses->rows - 1 != weak_count )
13166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadArg,
13176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            "The output matrix of weak classifier responses must be valid "
13186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            "floating-point vector of the same number of components as the length of input slice" );
13196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        wstep = CV_IS_MAT_CONT(weak_responses->type) ? 1 : weak_responses->step/sizeof(float);
13206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
13216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    var_count = data->var_count;
13236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    vtype = data->var_type->data.i;
13246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cmap = data->cat_map->data.i;
13256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cofs = data->cat_ofs->data.i;
13266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // if need, preprocess the input vector
13286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !raw_mode && (data->cat_var_count > 0 || data->var_idx) )
13296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
13306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int bufsize;
13316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int step, mstep = 0;
13326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const float* src_sample;
13336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const uchar* src_mask = 0;
13346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        float* dst_sample;
13356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        uchar* dst_mask;
13366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const int* vidx = data->var_idx && !raw_mode ? data->var_idx->data.i : 0;
13376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        bool have_mask = _missing != 0;
13386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        bufsize = var_count*(sizeof(float) + sizeof(uchar));
13406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( bufsize <= CV_MAX_LOCAL_SIZE )
13416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            buf = (float*)cvStackAlloc( bufsize );
13426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
13436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
13446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_CALL( buf = (float*)cvAlloc( bufsize ));
13456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            allocated = true;
13466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
13476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        dst_sample = buf;
13486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        dst_mask = (uchar*)(buf + var_count);
13496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        src_sample = _sample->data.fl;
13516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(src_sample[0]);
13526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( _missing )
13546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
13556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            src_mask = _missing->data.ptr;
13566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step;
13576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
13586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < var_count; i++ )
13606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
13616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int idx = vidx ? vidx[i] : i;
13626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            float val = src_sample[idx*step];
13636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int ci = vtype[i];
13646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            uchar m = src_mask ? src_mask[i] : (uchar)0;
13656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( ci >= 0 )
13676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
13686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int a = cofs[ci], b = cofs[ci+1], c = a;
13696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int ival = cvRound(val);
13706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( ival != val )
13716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    CV_ERROR( CV_StsBadArg,
13726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    "one of input categorical variable is not an integer" );
13736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                while( a < b )
13756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
13766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    c = (a + b) >> 1;
13776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( ival < cmap[c] )
13786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        b = c;
13796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    else if( ival > cmap[c] )
13806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        a = c+1;
13816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    else
13826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        break;
13836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
13846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( c < 0 || ival != cmap[c] )
13866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
13876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    m = 1;
13886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    have_mask = true;
13896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
13906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                else
13916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
13926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    val = (float)(c - cofs[ci]);
13936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
13946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
13956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            dst_sample[i] = val;
13976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            dst_mask[i] = m;
13986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
13996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sample = cvMat( 1, var_count, CV_32F, dst_sample );
14016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        _sample = &sample;
14026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( have_mask )
14046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
14056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            missing = cvMat( 1, var_count, CV_8UC1, dst_mask );
14066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            _missing = &missing;
14076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
14086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
14096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvStartReadSeq( weak, &reader );
14116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvSetSeqReaderPos( &reader, slice.start_index );
14126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < weak_count; i++ )
14146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
14156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvBoostTree* wtree;
14166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double val;
14176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_READ_SEQ_ELEM( wtree, reader );
14196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        val = wtree->predict( _sample, _missing, true )->value;
14216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( weak_responses )
14226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            weak_responses->data.fl[i*wstep] = (float)val;
14236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sum += val;
14256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
14266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cls_idx = sum >= 0;
14286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( raw_mode )
14296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        value = (float)cls_idx;
14306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
14316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        value = (float)cmap[cofs[vtype[var_count]] + cls_idx];
14326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
14346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( allocated )
14366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvFree( &buf );
14376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return value;
14396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
14406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvBoost::write_params( CvFileStorage* fs )
14446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
14456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvBoost::write_params" );
14466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
14486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const char* boost_type_str =
14506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.boost_type == DISCRETE ? "DiscreteAdaboost" :
14516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.boost_type == REAL ? "RealAdaboost" :
14526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.boost_type == LOGIT ? "LogitBoost" :
14536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.boost_type == GENTLE ? "GentleAdaboost" : 0;
14546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const char* split_crit_str =
14566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.split_criteria == DEFAULT ? "Default" :
14576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.split_criteria == GINI ? "Gini" :
14586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.boost_type == MISCLASS ? "Misclassification" :
14596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.boost_type == SQERR ? "SquaredErr" : 0;
14606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( boost_type_str )
14626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteString( fs, "boosting_type", boost_type_str );
14636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
14646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteInt( fs, "boosting_type", params.boost_type );
14656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( split_crit_str )
14676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteString( fs, "splitting_criteria", split_crit_str );
14686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
14696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteInt( fs, "splitting_criteria", params.split_criteria );
14706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteInt( fs, "ntrees", params.weak_count );
14726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteReal( fs, "weight_trimming_rate", params.weight_trim_rate );
14736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data->write_params( fs );
14756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
14776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
14786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvBoost::read_params( CvFileStorage* fs, CvFileNode* fnode )
14816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
14826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvBoost::read_params" );
14836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
14856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvFileNode* temp;
14876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !fnode || !CV_NODE_IS_MAP(fnode->tag) )
14896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        return;
14906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data = new CvDTreeTrainData();
14926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( data->read_params(fs, fnode));
14936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data->shared = true;
14946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params.max_depth = data->params.max_depth;
14966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params.min_sample_count = data->params.min_sample_count;
14976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params.max_categories = data->params.max_categories;
14986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params.priors = data->params.priors;
14996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params.regression_accuracy = data->params.regression_accuracy;
15006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params.use_surrogates = data->params.use_surrogates;
15016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    temp = cvGetFileNodeByName( fs, fnode, "boosting_type" );
15036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !temp )
15046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        return;
15056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( temp && CV_NODE_IS_STRING(temp->tag) )
15076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
15086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const char* boost_type_str = cvReadString( temp, "" );
15096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.boost_type = strcmp( boost_type_str, "DiscreteAdaboost" ) == 0 ? DISCRETE :
15106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            strcmp( boost_type_str, "RealAdaboost" ) == 0 ? REAL :
15116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            strcmp( boost_type_str, "LogitBoost" ) == 0 ? LOGIT :
15126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            strcmp( boost_type_str, "GentleAdaboost" ) == 0 ? GENTLE : -1;
15136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
15146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
15156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.boost_type = cvReadInt( temp, -1 );
15166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.boost_type < DISCRETE || params.boost_type > GENTLE )
15186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Unknown boosting type" );
15196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    temp = cvGetFileNodeByName( fs, fnode, "splitting_criteria" );
15216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( temp && CV_NODE_IS_STRING(temp->tag) )
15226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
15236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const char* split_crit_str = cvReadString( temp, "" );
15246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.split_criteria = strcmp( split_crit_str, "Default" ) == 0 ? DEFAULT :
15256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                strcmp( split_crit_str, "Gini" ) == 0 ? GINI :
15266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                strcmp( split_crit_str, "Misclassification" ) == 0 ? MISCLASS :
15276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                strcmp( split_crit_str, "SquaredErr" ) == 0 ? SQERR : -1;
15286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
15296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
15306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.split_criteria = cvReadInt( temp, -1 );
15316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.split_criteria < DEFAULT || params.boost_type > SQERR )
15336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Unknown boosting type" );
15346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params.weak_count = cvReadIntByName( fs, fnode, "ntrees" );
15366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params.weight_trim_rate = cvReadRealByName( fs, fnode, "weight_trimming_rate", 0. );
15376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
15396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
15406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid
15446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::read( CvFileStorage* fs, CvFileNode* node )
15456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
15466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvRTrees::read" );
15476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
15496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvSeqReader reader;
15516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvFileNode* trees_fnode;
15526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMemStorage* storage;
15536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, ntrees;
15546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
15566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    read_params( fs, node );
15576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !data )
15596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        EXIT;
15606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    trees_fnode = cvGetFileNodeByName( fs, node, "trees" );
15626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !trees_fnode || !CV_NODE_IS_SEQ(trees_fnode->tag) )
15636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsParseError, "<trees> tag is missing" );
15646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvStartReadSeq( trees_fnode->data.seq, &reader );
15666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ntrees = trees_fnode->data.seq->total;
15676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( ntrees != params.weak_count )
15696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsUnmatchedSizes,
15706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        "The number of trees stored does not match <ntrees> tag value" );
15716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( storage = cvCreateMemStorage() );
15736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
15746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < ntrees; i++ )
15766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
15776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvBoostTree* tree = new CvBoostTree();
15786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(tree->read( fs, (CvFileNode*)reader.ptr, this, data ));
15796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
15806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvSeqPush( weak, &tree );
15816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
15826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
15846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
15856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid
15886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::write( CvFileStorage* fs, const char* name )
15896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
15906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvBoost::write" );
15916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
15936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvSeqReader reader;
15956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i;
15966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_BOOSTING );
15986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !weak )
16006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "The classifier has not been trained yet" );
16016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    write_params( fs );
16036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvStartWriteStruct( fs, "trees", CV_NODE_SEQ );
16046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvStartReadSeq( weak, &reader );
16066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < weak->total; i++ )
16086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
16096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvBoostTree* tree;
16106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_READ_SEQ_ELEM( tree, reader );
16116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvStartWriteStruct( fs, 0, CV_NODE_MAP );
16126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        tree->write( fs );
16136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvEndWriteStruct( fs );
16146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
16156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvEndWriteStruct( fs );
16176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvEndWriteStruct( fs );
16186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
16206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
16216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvMat*
16246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::get_weights()
16256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
16266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return weights;
16276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
16286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvMat*
16316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::get_subtree_weights()
16326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
16336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return subtree_weights;
16346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
16356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvMat*
16386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::get_weak_response()
16396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
16406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return weak_eval;
16416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
16426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennconst CvBoostParams&
16456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::get_params() const
16466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
16476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return params;
16486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
16496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvSeq* CvBoost::get_weak_predictors()
16516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
16526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return weak;
16536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
16546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/* End of file. */
1656