16acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/*M/////////////////////////////////////////////////////////////////////////////////////// 26acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 36acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 46acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 56acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// By downloading, copying, installing or using the software you agree to this license. 66acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// If you do not agree to this license, do not download, install, 76acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// copy or use the software. 86acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 96acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Intel License Agreement 116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Copyright (C) 2000, Intel Corporation, all rights reserved. 136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Third party copyrights are property of their respective owners. 146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Redistribution and use in source and binary forms, with or without modification, 166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// are permitted provided that the following conditions are met: 176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// * Redistribution's of source code must retain the above copyright notice, 196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// this list of conditions and the following disclaimer. 206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// * Redistribution's in binary form must reproduce the above copyright notice, 226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// this list of conditions and the following disclaimer in the documentation 236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// and/or other materials provided with the distribution. 246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// * The name of Intel Corporation may not be used to endorse or promote products 266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// derived from this software without specific prior written permission. 276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// This software is provided by the copyright holders and contributors "as is" and 296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// any express or implied warranties, including, but not limited to, the implied 306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// warranties of merchantability and fitness for a particular purpose are disclaimed. 316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// In no event shall the Intel Corporation or contributors be liable for any direct, 326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// indirect, incidental, special, exemplary, or consequential damages 336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// (including, but not limited to, procurement of substitute goods or services; 346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// loss of use, data, or profits; or business interruption) however caused 356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// and on any theory of liability, whether in contract, strict liability, 366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// or tort (including negligence or otherwise) arising in any way out of 376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// the use of this software, even if advised of the possibility of such damage. 386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//M*/ 406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#include "_ml.h" 426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic inline double 446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennlog_ratio( double val ) 456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double eps = 1e-5; 476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn val = MAX( val, eps ); 496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn val = MIN( val, 1. - eps ); 506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return log( val/(1. - val) ); 516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostParams::CvBoostParams() 556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn boost_type = CvBoost::REAL; 576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weak_count = 100; 586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weight_trim_rate = 0.95; 596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cv_folds = 0; 606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_depth = 1; 616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostParams::CvBoostParams( int _boost_type, int _weak_count, 656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double _weight_trim_rate, int _max_depth, 666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bool _use_surrogates, const float* _priors ) 676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn boost_type = _boost_type; 696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weak_count = _weak_count; 706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weight_trim_rate = _weight_trim_rate; 716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split_criteria = CvBoost::DEFAULT; 726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cv_folds = 0; 736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_depth = _max_depth; 746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn use_surrogates = _use_surrogates; 756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn priors = _priors; 766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn///////////////////////////////// CvBoostTree /////////////////////////////////// 816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::CvBoostTree() 836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ensemble = 0; 856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::~CvBoostTree() 896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn clear(); 916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid 956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::clear() 966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTree::clear(); 986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ensemble = 0; 996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 1006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool 1036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::train( CvDTreeTrainData* _train_data, 1046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* _subsample_idx, CvBoost* _ensemble ) 1056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 1066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn clear(); 1076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ensemble = _ensemble; 1086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data = _train_data; 1096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data->shared = true; 1106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return do_train( _subsample_idx ); 1126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 1136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool 1166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::train( const CvMat*, int, const CvMat*, const CvMat*, 1176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat*, const CvMat*, const CvMat*, CvDTreeParams ) 1186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 1196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn assert(0); 1206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return false; 1216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 1226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool 1256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::train( CvDTreeTrainData*, const CvMat* ) 1266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 1276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn assert(0); 1286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return false; 1296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 1306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid 1336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::scale( double scale ) 1346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 1356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeNode* node = root; 1366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // traverse the tree and scale all the node values 1386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for(;;) 1396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 1406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeNode* parent; 1416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for(;;) 1426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 1436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->value *= scale; 1446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !node->left ) 1456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 1466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node = node->left; 1476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 1486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( parent = node->parent; parent && parent->right == node; 1506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node = parent, parent = parent->parent ) 1516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ; 1526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !parent ) 1546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 1556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node = parent->right; 1576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 1586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 1596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid 1626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::try_split_node( CvDTreeNode* node ) 1636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 1646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTree::try_split_node( node ); 1656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !node->left ) 1676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 1686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // if the node has not been split, 1696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // store the responses for the corresponding training samples 1706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* weak_eval = ensemble->get_weak_response()->data.db; 1716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* labels = data->get_labels( node ); 1726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, count = node->sample_count; 1736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double value = node->value; 1746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 1766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weak_eval[labels[i]] = value; 1776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 1786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 1796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renndouble 1826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::calc_node_dir( CvDTreeNode* node ) 1836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 1846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn char* dir = (char*)data->direction->data.ptr; 1856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* weights = ensemble->get_subtree_weights()->data.db; 1866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, n = node->sample_count, vi = node->split->var_idx; 1876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double L, R; 1886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn assert( !node->split->inversed ); 1906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( data->get_var_type(vi) >= 0 ) // split on categorical var 1926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 1936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* cat_labels = data->get_cat_var_data( node, vi ); 1946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* subset = node->split->subset; 1956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sum = 0, sum_abs = 0; 1966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 1986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 1996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = cat_labels[i]; 2006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = weights[i]; 2016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0; 2026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum += d*w; sum_abs += (d & 1)*w; 2036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dir[i] = (char)d; 2046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn R = (sum_abs + sum) * 0.5; 2076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn L = (sum_abs - sum) * 0.5; 2086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else // split on ordered var 2106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvPair32s32f* sorted = data->get_ord_var_data(node,vi); 2126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int split_point = node->split->ord.split_point; 2136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n1 = node->get_num_valid(vi); 2146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn assert( 0 <= split_point && split_point < n1-1 ); 2166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn L = R = 0; 2176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i <= split_point; i++ ) 2196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = sorted[i].i; 2216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = weights[idx]; 2226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dir[idx] = (char)-1; 2236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn L += w; 2246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ; i < n1; i++ ) 2276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = sorted[i].i; 2296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = weights[idx]; 2306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dir[idx] = (char)1; 2316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn R += w; 2326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ; i < n; i++ ) 2356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dir[sorted[i].i] = (char)0; 2366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->maxlr = MAX( L, R ); 2396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return node->split->quality/(L + R); 2406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 2416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* 2446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::find_split_ord_class( CvDTreeNode* node, int vi ) 2456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 2466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float epsilon = FLT_EPSILON*2; 2476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvPair32s32f* sorted = data->get_ord_var_data(node, vi); 2486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* responses = data->get_class_labels(node); 2496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* weights = ensemble->get_subtree_weights()->data.db; 2506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n = node->sample_count; 2516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n1 = node->get_num_valid(vi); 2526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* rcw0 = weights + n; 2536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double lcw[2] = {0,0}, rcw[2]; 2546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, best_i = -1; 2556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double best_val = 0; 2566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int boost_type = ensemble->get_params().boost_type; 2576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int split_criteria = ensemble->get_params().split_criteria; 2586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rcw[0] = rcw0[0]; rcw[1] = rcw0[1]; 2606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = n1; i < n; i++ ) 2616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = sorted[i].i; 2636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = weights[idx]; 2646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rcw[responses[idx]] -= w; 2656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( split_criteria != CvBoost::GINI && split_criteria != CvBoost::MISCLASS ) 2686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split_criteria = boost_type == CvBoost::DISCRETE ? CvBoost::MISCLASS : CvBoost::GINI; 2696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( split_criteria == CvBoost::GINI ) 2716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double L = 0, R = rcw[0] + rcw[1]; 2736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double lsum2 = 0, rsum2 = rcw[0]*rcw[0] + rcw[1]*rcw[1]; 2746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n1 - 1; i++ ) 2766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = sorted[i].i; 2786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = weights[idx], w2 = w*w; 2796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double lv, rv; 2806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn idx = responses[idx]; 2816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn L += w; R -= w; 2826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lv = lcw[idx]; rv = rcw[idx]; 2836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lsum2 += 2*lv*w + w2; 2846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rsum2 -= 2*rv*w - w2; 2856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lcw[idx] = lv + w; rcw[idx] = rv - w; 2866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( sorted[i].val + epsilon < sorted[i+1].val ) 2886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double val = (lsum2*R + rsum2*L)/(L*R); 2906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( best_val < val ) 2916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val = val; 2936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_i = i; 2946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 2996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n1 - 1; i++ ) 3016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = sorted[i].i; 3036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = weights[idx]; 3046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn idx = responses[idx]; 3056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lcw[idx] += w; 3066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rcw[idx] -= w; 3076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( sorted[i].val + epsilon < sorted[i+1].val ) 3096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double val = lcw[0] + rcw[1], val2 = lcw[1] + rcw[0]; 3116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn val = MAX(val, val2); 3126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( best_val < val ) 3136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val = val; 3156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_i = i; 3166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return best_i >= 0 ? data->new_split_ord( vi, 3226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i, 3236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 0, (float)best_val ) : 0; 3246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 3256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#define CV_CMP_NUM_PTR(a,b) (*(a) < *(b)) 3286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int ) 3296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* 3316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::find_split_cat_class( CvDTreeNode* node, int vi ) 3326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 3336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeSplit* split; 3346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* cat_labels = data->get_cat_var_data(node, vi); 3356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* responses = data->get_class_labels(node); 3366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int ci = data->get_var_type(vi); 3376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n = node->sample_count; 3386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int mi = data->cat_count->data.i[ci]; 3396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double lcw[2]={0,0}, rcw[2]={0,0}; 3406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* cjk = (double*)cvStackAlloc(2*(mi+1)*sizeof(cjk[0]))+2; 3416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* weights = ensemble->get_subtree_weights()->data.db; 3426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double** dbl_ptr = (double**)cvStackAlloc( mi*sizeof(dbl_ptr[0]) ); 3436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, j, k, idx; 3446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double L = 0, R; 3456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double best_val = 0; 3466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int best_subset = -1, subset_i; 3476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int boost_type = ensemble->get_params().boost_type; 3486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int split_criteria = ensemble->get_params().split_criteria; 3496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // init array of counters: 3516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // c_{jk} - number of samples that have vi-th input variable = j and response = k. 3526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = -1; j < mi; j++ ) 3536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cjk[j*2] = cjk[j*2+1] = 0; 3546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 3566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = weights[i]; 3586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn j = cat_labels[i]; 3596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn k = responses[i]; 3606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cjk[j*2 + k] += w; 3616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < mi; j++ ) 3646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rcw[0] += cjk[j*2]; 3666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rcw[1] += cjk[j*2+1]; 3676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dbl_ptr[j] = cjk + j*2 + 1; 3686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn R = rcw[0] + rcw[1]; 3716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( split_criteria != CvBoost::GINI && split_criteria != CvBoost::MISCLASS ) 3736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split_criteria = boost_type == CvBoost::DISCRETE ? CvBoost::MISCLASS : CvBoost::GINI; 3746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // sort rows of c_jk by increasing c_j,1 3766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // (i.e. by the weight of samples in j-th category that belong to class 1) 3776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn icvSortDblPtr( dbl_ptr, mi, 0 ); 3786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( subset_i = 0; subset_i < mi-1; subset_i++ ) 3806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn idx = (int)(dbl_ptr[subset_i] - cjk)/2; 3826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* crow = cjk + idx*2; 3836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w0 = crow[0], w1 = crow[1]; 3846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double weight = w0 + w1; 3856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( weight < FLT_EPSILON ) 3876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn continue; 3886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lcw[0] += w0; rcw[0] -= w0; 3906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lcw[1] += w1; rcw[1] -= w1; 3916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( split_criteria == CvBoost::GINI ) 3936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double lsum2 = lcw[0]*lcw[0] + lcw[1]*lcw[1]; 3956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double rsum2 = rcw[0]*rcw[0] + rcw[1]*rcw[1]; 3966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn L += weight; 3986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn R -= weight; 3996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( L > FLT_EPSILON && R > FLT_EPSILON ) 4016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double val = (lsum2*R + rsum2*L)/(L*R); 4036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( best_val < val ) 4046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val = val; 4066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_subset = subset_i; 4076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 4116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double val = lcw[0] + rcw[1]; 4136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double val2 = lcw[1] + rcw[0]; 4146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn val = MAX(val, val2); 4166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( best_val < val ) 4176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val = val; 4196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_subset = subset_i; 4206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( best_subset < 0 ) 4256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return 0; 4266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split = data->new_split_cat( vi, (float)best_val ); 4286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i <= best_subset; i++ ) 4306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn idx = (int)(dbl_ptr[i] - cjk) >> 1; 4326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->subset[idx >> 5] |= 1 << (idx & 31); 4336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return split; 4366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 4376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* 4406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::find_split_ord_reg( CvDTreeNode* node, int vi ) 4416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 4426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float epsilon = FLT_EPSILON*2; 4436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvPair32s32f* sorted = data->get_ord_var_data(node, vi); 4446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float* responses = data->get_ord_responses(node); 4456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* weights = ensemble->get_subtree_weights()->data.db; 4466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n = node->sample_count; 4476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n1 = node->get_num_valid(vi); 4486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, best_i = -1; 4496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double best_val = 0, lsum = 0, rsum = node->value*n; 4506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double L = 0, R = weights[n]; 4516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // compensate for missing values 4536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = n1; i < n; i++ ) 4546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = sorted[i].i; 4566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = weights[idx]; 4576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rsum -= responses[idx]*w; 4586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn R -= w; 4596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // find the optimal split 4626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n1 - 1; i++ ) 4636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = sorted[i].i; 4656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = weights[idx]; 4666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double t = responses[idx]*w; 4676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn L += w; R -= w; 4686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lsum += t; rsum -= t; 4696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( sorted[i].val + epsilon < sorted[i+1].val ) 4716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double val = (lsum*lsum*R + rsum*rsum*L)/(L*R); 4736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( best_val < val ) 4746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val = val; 4766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_i = i; 4776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return best_i >= 0 ? data->new_split_ord( vi, 4826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i, 4836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 0, (float)best_val ) : 0; 4846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 4856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* 4886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::find_split_cat_reg( CvDTreeNode* node, int vi ) 4896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 4906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeSplit* split; 4916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* cat_labels = data->get_cat_var_data(node, vi); 4926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float* responses = data->get_ord_responses(node); 4936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* weights = ensemble->get_subtree_weights()->data.db; 4946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int ci = data->get_var_type(vi); 4956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n = node->sample_count; 4966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int mi = data->cat_count->data.i[ci]; 4976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1; 4986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* counts = (double*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1; 4996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double** sum_ptr = (double**)cvStackAlloc( mi*sizeof(sum_ptr[0]) ); 5006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double L = 0, R = 0, best_val = 0, lsum = 0, rsum = 0; 5016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, best_subset = -1, subset_i; 5026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = -1; i < mi; i++ ) 5046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum[i] = counts[i] = 0; 5056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // calculate sum response and weight of each category of the input var 5076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 5086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = cat_labels[i]; 5106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = weights[i]; 5116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double s = sum[idx] + responses[i]*w; 5126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double nc = counts[idx] + w; 5136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum[idx] = s; 5146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn counts[idx] = nc; 5156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // calculate average response in each category 5186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < mi; i++ ) 5196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn R += counts[i]; 5216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rsum += sum[i]; 5226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum[i] /= counts[i]; 5236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum_ptr[i] = sum + i; 5246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn icvSortDblPtr( sum_ptr, mi, 0 ); 5276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // revert back to unnormalized sums 5296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // (there should be a very little loss in accuracy) 5306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < mi; i++ ) 5316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum[i] *= counts[i]; 5326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( subset_i = 0; subset_i < mi-1; subset_i++ ) 5346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = (int)(sum_ptr[subset_i] - sum); 5366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double ni = counts[idx]; 5376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ni > FLT_EPSILON ) 5396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double s = sum[idx]; 5416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lsum += s; L += ni; 5426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rsum -= s; R -= ni; 5436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( L > FLT_EPSILON && R > FLT_EPSILON ) 5456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double val = (lsum*lsum*R + rsum*rsum*L)/(L*R); 5476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( best_val < val ) 5486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val = val; 5506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_subset = subset_i; 5516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( best_subset < 0 ) 5576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return 0; 5586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split = data->new_split_cat( vi, (float)best_val ); 5606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i <= best_subset; i++ ) 5616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = (int)(sum_ptr[i] - sum); 5636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->subset[idx >> 5] |= 1 << (idx & 31); 5646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return split; 5676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 5686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* 5716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::find_surrogate_split_ord( CvDTreeNode* node, int vi ) 5726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 5736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float epsilon = FLT_EPSILON*2; 5746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvPair32s32f* sorted = data->get_ord_var_data(node, vi); 5756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* weights = ensemble->get_subtree_weights()->data.db; 5766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const char* dir = (char*)data->direction->data.ptr; 5776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n1 = node->get_num_valid(vi); 5786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // LL - number of samples that both the primary and the surrogate splits send to the left 5796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // LR - ... primary split sends to the left and the surrogate split sends to the right 5806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // RL - ... primary split sends to the right and the surrogate split sends to the left 5816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // RR - ... both send to the right 5826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, best_i = -1, best_inversed = 0; 5836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double best_val; 5846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double LL = 0, RL = 0, LR, RR; 5856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double worst_val = node->maxlr; 5866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sum = 0, sum_abs = 0; 5876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val = worst_val; 5886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n1; i++ ) 5906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = sorted[i].i; 5926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = weights[idx]; 5936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int d = dir[idx]; 5946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum += d*w; sum_abs += (d & 1)*w; 5956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // sum_abs = R + L; sum = R - L 5986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn RR = (sum_abs + sum)*0.5; 5996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn LR = (sum_abs - sum)*0.5; 6006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // initially all the samples are sent to the right by the surrogate split, 6026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // LR of them are sent to the left by primary split, and RR - to the right. 6036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // now iteratively compute LL, LR, RL and RR for every possible surrogate split value. 6046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n1 - 1; i++ ) 6056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = sorted[i].i; 6076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = weights[idx]; 6086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int d = dir[idx]; 6096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( d < 0 ) 6116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn LL += w; LR -= w; 6136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( LL + RR > best_val && sorted[i].val + epsilon < sorted[i+1].val ) 6146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val = LL + RR; 6166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_i = i; best_inversed = 0; 6176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( d > 0 ) 6206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn RL += w; RR -= w; 6226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( RL + LR > best_val && sorted[i].val + epsilon < sorted[i+1].val ) 6236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val = RL + LR; 6256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_i = i; best_inversed = 1; 6266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi, 6316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i, 6326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_inversed, (float)best_val ) : 0; 6336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 6346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* 6376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::find_surrogate_split_cat( CvDTreeNode* node, int vi ) 6386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 6396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* cat_labels = data->get_cat_var_data(node, vi); 6406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const char* dir = (char*)data->direction->data.ptr; 6416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* weights = ensemble->get_subtree_weights()->data.db; 6426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int n = node->sample_count; 6436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // LL - number of samples that both the primary and the surrogate splits send to the left 6446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // LR - ... primary split sends to the left and the surrogate split sends to the right 6456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // RL - ... primary split sends to the right and the surrogate split sends to the left 6466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // RR - ... both send to the right 6476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeSplit* split = data->new_split_cat( vi, 0 ); 6486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, mi = data->cat_count->data.i[data->get_var_type(vi)]; 6496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double best_val = 0; 6506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1; 6516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* rc = lc + mi + 1; 6526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = -1; i < mi; i++ ) 6546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lc[i] = rc[i] = 0; 6556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // 1. for each category calculate the weight of samples 6576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // sent to the left (lc) and to the right (rc) by the primary split 6586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < n; i++ ) 6596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = cat_labels[i]; 6616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = weights[i]; 6626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int d = dir[i]; 6636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sum = lc[idx] + d*w; 6646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sum_abs = rc[idx] + (d & 1)*w; 6656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lc[idx] = sum; rc[idx] = sum_abs; 6666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < mi; i++ ) 6696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sum = lc[i]; 6716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sum_abs = rc[i]; 6726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn lc[i] = (sum_abs - sum) * 0.5; 6736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rc[i] = (sum_abs + sum) * 0.5; 6746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // 2. now form the split. 6776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // in each category send all the samples to the same direction as majority 6786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < mi; i++ ) 6796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double lval = lc[i], rval = rc[i]; 6816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( lval > rval ) 6826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->subset[i >> 5] |= 1 << (i & 31); 6846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val += lval; 6856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 6876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn best_val += rval; 6886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn split->quality = (float)best_val; 6916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( split->quality <= node->maxlr ) 6926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSetRemoveByPtr( data->split_heap, split ), split = 0; 6936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return split; 6956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 6966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid 6996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoostTree::calc_node_value( CvDTreeNode* node ) 7006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 7016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, count = node->sample_count; 7026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* weights = ensemble->get_weights()->data.db; 7036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* labels = data->get_labels(node); 7046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* subtree_weights = ensemble->get_subtree_weights()->data.db; 7056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double rcw[2] = {0,0}; 7066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int boost_type = ensemble->get_params().boost_type; 7076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn //const double* priors = data->priors->data.db; 7086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( data->is_classifier ) 7106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* responses = data->get_class_labels(node); 7126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 7146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = labels[i]; 7166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = weights[idx]/*priors[responses[i]]*/; 7176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rcw[responses[i]] += w; 7186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn subtree_weights[i] = w; 7196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->class_idx = rcw[1] > rcw[0]; 7226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( boost_type == CvBoost::DISCRETE ) 7246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // ignore cat_map for responses, and use {-1,1}, 7266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // as the whole ensemble response is computes as sign(sum_i(weak_response_i) 7276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->value = node->class_idx*2 - 1; 7286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 7306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double p = rcw[1]/(rcw[0] + rcw[1]); 7326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn assert( boost_type == CvBoost::REAL ); 7336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // store log-ratio of the probability 7356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->value = 0.5*log_ratio(p); 7366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 7396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // in case of regression tree: 7416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // * node value is 1/n*sum_i(Y_i), where Y_i is i-th response, 7426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // n is the number of samples in the node. 7436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2) 7446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sum = 0, sum2 = 0, iw; 7456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float* values = data->get_ord_responses(node); 7466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 7486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = labels[i]; 7506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = weights[idx]/*priors[values[i] > 0]*/; 7516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double t = values[i]; 7526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn rcw[0] += w; 7536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn subtree_weights[i] = w; 7546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum += t*w; 7556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum2 += t*t*w; 7566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn iw = 1./rcw[0]; 7596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->value = sum*iw; 7606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->node_risk = sum2 - (sum*iw)*sum; 7616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // renormalize the risk, as in try_split_node the unweighted formula 7636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // sqrt(risk)/n is used, rather than sqrt(risk)/sum(weights_i) 7646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn node->node_risk *= count*iw*count*iw; 7656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // store summary weights 7686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn subtree_weights[count] = rcw[0]; 7696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn subtree_weights[count+1] = rcw[1]; 7706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 7716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvBoostTree::read( CvFileStorage* fs, CvFileNode* fnode, CvBoost* _ensemble, CvDTreeTrainData* _data ) 7746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 7756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTree::read( fs, fnode, _data ); 7766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ensemble = _ensemble; 7776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 7786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvBoostTree::read( CvFileStorage*, CvFileNode* ) 7816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 7826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn assert(0); 7836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 7846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvBoostTree::read( CvFileStorage* _fs, CvFileNode* _node, 7866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTreeTrainData* _data ) 7876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 7886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvDTree::read( _fs, _node, _data ); 7896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 7906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/////////////////////////////////// CvBoost ///////////////////////////////////// 7936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::CvBoost() 7956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 7966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data = 0; 7976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weak = 0; 7986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn default_model_name = "my_boost_tree"; 7996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn orig_response = sum_response = weak_eval = subsample_mask = 8006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weights = subtree_weights = 0; 8016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn clear(); 8036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 8046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvBoost::prune( CvSlice slice ) 8076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 8086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( weak ) 8096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvSeqReader reader; 8116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, count = cvSliceLength( slice, weak ); 8126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartReadSeq( weak, &reader ); 8146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSetSeqReaderPos( &reader, slice.start_index ); 8156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 8176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvBoostTree* w; 8196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_READ_SEQ_ELEM( w, reader ); 8206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn delete w; 8216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSeqRemoveSlice( weak, slice ); 8246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 8266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvBoost::clear() 8296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 8306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( weak ) 8316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn prune( CV_WHOLE_SEQ ); 8336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMemStorage( &weak->storage ); 8346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( data ) 8366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn delete data; 8376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weak = 0; 8386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data = 0; 8396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &orig_response ); 8406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &sum_response ); 8416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &weak_eval ); 8426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &subsample_mask ); 8436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &weights ); 8446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn have_subsample = false; 8456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 8466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::~CvBoost() 8496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 8506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn clear(); 8516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 8526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::CvBoost( const CvMat* _train_data, int _tflag, 8556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* _responses, const CvMat* _var_idx, 8566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* _sample_idx, const CvMat* _var_type, 8576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* _missing_mask, CvBoostParams _params ) 8586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 8596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weak = 0; 8606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data = 0; 8616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn default_model_name = "my_boost_tree"; 8626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn orig_response = sum_response = weak_eval = subsample_mask = weights = 0; 8636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn train( _train_data, _tflag, _responses, _var_idx, _sample_idx, 8656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _var_type, _missing_mask, _params ); 8666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 8676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool 8706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::set_params( const CvBoostParams& _params ) 8716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 8726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bool ok = false; 8736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvBoost::set_params" ); 8756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 8776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params = _params; 8796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.boost_type != DISCRETE && params.boost_type != REAL && 8806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.boost_type != LOGIT && params.boost_type != GENTLE ) 8816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Unknown/unsupported boosting type" ); 8826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.weak_count = MAX( params.weak_count, 1 ); 8846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.weight_trim_rate = MAX( params.weight_trim_rate, 0. ); 8856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.weight_trim_rate = MIN( params.weight_trim_rate, 1. ); 8866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.weight_trim_rate < FLT_EPSILON ) 8876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.weight_trim_rate = 1.f; 8886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.boost_type == DISCRETE && 8906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.split_criteria != GINI && params.split_criteria != MISCLASS ) 8916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.split_criteria = MISCLASS; 8926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.boost_type == REAL && 8936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.split_criteria != GINI && params.split_criteria != MISCLASS ) 8946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.split_criteria = GINI; 8956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( (params.boost_type == LOGIT || params.boost_type == GENTLE) && 8966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.split_criteria != SQERR ) 8976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.split_criteria = SQERR; 8986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ok = true; 9006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 9026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return ok; 9046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 9056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool 9086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::train( const CvMat* _train_data, int _tflag, 9096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* _responses, const CvMat* _var_idx, 9106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* _sample_idx, const CvMat* _var_type, 9116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* _missing_mask, 9126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvBoostParams _params, bool _update ) 9136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 9146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bool ok = false; 9156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMemStorage* storage = 0; 9166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvBoost::train" ); 9186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 9206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i; 9226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn set_params( _params ); 9246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !_update || !data ) 9266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn clear(); 9286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx, 9296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _sample_idx, _var_type, _missing_mask, _params, true, true ); 9306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( data->get_num_classes() != 2 ) 9326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsNotImplemented, 9336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "Boosted trees can only be used for 2-class classification." ); 9346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( storage = cvCreateMemStorage() ); 9356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage ); 9366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn storage = 0; 9376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 9396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data->set_data( _train_data, _tflag, _responses, _var_idx, 9416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _sample_idx, _var_type, _missing_mask, _params, true, true, true ); 9426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn update_weights( 0 ); 9456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < params.weak_count; i++ ) 9476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvBoostTree* tree = new CvBoostTree; 9496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !tree->train( data, subsample_mask, this ) ) 9506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn delete tree; 9526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn continue; 9536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn //cvCheckArr( get_weak_response()); 9556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSeqPush( weak, &tree ); 9566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn update_weights( tree ); 9576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn trim_weights(); 9586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data->is_classifier = true; 9616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ok = true; 9626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 9646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return ok; 9666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 9676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid 9706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::update_weights( CvBoostTree* tree ) 9716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 9726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvBoost::update_weights" ); 9736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 9756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, count = data->sample_count; 9776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sumw = 0.; 9786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !tree ) // before training the first tree, initialize weights and other parameters 9806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* class_labels = data->get_class_labels(data->data_root); 9826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // in case of logitboost and gentle adaboost each weak tree is a regression tree, 9836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // so we need to convert class labels to floating-point values 9846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float* responses = data->get_ord_responses(data->data_root); 9856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int* labels = data->get_labels(data->data_root); 9866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w0 = 1./count; 9876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double p[2] = { 1, 1 }; 9886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &orig_response ); 9906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &sum_response ); 9916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &weak_eval ); 9926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &subsample_mask ); 9936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &weights ); 9946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( orig_response = cvCreateMat( 1, count, CV_32S )); 9966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( weak_eval = cvCreateMat( 1, count, CV_64F )); 9976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( subsample_mask = cvCreateMat( 1, count, CV_8U )); 9986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( weights = cvCreateMat( 1, count, CV_64F )); 9996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( subtree_weights = cvCreateMat( 1, count + 2, CV_64F )); 10006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( data->have_priors ) 10026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // compute weight scale for each class from their prior probabilities 10046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int c1 = 0; 10056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 10066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn c1 += class_labels[i]; 10076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn p[0] = data->priors->data.db[0]*(c1 < count ? 1./(count - c1) : 0.); 10086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn p[1] = data->priors->data.db[1]*(c1 > 0 ? 1./c1 : 0.); 10096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn p[0] /= p[0] + p[1]; 10106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn p[1] = 1. - p[0]; 10116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 10146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // save original categorical responses {0,1}, convert them to {-1,1} 10166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn orig_response->data.i[i] = class_labels[i]*2 - 1; 10176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // make all the samples active at start. 10186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // later, in trim_weights() deactivate/reactive again some, if need 10196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn subsample_mask->data.ptr[i] = (uchar)1; 10206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // make all the initial weights the same. 10216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weights->data.db[i] = w0*p[class_labels[i]]; 10226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // set the labels to find (from within weak tree learning proc) 10236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // the particular sample weight, and where to store the response. 10246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn labels[i] = i; 10256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.boost_type == LOGIT ) 10286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( sum_response = cvCreateMat( 1, count, CV_64F )); 10306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 10326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum_response->data.db[i] = 0; 10346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn responses[i] = orig_response->data.i[i] > 0 ? 2.f : -2.f; 10356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // in case of logitboost each weak tree is a regression tree. 10386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // the target function values are recalculated for each of the trees 10396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data->is_classifier = false; 10406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( params.boost_type == GENTLE ) 10426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 10446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn responses[i] = (float)orig_response->data.i[i]; 10456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data->is_classifier = false; 10476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 10506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // at this moment, for all the samples that participated in the training of the most 10526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // recent weak classifier we know the responses. For other samples we need to compute them 10536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( have_subsample ) 10546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float* values = (float*)(data->buf->data.ptr + data->buf->step); 10566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn uchar* missing = data->buf->data.ptr + data->buf->step*2; 10576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat _sample, _mask; 10586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // invert the subsample mask 10606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvXorS( subsample_mask, cvScalar(1.), subsample_mask ); 10616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data->get_vectors( subsample_mask, values, missing, 0 ); 10626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn //data->get_vectors( 0, values, missing, 0 ); 10636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _sample = cvMat( 1, data->var_count, CV_32F ); 10656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _mask = cvMat( 1, data->var_count, CV_8U ); 10666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // run tree through all the non-processed samples 10686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 10696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( subsample_mask->data.ptr[i] ) 10706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _sample.data.fl = values; 10726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _mask.data.ptr = missing; 10736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn values += _sample.cols; 10746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn missing += _mask.cols; 10756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weak_eval->data.db[i] = tree->predict( &_sample, &_mask, true )->value; 10766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // now update weights and other parameters for each type of boosting 10806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.boost_type == DISCRETE ) 10816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // Discrete AdaBoost: 10836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // weak_eval[i] (=f(x_i)) is in {-1,1} 10846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // err = sum(w_i*(f(x_i) != y_i))/sum(w_i) 10856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // C = log((1-err)/err) 10866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // w_i *= exp(C*(f(x_i) != y_i)) 10876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double C, err = 0.; 10896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double scale[] = { 1., 0. }; 10906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 10926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = weights->data.db[i]; 10946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sumw += w; 10956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn err += w*(weak_eval->data.db[i] != orig_response->data.i[i]); 10966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( sumw != 0 ) 10996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn err /= sumw; 11006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn C = err = -log_ratio( err ); 11016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn scale[1] = exp(err); 11026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sumw = 0; 11046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 11056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = weights->data.db[i]* 11076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn scale[weak_eval->data.db[i] != orig_response->data.i[i]]; 11086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sumw += w; 11096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weights->data.db[i] = w; 11106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn tree->scale( C ); 11136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( params.boost_type == REAL ) 11156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // Real AdaBoost: 11176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // weak_eval[i] = f(x_i) = 0.5*log(p(x_i)/(1-p(x_i))), p(x_i)=P(y=1|x_i) 11186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // w_i *= exp(-y_i*f(x_i)) 11196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 11216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weak_eval->data.db[i] *= -orig_response->data.i[i]; 11226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvExp( weak_eval, weak_eval ); 11246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 11266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = weights->data.db[i]*weak_eval->data.db[i]; 11286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sumw += w; 11296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weights->data.db[i] = w; 11306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( params.boost_type == LOGIT ) 11336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // LogitBoost: 11356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // weak_eval[i] = f(x_i) in [-z_max,z_max] 11366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // sum_response = F(x_i). 11376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // F(x_i) += 0.5*f(x_i) 11386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // p(x_i) = exp(F(x_i))/(exp(F(x_i)) + exp(-F(x_i))=1/(1+exp(-2*F(x_i))) 11396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // reuse weak_eval: weak_eval[i] <- p(x_i) 11406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // w_i = p(x_i)*1(1 - p(x_i)) 11416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // z_i = ((y_i+1)/2 - p(x_i))/(p(x_i)*(1 - p(x_i))) 11426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // store z_i to the data->data_root as the new target responses 11436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double lb_weight_thresh = FLT_EPSILON; 11456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double lb_z_max = 10.; 11466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float* responses = data->get_ord_responses(data->data_root); 11476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn /*if( weak->total == 7 ) 11496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn putchar('*');*/ 11506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 11526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double s = sum_response->data.db[i] + 0.5*weak_eval->data.db[i]; 11546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum_response->data.db[i] = s; 11556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weak_eval->data.db[i] = -2*s; 11566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvExp( weak_eval, weak_eval ); 11596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 11616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double p = 1./(1. + weak_eval->data.db[i]); 11636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = p*(1 - p), z; 11646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w = MAX( w, lb_weight_thresh ); 11656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weights->data.db[i] = w; 11666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sumw += w; 11676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( orig_response->data.i[i] > 0 ) 11686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn z = 1./p; 11706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn responses[i] = (float)MIN(z, lb_z_max); 11716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 11736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn z = 1./(1-p); 11756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn responses[i] = (float)-MIN(z, lb_z_max); 11766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 11806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // Gentle AdaBoost: 11826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // weak_eval[i] = f(x_i) in [-1,1] 11836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // w_i *= exp(-y_i*f(x_i)) 11846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn assert( params.boost_type == GENTLE ); 11856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 11876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weak_eval->data.db[i] *= -orig_response->data.i[i]; 11886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvExp( weak_eval, weak_eval ); 11906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 11926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 11936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = weights->data.db[i] * weak_eval->data.db[i]; 11946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weights->data.db[i] = w; 11956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sumw += w; 11966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 11996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // renormalize weights 12016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( sumw > FLT_EPSILON ) 12026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 12036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sumw = 1./sumw; 12046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; ++i ) 12056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weights->data.db[i] *= sumw; 12066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 12076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 12096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 12106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic CV_IMPLEMENT_QSORT_EX( icvSort_64f, double, CV_LT, int ) 12136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid 12166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::trim_weights() 12176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 12186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvBoost::trim_weights" ); 12196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 12216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, count = data->sample_count, nz_count = 0; 12236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sum, threshold; 12246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.weight_trim_rate <= 0. || params.weight_trim_rate >= 1. ) 12266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn EXIT; 12276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // use weak_eval as temporary buffer for sorted weights 12296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvCopy( weights, weak_eval ); 12306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn icvSort_64f( weak_eval->data.db, count, 0 ); 12326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // as weight trimming occurs immediately after updating the weights, 12346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // where they are renormalized, we assume that the weight sum = 1. 12356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum = 1. - params.weight_trim_rate; 12366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 12386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 12396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = weak_eval->data.db[i]; 12406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( sum > w ) 12416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 12426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum -= w; 12436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 12446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn threshold = i < count ? weak_eval->data.db[i] : DBL_MAX; 12466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < count; i++ ) 12486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 12496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w = weights->data.db[i]; 12506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int f = w > threshold; 12516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn subsample_mask->data.ptr[i] = (uchar)f; 12526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn nz_count += f; 12536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 12546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn have_subsample = nz_count < count; 12566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 12586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 12596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennfloat 12626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::predict( const CvMat* _sample, const CvMat* _missing, 12636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* weak_responses, CvSlice slice, 12646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bool raw_mode ) const 12656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 12666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float* buf = 0; 12676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bool allocated = false; 12686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float value = -FLT_MAX; 12696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvBoost::predict" ); 12716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 12736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, weak_count, var_count; 12756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat sample, missing; 12766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvSeqReader reader; 12776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sum = 0; 12786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int cls_idx; 12796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int wstep = 0; 12806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* vtype; 12816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* cmap; 12826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* cofs; 12836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !weak ) 12856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsError, "The boosted tree ensemble has not been trained yet" ); 12866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 || 12886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _sample->cols != 1 && _sample->rows != 1 || 12896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _sample->cols + _sample->rows - 1 != data->var_all && !raw_mode || 12906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _sample->cols + _sample->rows - 1 != data->var_count && raw_mode ) 12916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, 12926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "the input sample must be 1d floating-point vector with the same " 12936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "number of elements as the total number of variables used for training" ); 12946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 12956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( _missing ) 12966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 12976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) || 12986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn !CV_ARE_SIZES_EQ(_missing, _sample) ) 12996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, 13006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "the missing data mask must be 8-bit vector of the same size as input sample" ); 13016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 13026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weak_count = cvSliceLength( slice, weak ); 13046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( weak_count >= weak->total ) 13056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 13066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weak_count = weak->total; 13076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn slice.start_index = 0; 13086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 13096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( weak_responses ) 13116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 13126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_IS_MAT(weak_responses) || 13136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_MAT_TYPE(weak_responses->type) != CV_32FC1 || 13146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weak_responses->cols != 1 && weak_responses->rows != 1 || 13156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weak_responses->cols + weak_responses->rows - 1 != weak_count ) 13166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, 13176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "The output matrix of weak classifier responses must be valid " 13186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "floating-point vector of the same number of components as the length of input slice" ); 13196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn wstep = CV_IS_MAT_CONT(weak_responses->type) ? 1 : weak_responses->step/sizeof(float); 13206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 13216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn var_count = data->var_count; 13236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn vtype = data->var_type->data.i; 13246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cmap = data->cat_map->data.i; 13256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cofs = data->cat_ofs->data.i; 13266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // if need, preprocess the input vector 13286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !raw_mode && (data->cat_var_count > 0 || data->var_idx) ) 13296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 13306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int bufsize; 13316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int step, mstep = 0; 13326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float* src_sample; 13336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const uchar* src_mask = 0; 13346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float* dst_sample; 13356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn uchar* dst_mask; 13366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const int* vidx = data->var_idx && !raw_mode ? data->var_idx->data.i : 0; 13376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bool have_mask = _missing != 0; 13386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bufsize = var_count*(sizeof(float) + sizeof(uchar)); 13406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( bufsize <= CV_MAX_LOCAL_SIZE ) 13416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn buf = (float*)cvStackAlloc( bufsize ); 13426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 13436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 13446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( buf = (float*)cvAlloc( bufsize )); 13456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn allocated = true; 13466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 13476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dst_sample = buf; 13486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dst_mask = (uchar*)(buf + var_count); 13496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn src_sample = _sample->data.fl; 13516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(src_sample[0]); 13526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( _missing ) 13546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 13556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn src_mask = _missing->data.ptr; 13566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step; 13576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 13586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < var_count; i++ ) 13606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 13616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = vidx ? vidx[i] : i; 13626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float val = src_sample[idx*step]; 13636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int ci = vtype[i]; 13646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn uchar m = src_mask ? src_mask[i] : (uchar)0; 13656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ci >= 0 ) 13676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 13686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int a = cofs[ci], b = cofs[ci+1], c = a; 13696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int ival = cvRound(val); 13706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ival != val ) 13716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, 13726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "one of input categorical variable is not an integer" ); 13736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn while( a < b ) 13756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 13766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn c = (a + b) >> 1; 13776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ival < cmap[c] ) 13786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn b = c; 13796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( ival > cmap[c] ) 13806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn a = c+1; 13816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 13826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 13836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 13846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( c < 0 || ival != cmap[c] ) 13866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 13876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn m = 1; 13886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn have_mask = true; 13896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 13906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 13916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 13926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn val = (float)(c - cofs[ci]); 13936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 13946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 13956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 13966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dst_sample[i] = val; 13976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dst_mask[i] = m; 13986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 13996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sample = cvMat( 1, var_count, CV_32F, dst_sample ); 14016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _sample = &sample; 14026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( have_mask ) 14046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 14056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn missing = cvMat( 1, var_count, CV_8UC1, dst_mask ); 14066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _missing = &missing; 14076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 14086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 14096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartReadSeq( weak, &reader ); 14116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSetSeqReaderPos( &reader, slice.start_index ); 14126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < weak_count; i++ ) 14146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 14156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvBoostTree* wtree; 14166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double val; 14176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_READ_SEQ_ELEM( wtree, reader ); 14196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn val = wtree->predict( _sample, _missing, true )->value; 14216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( weak_responses ) 14226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weak_responses->data.fl[i*wstep] = (float)val; 14236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum += val; 14256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 14266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cls_idx = sum >= 0; 14286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( raw_mode ) 14296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn value = (float)cls_idx; 14306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 14316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn value = (float)cmap[cofs[vtype[var_count]] + cls_idx]; 14326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 14346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( allocated ) 14366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &buf ); 14376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return value; 14396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 14406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvBoost::write_params( CvFileStorage* fs ) 14446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 14456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvBoost::write_params" ); 14466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 14486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const char* boost_type_str = 14506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.boost_type == DISCRETE ? "DiscreteAdaboost" : 14516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.boost_type == REAL ? "RealAdaboost" : 14526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.boost_type == LOGIT ? "LogitBoost" : 14536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.boost_type == GENTLE ? "GentleAdaboost" : 0; 14546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const char* split_crit_str = 14566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.split_criteria == DEFAULT ? "Default" : 14576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.split_criteria == GINI ? "Gini" : 14586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.boost_type == MISCLASS ? "Misclassification" : 14596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.boost_type == SQERR ? "SquaredErr" : 0; 14606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( boost_type_str ) 14626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteString( fs, "boosting_type", boost_type_str ); 14636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 14646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "boosting_type", params.boost_type ); 14656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( split_crit_str ) 14676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteString( fs, "splitting_criteria", split_crit_str ); 14686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 14696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "splitting_criteria", params.split_criteria ); 14706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteInt( fs, "ntrees", params.weak_count ); 14726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvWriteReal( fs, "weight_trimming_rate", params.weight_trim_rate ); 14736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data->write_params( fs ); 14756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 14776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 14786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvBoost::read_params( CvFileStorage* fs, CvFileNode* fnode ) 14816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 14826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvBoost::read_params" ); 14836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 14856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvFileNode* temp; 14876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !fnode || !CV_NODE_IS_MAP(fnode->tag) ) 14896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return; 14906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data = new CvDTreeTrainData(); 14926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( data->read_params(fs, fnode)); 14936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn data->shared = true; 14946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 14956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.max_depth = data->params.max_depth; 14966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.min_sample_count = data->params.min_sample_count; 14976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.max_categories = data->params.max_categories; 14986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.priors = data->params.priors; 14996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.regression_accuracy = data->params.regression_accuracy; 15006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.use_surrogates = data->params.use_surrogates; 15016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn temp = cvGetFileNodeByName( fs, fnode, "boosting_type" ); 15036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !temp ) 15046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return; 15056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( temp && CV_NODE_IS_STRING(temp->tag) ) 15076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 15086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const char* boost_type_str = cvReadString( temp, "" ); 15096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.boost_type = strcmp( boost_type_str, "DiscreteAdaboost" ) == 0 ? DISCRETE : 15106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn strcmp( boost_type_str, "RealAdaboost" ) == 0 ? REAL : 15116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn strcmp( boost_type_str, "LogitBoost" ) == 0 ? LOGIT : 15126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn strcmp( boost_type_str, "GentleAdaboost" ) == 0 ? GENTLE : -1; 15136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 15146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 15156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.boost_type = cvReadInt( temp, -1 ); 15166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.boost_type < DISCRETE || params.boost_type > GENTLE ) 15186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Unknown boosting type" ); 15196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn temp = cvGetFileNodeByName( fs, fnode, "splitting_criteria" ); 15216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( temp && CV_NODE_IS_STRING(temp->tag) ) 15226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 15236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const char* split_crit_str = cvReadString( temp, "" ); 15246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.split_criteria = strcmp( split_crit_str, "Default" ) == 0 ? DEFAULT : 15256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn strcmp( split_crit_str, "Gini" ) == 0 ? GINI : 15266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn strcmp( split_crit_str, "Misclassification" ) == 0 ? MISCLASS : 15276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn strcmp( split_crit_str, "SquaredErr" ) == 0 ? SQERR : -1; 15286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 15296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 15306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.split_criteria = cvReadInt( temp, -1 ); 15316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.split_criteria < DEFAULT || params.boost_type > SQERR ) 15336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Unknown boosting type" ); 15346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.weak_count = cvReadIntByName( fs, fnode, "ntrees" ); 15366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.weight_trim_rate = cvReadRealByName( fs, fnode, "weight_trimming_rate", 0. ); 15376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 15396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 15406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid 15446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::read( CvFileStorage* fs, CvFileNode* node ) 15456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 15466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvRTrees::read" ); 15476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 15496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvSeqReader reader; 15516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvFileNode* trees_fnode; 15526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMemStorage* storage; 15536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, ntrees; 15546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn clear(); 15566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn read_params( fs, node ); 15576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !data ) 15596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn EXIT; 15606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn trees_fnode = cvGetFileNodeByName( fs, node, "trees" ); 15626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !trees_fnode || !CV_NODE_IS_SEQ(trees_fnode->tag) ) 15636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsParseError, "<trees> tag is missing" ); 15646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartReadSeq( trees_fnode->data.seq, &reader ); 15666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn ntrees = trees_fnode->data.seq->total; 15676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ntrees != params.weak_count ) 15696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsUnmatchedSizes, 15706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "The number of trees stored does not match <ntrees> tag value" ); 15716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( storage = cvCreateMemStorage() ); 15736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage ); 15746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < ntrees; i++ ) 15766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 15776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvBoostTree* tree = new CvBoostTree(); 15786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(tree->read( fs, (CvFileNode*)reader.ptr, this, data )); 15796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader ); 15806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSeqPush( weak, &tree ); 15816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 15826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 15846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 15856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid 15886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::write( CvFileStorage* fs, const char* name ) 15896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 15906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvBoost::write" ); 15916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 15936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvSeqReader reader; 15956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i; 15966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_BOOSTING ); 15986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 15996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !weak ) 16006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "The classifier has not been trained yet" ); 16016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn write_params( fs ); 16036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartWriteStruct( fs, "trees", CV_NODE_SEQ ); 16046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartReadSeq( weak, &reader ); 16066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < weak->total; i++ ) 16086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 16096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvBoostTree* tree; 16106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_READ_SEQ_ELEM( tree, reader ); 16116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvStartWriteStruct( fs, 0, CV_NODE_MAP ); 16126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn tree->write( fs ); 16136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvEndWriteStruct( fs ); 16146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 16156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvEndWriteStruct( fs ); 16176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvEndWriteStruct( fs ); 16186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 16206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 16216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvMat* 16246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::get_weights() 16256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 16266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return weights; 16276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 16286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvMat* 16316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::get_subtree_weights() 16326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 16336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return subtree_weights; 16346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 16356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvMat* 16386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::get_weak_response() 16396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 16406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return weak_eval; 16416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 16426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennconst CvBoostParams& 16456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvBoost::get_params() const 16466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 16476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return params; 16486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 16496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvSeq* CvBoost::get_weak_predictors() 16516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 16526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return weak; 16536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 16546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 16556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/* End of file. */ 1656