16acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/*M///////////////////////////////////////////////////////////////////////////////////////
26acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
36acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
46acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
56acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  By downloading, copying, installing or using the software you agree to this license.
66acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  If you do not agree to this license, do not download, install,
76acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  copy or use the software.
86acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
96acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//                        Intel License Agreement
116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Copyright (C) 2000, Intel Corporation, all rights reserved.
136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Third party copyrights are property of their respective owners.
146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Redistribution and use in source and binary forms, with or without modification,
166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// are permitted provided that the following conditions are met:
176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//   * Redistribution's of source code must retain the above copyright notice,
196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     this list of conditions and the following disclaimer.
206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//   * Redistribution's in binary form must reproduce the above copyright notice,
226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     this list of conditions and the following disclaimer in the documentation
236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     and/or other materials provided with the distribution.
246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//   * The name of Intel Corporation may not be used to endorse or promote products
266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     derived from this software without specific prior written permission.
276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// This software is provided by the copyright holders and contributors "as is" and
296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// any express or implied warranties, including, but not limited to, the implied
306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// warranties of merchantability and fitness for a particular purpose are disclaimed.
316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// In no event shall the Intel Corporation or contributors be liable for any direct,
326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// indirect, incidental, special, exemplary, or consequential damages
336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// (including, but not limited to, procurement of substitute goods or services;
346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// loss of use, data, or profits; or business interruption) however caused
356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// and on any theory of liability, whether in contract, strict liability,
366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// or tort (including negligence or otherwise) arising in any way out of
376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// the use of this software, even if advised of the possibility of such damage.
386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//M*/
406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#include "_ml.h"
426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic const float ord_nan = FLT_MAX*0.5f;
446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic const int min_block_size = 1 << 16;
456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic const int block_size_delta = 1 << 10;
466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeTrainData::CvDTreeTrainData()
486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    var_idx = var_type = cat_count = cat_ofs = cat_map =
506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        priors = priors_mult = counts = buf = direction = split_buf = 0;
516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    tree_storage = temp_storage = 0;
526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,
586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                      const CvMat* _responses, const CvMat* _var_idx,
596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                      const CvMat* _sample_idx, const CvMat* _var_type,
606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                      const CvMat* _missing_mask, const CvDTreeParams& _params,
616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                      bool _shared, bool _add_labels )
626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    var_idx = var_type = cat_count = cat_ofs = cat_map =
646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        priors = priors_mult = counts = buf = direction = split_buf = 0;
656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    tree_storage = temp_storage = 0;
666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,
686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn              _var_type, _missing_mask, _params, _shared, _add_labels );
696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeTrainData::~CvDTreeTrainData()
736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool CvDTreeTrainData::set_params( const CvDTreeParams& _params )
796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bool ok = false;
816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvDTreeTrainData::set_params" );
836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // set parameters
876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params = _params;
886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.max_categories < 2 )
906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" );
916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params.max_categories = MIN( params.max_categories, 15 );
926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.max_depth < 0 )
946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" );
956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params.max_depth = MIN( params.max_depth, 25 );
966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params.min_sample_count = MAX(params.min_sample_count,1);
986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.cv_folds < 0 )
1006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsOutOfRange,
1016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        "params.cv_folds should be =0 (the tree is not pruned) "
1026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        "or n>0 (tree is pruned using n-fold cross-validation)" );
1036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.cv_folds == 1 )
1056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.cv_folds = 0;
1066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.regression_accuracy < 0 )
1086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
1096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ok = true;
1116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
1136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return ok;
1156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
1166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
1196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )
1206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )
1216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#define CV_CMP_PAIRS(a,b) ((a).val < (b).val)
1236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennstatic CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair32s32f, CV_CMP_PAIRS, int )
1246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
1266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
1276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
1286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bool _shared, bool _add_labels, bool _update_data )
1296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
1306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* sample_idx = 0;
1316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* var_type0 = 0;
1326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* tmp_map = 0;
1336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int** int_ptr = 0;
1346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeTrainData* data = 0;
1356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvDTreeTrainData::set_data" );
1376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
1396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int sample_all = 0, r_type = 0, cv_n;
1416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int total_c_count = 0;
1426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
1436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
1446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int vi, i;
1456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    char err[100];
1466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int *sidx = 0, *vidx = 0;
1476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( _update_data && data_root )
1496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
1506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
1516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            _sample_idx, _var_type, _missing_mask, _params, _shared, _add_labels );
1526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // compare new and old train data
1546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !(data->var_count == var_count &&
1556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON &&
1566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvNorm( data->cat_count, cat_count, CV_C ) < FLT_EPSILON &&
1576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvNorm( data->cat_map, cat_map, CV_C ) < FLT_EPSILON) )
1586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadArg,
1596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            "The new training data must have the same types and the input and output variables "
1606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            "and the same categories for categorical variables" );
1616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReleaseMat( &priors );
1636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReleaseMat( &priors_mult );
1646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReleaseMat( &buf );
1656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReleaseMat( &direction );
1666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReleaseMat( &split_buf );
1676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReleaseMemStorage( &temp_storage );
1686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        priors = data->priors; data->priors = 0;
1706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        priors_mult = data->priors_mult; data->priors_mult = 0;
1716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        buf = data->buf; data->buf = 0;
1726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        buf_count = data->buf_count; buf_size = data->buf_size;
1736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sample_count = data->sample_count;
1746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        direction = data->direction; data->direction = 0;
1766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        split_buf = data->split_buf; data->split_buf = 0;
1776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        temp_storage = data->temp_storage; data->temp_storage = 0;
1786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        nv_heap = data->nv_heap; cv_heap = data->cv_heap;
1796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        data_root = new_node( 0, sample_count, 0, 0 );
1816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        EXIT;
1826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
1836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
1856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    var_all = 0;
1876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    rng = cvRNG(-1);
1886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( set_params( _params ));
1906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // check parameter types and sizes
1926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
1936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( _tflag == CV_ROW_SAMPLE )
1946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
1956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
1966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        dv_step = 1;
1976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( _missing_mask )
1986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            ms_step = _missing_mask->step, mv_step = 1;
1996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
2006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
2016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
2026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
2036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        ds_step = 1;
2046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( _missing_mask )
2056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            mv_step = _missing_mask->step, ms_step = 1;
2066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
2076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    sample_count = sample_all;
2096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    var_count = var_all;
2106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( _sample_idx )
2126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
2136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( sample_idx = cvPreprocessIndexArray( _sample_idx, sample_all ));
2146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sidx = sample_idx->data.i;
2156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sample_count = sample_idx->rows + sample_idx->cols - 1;
2166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
2176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( _var_idx )
2196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
2206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
2216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        vidx = var_idx->data.i;
2226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        var_count = var_idx->rows + var_idx->cols - 1;
2236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
2246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !CV_IS_MAT(_responses) ||
2266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
2276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn         CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
2286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        _responses->rows != 1 && _responses->cols != 1 ||
2296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        _responses->rows + _responses->cols - 1 != sample_all )
2306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
2316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                  "floating-point vector containing as many elements as "
2326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                  "the total number of samples in the training data matrix" );
2336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_all, &r_type ));
2356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
2366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cat_var_count = 0;
2386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ord_var_count = -1;
2396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    is_classifier = r_type == CV_VAR_CATEGORICAL;
2416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // step 0. calc the number of categorical vars
2436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( vi = 0; vi < var_count; vi++ )
2446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
2456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?
2466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cat_var_count++ : ord_var_count--;
2476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
2486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ord_var_count = ~ord_var_count;
2506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cv_n = params.cv_folds;
2516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // set the two last elements of var_type array to be able
2526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // to locate responses and cross-validation labels using
2536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // the corresponding get_* functions.
2546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    var_type->data.i[var_count] = cat_var_count;
2556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    var_type->data.i[var_count+1] = cat_var_count+1;
2566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // in case of single ordered predictor we need dummy cv_labels
2586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // for safe split_node_data() operation
2596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    have_labels = cv_n > 0 || ord_var_count == 1 && cat_var_count == 0 || _add_labels;
2606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    buf_size = (ord_var_count + get_work_var_count())*sample_count + 2;
2626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    shared = _shared;
2636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    buf_count = shared ? 3 : 2;
2646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));
2656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( cat_count = cvCreateMat( 1, cat_var_count+1, CV_32SC1 ));
2666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( cat_ofs = cvCreateMat( 1, cat_count->cols+1, CV_32SC1 ));
2676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( cat_map = cvCreateMat( 1, cat_count->cols*10 + 128, CV_32SC1 ));
2686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // now calculate the maximum size of split,
2706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // create memory storage that will keep nodes and splits of the decision tree
2716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // allocate root node and the buffer for the whole training data
2726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    max_split_size = cvAlign(sizeof(CvDTreeSplit) +
2736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
2746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
2756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
2766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
2776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
2786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    nv_size = var_count*sizeof(int);
2806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    nv_size = MAX( nv_size, (int)sizeof(CvSetElem) );
2816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    temp_block_size = nv_size;
2836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( cv_n )
2856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
2866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( sample_count < cv_n*MAX(params.min_sample_count,10) )
2876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsOutOfRange,
2886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                "The many folds in cross-validation for such a small dataset" );
2896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
2916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        temp_block_size = MAX(temp_block_size, cv_size);
2926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
2936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
2956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
2966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
2976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( cv_size )
2986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
2996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
3016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
3026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    max_c_count = 1;
3046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // transform the training data to convenient representation
3066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( vi = 0; vi <= var_count; vi++ )
3076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int ci;
3096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const uchar* mask = 0;
3106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int m_step = 0, step;
3116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const int* idata = 0;
3126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const float* fdata = 0;
3136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int num_valid = 0;
3146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( vi < var_count ) // analyze i-th input variable
3166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
3176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int vi0 = vidx ? vidx[vi] : vi;
3186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            ci = get_var_type(vi);
3196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            step = ds_step; m_step = ms_step;
3206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
3216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                idata = _train_data->data.i + vi0*dv_step;
3226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else
3236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                fdata = _train_data->data.fl + vi0*dv_step;
3246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( _missing_mask )
3256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                mask = _missing_mask->data.ptr + vi0*mv_step;
3266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
3276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else // analyze _responses
3286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
3296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            ci = cat_var_count;
3306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            step = CV_IS_MAT_CONT(_responses->type) ?
3316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                1 : _responses->step / CV_ELEM_SIZE(_responses->type);
3326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
3336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                idata = _responses->data.i;
3346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else
3356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                fdata = _responses->data.fl;
3366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
3376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( vi < var_count && ci >= 0 ||
3396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            vi == var_count && is_classifier ) // process categorical variable or response
3406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
3416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int c_count, prev_label;
3426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int* c_map, *dst = get_cat_var_data( data_root, vi );
3436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // copy data
3456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < sample_count; i++ )
3466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
3476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int val = INT_MAX, si = sidx ? sidx[i] : i;
3486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( !mask || !mask[si*m_step] )
3496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
3506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( idata )
3516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        val = idata[si*step];
3526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    else
3536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    {
3546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        float t = fdata[si*step];
3556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        val = cvRound(t);
3566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        if( val != t )
3576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        {
3586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            sprintf( err, "%d-th value of %d-th (categorical) "
3596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                "variable is not an integer", i, vi );
3606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            CV_ERROR( CV_StsBadArg, err );
3616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        }
3626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    }
3636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( val == INT_MAX )
3656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    {
3666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        sprintf( err, "%d-th value of %d-th (categorical) "
3676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            "variable is too large", i, vi );
3686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        CV_ERROR( CV_StsBadArg, err );
3696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    }
3706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    num_valid++;
3716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
3726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dst[i] = val;
3736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int_ptr[i] = dst + i;
3746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
3756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // sort all the values, including the missing measurements
3776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // that should all move to the end
3786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            icvSortIntPtr( int_ptr, sample_count, 0 );
3796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            //qsort( int_ptr, sample_count, sizeof(int_ptr[0]), icvCmpIntPtr );
3806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            c_count = num_valid > 0;
3826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // count the categories
3846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 1; i < num_valid; i++ )
3856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                c_count += *int_ptr[i] != *int_ptr[i-1];
3866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( vi > 0 )
3886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                max_c_count = MAX( max_c_count, c_count );
3896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cat_count->data.i[ci] = c_count;
3906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cat_ofs->data.i[ci] = total_c_count;
3916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // resize cat_map, if need
3936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( cat_map->cols < total_c_count + c_count )
3946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
3956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                tmp_map = cat_map;
3966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CV_CALL( cat_map = cvCreateMat( 1,
3976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
3986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( i = 0; i < total_c_count; i++ )
3996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    cat_map->data.i[i] = tmp_map->data.i[i];
4006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvReleaseMat( &tmp_map );
4016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
4026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            c_map = cat_map->data.i + total_c_count;
4046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            total_c_count += c_count;
4056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // compact the class indices and build the map
4076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            prev_label = ~*int_ptr[0];
4086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            c_count = -1;
4096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < num_valid; i++ )
4116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
4126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int cur_label = *int_ptr[i];
4136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( cur_label != prev_label )
4146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    c_map[++c_count] = prev_label = cur_label;
4156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                *int_ptr[i] = c_count;
4166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
4176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // replace labels for missing values with -1
4196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( ; i < sample_count; i++ )
4206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                *int_ptr[i] = -1;
4216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
4226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else if( ci < 0 ) // process ordered variable
4236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
4246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CvPair32s32f* dst = get_ord_var_data( data_root, vi );
4256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < sample_count; i++ )
4276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
4286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                float val = ord_nan;
4296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int si = sidx ? sidx[i] : i;
4306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( !mask || !mask[si*m_step] )
4316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
4326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( idata )
4336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        val = (float)idata[si*step];
4346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    else
4356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        val = fdata[si*step];
4366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( fabs(val) >= ord_nan )
4386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    {
4396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        sprintf( err, "%d-th value of %d-th (ordered) "
4406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            "variable (=%g) is too large", i, vi, val );
4416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        CV_ERROR( CV_StsBadArg, err );
4426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    }
4436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    num_valid++;
4446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
4456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dst[i].i = i;
4466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dst[i].val = val;
4476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
4486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            icvSortPairs( dst, sample_count, 0 );
4506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
4516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else // special case: process ordered response,
4526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn             // it will be stored similarly to categorical vars (i.e. no pairs)
4536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
4546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            float* dst = get_ord_responses( data_root );
4556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < sample_count; i++ )
4576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
4586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                float val = ord_nan;
4596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int si = sidx ? sidx[i] : i;
4606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( idata )
4616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    val = (float)idata[si*step];
4626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                else
4636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    val = fdata[si*step];
4646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( fabs(val) >= ord_nan )
4666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
4676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    sprintf( err, "%d-th value of %d-th (ordered) "
4686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        "variable (=%g) is out of range", i, vi, val );
4696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    CV_ERROR( CV_StsBadArg, err );
4706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
4716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dst[i] = val;
4726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
4736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cat_count->data.i[cat_var_count] = 0;
4756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cat_ofs->data.i[cat_var_count] = total_c_count;
4766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            num_valid = sample_count;
4776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
4786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( vi < var_count )
4806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            data_root->set_num_valid(vi, num_valid);
4816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
4826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( cv_n )
4846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
4856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int* dst = get_labels(data_root);
4866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvRNG* r = &rng;
4876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = vi = 0; i < sample_count; i++ )
4896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
4906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            dst[i] = vi++;
4916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            vi &= vi < cv_n ? -1 : 0;
4926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
4936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < sample_count; i++ )
4956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
4966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int a = cvRandInt(r) % sample_count;
4976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int b = cvRandInt(r) % sample_count;
4986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_SWAP( dst[a], dst[b], vi );
4996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
5006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
5016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cat_map->cols = MAX( total_c_count, 1 );
5036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    max_split_size = cvAlign(sizeof(CvDTreeSplit) +
5056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
5066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
5076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    have_priors = is_classifier && params.priors;
5096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( is_classifier )
5106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
5116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int m = get_num_classes();
5126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double sum = 0;
5136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
5146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < m; i++ )
5156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
5166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double val = have_priors ? params.priors[i] : 1.;
5176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( val <= 0 )
5186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
5196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            priors->data.db[i] = val;
5206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            sum += val;
5216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
5226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // normalize weights
5246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( have_priors )
5256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvScale( priors, priors, 1./sum );
5266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( priors_mult = cvCloneMat( priors ));
5286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
5296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
5306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
5326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
5336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
5356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( data )
5376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        delete data;
5386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( &int_ptr );
5406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &sample_idx );
5416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &var_type0 );
5426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &tmp_map );
5436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
5446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
5476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
5486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeNode* root = 0;
5496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* isubsample_idx = 0;
5506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* subsample_co = 0;
5516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );
5536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
5556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !data_root )
5576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsError, "No training data has been set" );
5586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( _subsample_idx )
5606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
5616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !isubsample_idx )
5636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
5646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // make a copy of the root node
5656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvDTreeNode temp;
5666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int i;
5676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        root = new_node( 0, 1, 0, 0 );
5686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        temp = *root;
5696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        *root = *data_root;
5706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        root->num_valid = temp.num_valid;
5716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( root->num_valid )
5726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
5736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < var_count; i++ )
5746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                root->num_valid[i] = data_root->num_valid[i];
5756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
5766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        root->cv_Tn = temp.cv_Tn;
5776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        root->cv_node_risk = temp.cv_node_risk;
5786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        root->cv_node_error = temp.cv_node_error;
5796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
5806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
5816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
5826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int* sidx = isubsample_idx->data.i;
5836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)
5846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int* co, cur_ofs = 0;
5856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int vi, i, total = data_root->sample_count;
5866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int count = isubsample_idx->rows + isubsample_idx->cols - 1;
5876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int work_var_count = get_work_var_count();
5886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        root = new_node( 0, count, 1, 0 );
5896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( subsample_co = cvCreateMat( 1, total*2, CV_32SC1 ));
5916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvZero( subsample_co );
5926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        co = subsample_co->data.i;
5936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < count; i++ )
5946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            co[sidx[i]*2]++;
5956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < total; i++ )
5966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
5976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( co[i*2] )
5986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
5996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                co[i*2+1] = cur_ofs;
6006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cur_ofs += co[i*2];
6016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
6026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else
6036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                co[i*2+1] = -1;
6046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
6056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( vi = 0; vi < work_var_count; vi++ )
6076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
6086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int ci = get_var_type(vi);
6096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( ci >= 0 || vi >= var_count )
6116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
6126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                const int* src = get_cat_var_data( data_root, vi );
6136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int* dst = get_cat_var_data( root, vi );
6146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int num_valid = 0;
6156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( i = 0; i < count; i++ )
6176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
6186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    int val = src[sidx[i]];
6196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    dst[i] = val;
6206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    num_valid += val >= 0;
6216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
6226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( vi < var_count )
6246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    root->set_num_valid(vi, num_valid);
6256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
6266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else
6276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
6286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                const CvPair32s32f* src = get_ord_var_data( data_root, vi );
6296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CvPair32s32f* dst = get_ord_var_data( root, vi );
6306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int j = 0, idx, count_i;
6316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int num_valid = data_root->get_num_valid(vi);
6326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( i = 0; i < num_valid; i++ )
6346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
6356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    idx = src[i].i;
6366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    count_i = co[idx*2];
6376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( count_i )
6386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    {
6396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        float val = src[i].val;
6406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
6416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        {
6426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            dst[j].val = val;
6436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            dst[j].i = cur_ofs;
6446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        }
6456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    }
6466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
6476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                root->set_num_valid(vi, j);
6496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( ; i < total; i++ )
6516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
6526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    idx = src[i].i;
6536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    count_i = co[idx*2];
6546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( count_i )
6556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    {
6566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        float val = src[i].val;
6576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
6586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        {
6596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            dst[j].val = val;
6606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            dst[j].i = cur_ofs;
6616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        }
6626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    }
6636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
6646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
6656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
6666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
6676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
6696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &isubsample_idx );
6716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &subsample_co );
6726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return root;
6746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
6756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,
6786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                    float* values, uchar* missing,
6796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                    float* responses, bool get_class_idx )
6806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
6816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* subsample_idx = 0;
6826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* subsample_co = 0;
6836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );
6856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
6876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, vi, total = sample_count, count = total, cur_ofs = 0;
6896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int* sidx = 0;
6906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int* co = 0;
6916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( _subsample_idx )
6936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
6946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
6956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sidx = subsample_idx->data.i;
6966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
6976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        co = subsample_co->data.i;
6986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvZero( subsample_co );
6996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        count = subsample_idx->cols + subsample_idx->rows - 1;
7006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < count; i++ )
7016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            co[sidx[i]*2]++;
7026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < total; i++ )
7036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
7046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int count_i = co[i*2];
7056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( count_i )
7066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
7076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                co[i*2+1] = cur_ofs*var_count;
7086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cur_ofs += count_i;
7096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
7106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
7116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
7126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( missing )
7146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        memset( missing, 1, count*var_count );
7156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( vi = 0; vi < var_count; vi++ )
7176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
7186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int ci = get_var_type(vi);
7196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( ci >= 0 ) // categorical
7206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
7216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            float* dst = values + vi;
7226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            uchar* m = missing ? missing + vi : 0;
7236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const int* src = get_cat_var_data(data_root, vi);
7246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < count; i++, dst += var_count )
7266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
7276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int idx = sidx ? sidx[i] : i;
7286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int val = src[idx];
7296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                *dst = (float)val;
7306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( m )
7316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
7326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    *m = val < 0;
7336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    m += var_count;
7346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
7356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
7366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
7376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else // ordered
7386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
7396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            float* dst = values + vi;
7406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            uchar* m = missing ? missing + vi : 0;
7416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const CvPair32s32f* src = get_ord_var_data(data_root, vi);
7426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int count1 = data_root->get_num_valid(vi);
7436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < count1; i++ )
7456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
7466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int idx = src[i].i;
7476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int count_i = 1;
7486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( co )
7496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
7506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    count_i = co[idx*2];
7516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    cur_ofs = co[idx*2+1];
7526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
7536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                else
7546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    cur_ofs = idx*var_count;
7556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( count_i )
7566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
7576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    float val = src[i].val;
7586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    for( ; count_i > 0; count_i--, cur_ofs += var_count )
7596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    {
7606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        dst[cur_ofs] = val;
7616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        if( m )
7626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            m[cur_ofs] = 0;
7636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    }
7646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
7656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
7666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
7676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
7686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // copy responses
7706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( responses )
7716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
7726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( is_classifier )
7736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
7746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const int* src = get_class_labels(data_root);
7756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < count; i++ )
7766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
7776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int idx = sidx ? sidx[i] : i;
7786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int val = get_class_idx ? src[idx] :
7796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
7806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                responses[i] = (float)val;
7816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
7826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
7836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
7846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
7856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const float* src = get_ord_responses(data_root);
7866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < count; i++ )
7876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
7886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int idx = sidx ? sidx[i] : i;
7896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                responses[i] = src[idx];
7906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
7916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
7926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
7936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
7956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &subsample_idx );
7976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &subsample_co );
7986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
7996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count,
8026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                         int storage_idx, int offset )
8036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
8046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );
8056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->sample_count = count;
8076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->depth = parent ? parent->depth + 1 : 0;
8086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->parent = parent;
8096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->left = node->right = 0;
8106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->split = 0;
8116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->value = 0;
8126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->class_idx = 0;
8136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->maxlr = 0.;
8146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->buf_idx = storage_idx;
8166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->offset = offset;
8176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( nv_heap )
8186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->num_valid = (int*)cvSetNew( nv_heap );
8196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
8206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->num_valid = 0;
8216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;
8226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->complexity = 0;
8236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.cv_folds > 0 && cv_heap )
8256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
8266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int cv_n = params.cv_folds;
8276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->Tn = INT_MAX;
8286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->cv_Tn = (int*)cvSetNew( cv_heap );
8296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));
8306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->cv_node_error = node->cv_node_risk + cv_n;
8316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
8326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
8336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
8346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->Tn = 0;
8356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->cv_Tn = 0;
8366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->cv_node_risk = 0;
8376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->cv_node_error = 0;
8386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
8396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return node;
8416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
8426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val,
8456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int split_point, int inversed, float quality )
8466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
8476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
8486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    split->var_idx = vi;
8496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    split->ord.c = cmp_val;
8506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    split->ord.split_point = split_point;
8516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    split->inversed = inversed;
8526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    split->quality = quality;
8536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    split->next = 0;
8546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return split;
8566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
8576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality )
8606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
8616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
8626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, n = (max_c_count + 31)/32;
8636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    split->var_idx = vi;
8656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    split->inversed = 0;
8666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    split->quality = quality;
8676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < n; i++ )
8686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        split->subset[i] = 0;
8696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    split->next = 0;
8706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return split;
8726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
8736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTreeTrainData::free_node( CvDTreeNode* node )
8766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
8776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeSplit* split = node->split;
8786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    free_node_data( node );
8796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    while( split )
8806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
8816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvDTreeSplit* next = split->next;
8826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvSetRemoveByPtr( split_heap, split );
8836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        split = next;
8846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
8856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->split = 0;
8866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvSetRemoveByPtr( node_heap, node );
8876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
8886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTreeTrainData::free_node_data( CvDTreeNode* node )
8916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
8926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( node->num_valid )
8936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
8946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvSetRemoveByPtr( nv_heap, node->num_valid );
8956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->num_valid = 0;
8966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
8976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // do not free cv_* fields, as all the cross-validation related data is released at once.
8986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
8996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTreeTrainData::free_train_data()
9026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
9036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &counts );
9046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &buf );
9056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &direction );
9066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &split_buf );
9076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMemStorage( &temp_storage );
9086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cv_heap = nv_heap = 0;
9096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
9106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTreeTrainData::clear()
9136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
9146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    free_train_data();
9156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMemStorage( &tree_storage );
9176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &var_idx );
9196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &var_type );
9206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &cat_count );
9216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &cat_ofs );
9226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &cat_map );
9236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &priors );
9246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &priors_mult );
9256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node_heap = split_heap = 0;
9276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;
9296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    have_labels = have_priors = is_classifier = false;
9306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    buf_count = buf_size = 0;
9326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    shared = false;
9336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data_root = 0;
9356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    rng = cvRNG(-1);
9376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
9386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvDTreeTrainData::get_num_classes() const
9416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
9426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return is_classifier ? cat_count->data.i[cat_var_count] : 0;
9436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
9446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvDTreeTrainData::get_var_type(int vi) const
9476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
9486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return var_type->data.i[vi];
9496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
9506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvDTreeTrainData::get_work_var_count() const
9536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
9546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return var_count + 1 + (have_labels ? 1 : 0);
9556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
9566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvPair32s32f* CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi )
9586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
9596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int oi = ~get_var_type(vi);
9606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    assert( 0 <= oi && oi < ord_var_count );
9616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return (CvPair32s32f*)(buf->data.i + n->buf_idx*buf->cols +
9626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                           n->offset + oi*n->sample_count*2);
9636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
9646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint* CvDTreeTrainData::get_class_labels( CvDTreeNode* n )
9676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
9686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return get_cat_var_data( n, var_count );
9696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
9706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennfloat* CvDTreeTrainData::get_ord_responses( CvDTreeNode* n )
9736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
9746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return (float*)get_cat_var_data( n, var_count );
9756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
9766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint* CvDTreeTrainData::get_labels( CvDTreeNode* n )
9796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
9806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return have_labels ? get_cat_var_data( n, var_count + 1 ) : 0;
9816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
9826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint* CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi )
9856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
9866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int ci = get_var_type(vi);
9876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    assert( 0 <= ci && ci <= cat_var_count + 1 );
9886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return buf->data.i + n->buf_idx*buf->cols + n->offset +
9896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn           (ord_var_count*2 + ci)*n->sample_count;
9906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
9916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )
9946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
9956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int idx = n->buf_idx + 1;
9966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( idx >= buf_count )
9976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        idx = shared ? 1 : 0;
9986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return idx;
9996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
10006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTreeTrainData::write_params( CvFileStorage* fs )
10036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
10046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvDTreeTrainData::write_params" );
10056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
10076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int vi, vcount = var_count;
10096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteInt( fs, "is_classifier", is_classifier ? 1 : 0 );
10116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteInt( fs, "var_all", var_all );
10126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteInt( fs, "var_count", var_count );
10136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteInt( fs, "ord_var_count", ord_var_count );
10146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteInt( fs, "cat_var_count", cat_var_count );
10156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
10176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteInt( fs, "use_surrogates", params.use_surrogates ? 1 : 0 );
10186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( is_classifier )
10206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
10216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteInt( fs, "max_categories", params.max_categories );
10226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
10236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
10246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
10256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteReal( fs, "regression_accuracy", params.regression_accuracy );
10266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
10276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteInt( fs, "max_depth", params.max_depth );
10296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteInt( fs, "min_sample_count", params.min_sample_count );
10306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteInt( fs, "cross_validation_folds", params.cv_folds );
10316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.cv_folds > 1 )
10336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
10346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteInt( fs, "use_1se_rule", params.use_1se_rule ? 1 : 0 );
10356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteInt( fs, "truncate_pruned_tree", params.truncate_pruned_tree ? 1 : 0 );
10366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
10376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( priors )
10396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWrite( fs, "priors", priors );
10406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvEndWriteStruct( fs );
10426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( var_idx )
10446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWrite( fs, "var_idx", var_idx );
10456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );
10476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( vi = 0; vi < vcount; vi++ )
10496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteInt( fs, 0, var_type->data.i[vi] >= 0 );
10506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvEndWriteStruct( fs );
10526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( cat_count && (cat_var_count > 0 || is_classifier) )
10546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
10556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ASSERT( cat_count != 0 );
10566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWrite( fs, "cat_count", cat_count );
10576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWrite( fs, "cat_map", cat_map );
10586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
10596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
10616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
10626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node )
10656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
10666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvDTreeTrainData::read_params" );
10676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
10696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvFileNode *tparams_node, *vartype_node;
10716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvSeqReader reader;
10726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int vi, max_split_size, tree_block_size;
10736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);
10756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    var_all = cvReadIntByName( fs, node, "var_all" );
10766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    var_count = cvReadIntByName( fs, node, "var_count", var_all );
10776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
10786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
10796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
10816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( tparams_node ) // training parameters are not necessary
10836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
10846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;
10856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( is_classifier )
10876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
10886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );
10896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
10906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
10916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
10926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            params.regression_accuracy =
10936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );
10946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
10956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );
10976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );
10986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );
10996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( params.cv_folds > 1 )
11016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
11026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;
11036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            params.truncate_pruned_tree =
11046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;
11056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
11066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );
11086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( priors )
11096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
11106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( !CV_IS_MAT(priors) )
11116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );
11126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            priors_mult = cvCloneMat( priors );
11136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
11146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
11156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
11176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( var_idx )
11186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
11196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !CV_IS_MAT(var_idx) ||
11206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            var_idx->cols != 1 && var_idx->rows != 1 ||
11216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            var_idx->cols + var_idx->rows - 1 != var_count ||
11226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_MAT_TYPE(var_idx->type) != CV_32SC1 )
11236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsParseError,
11246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );
11256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( vi = 0; vi < var_count; vi++ )
11276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( (unsigned)var_idx->data.i[vi] >= (unsigned)var_all )
11286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );
11296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
11306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ////// read var type
11326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 ));
11336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cat_var_count = 0;
11356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ord_var_count = -1;
11366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
11376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && var_count == 1 )
11396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--;
11406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
11416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
11426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||
11436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            vartype_node->data.seq->total != var_count )
11446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
11456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvStartReadSeq( vartype_node->data.seq, &reader );
11476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( vi = 0; vi < var_count; vi++ )
11496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
11506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CvFileNode* n = (CvFileNode*)reader.ptr;
11516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )
11526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
11536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;
11546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
11556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
11566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
11576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    var_type->data.i[var_count] = cat_var_count;
11586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ord_var_count = ~ord_var_count;
11606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( cat_var_count != cat_var_count || ord_var_count != ord_var_count )
11616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" );
11626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    //////
11636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( cat_var_count > 0 || is_classifier )
11656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
11666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int ccount, total_c_count = 0;
11676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
11686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
11696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !CV_IS_MAT(cat_count) || !CV_IS_MAT(cat_map) ||
11716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cat_count->cols != 1 && cat_count->rows != 1 ||
11726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_MAT_TYPE(cat_count->type) != CV_32SC1 ||
11736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cat_count->cols + cat_count->rows - 1 != cat_var_count + is_classifier ||
11746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cat_map->cols != 1 && cat_map->rows != 1 ||
11756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_MAT_TYPE(cat_map->type) != CV_32SC1 )
11766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsParseError,
11776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );
11786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        ccount = cat_var_count + is_classifier;
11806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));
11826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cat_ofs->data.i[0] = 0;
11836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        max_c_count = 1;
11846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( vi = 0; vi < ccount; vi++ )
11866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
11876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int val = cat_count->data.i[vi];
11886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( val <= 0 )
11896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );
11906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            max_c_count = MAX( max_c_count, val );
11916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cat_ofs->data.i[vi+1] = total_c_count += val;
11926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
11936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( cat_map->cols + cat_map->rows - 1 != total_c_count )
11956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadSize,
11966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            "cat_map vector length is not equal to the total number of categories in all categorical vars" );
11976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
11986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    max_split_size = cvAlign(sizeof(CvDTreeSplit) +
12006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
12016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
12036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
12046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
12056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( node_heap = cvCreateSet( 0, sizeof(node_heap[0]),
12066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            sizeof(CvDTreeNode), tree_storage ));
12076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( split_heap = cvCreateSet( 0, sizeof(split_heap[0]),
12086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            max_split_size, tree_storage ));
12096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
12116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
12126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/////////////////////// Decision Tree /////////////////////////
12156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTree::CvDTree()
12176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
12186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data = 0;
12196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    var_importance = 0;
12206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    default_model_name = "my_tree";
12216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
12236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
12246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::clear()
12276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
12286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &var_importance );
12296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( data )
12306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
12316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !data->shared )
12326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            delete data;
12336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
12346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            free_tree();
12356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        data = 0;
12366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
12376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    root = 0;
12386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    pruned_tree_idx = -1;
12396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
12406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTree::~CvDTree()
12436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
12446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
12456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
12466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennconst CvDTreeNode* CvDTree::get_root() const
12496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
12506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return root;
12516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
12526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvDTree::get_pruned_tree_idx() const
12556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
12566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return pruned_tree_idx;
12576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
12586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeTrainData* CvDTree::get_data()
12616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
12626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return data;
12636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
12646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool CvDTree::train( const CvMat* _train_data, int _tflag,
12676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                     const CvMat* _responses, const CvMat* _var_idx,
12686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                     const CvMat* _sample_idx, const CvMat* _var_type,
12696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                     const CvMat* _missing_mask, CvDTreeParams _params )
12706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
12716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bool result = false;
12726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvDTree::train" );
12746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
12766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
12786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data = new CvDTreeTrainData( _train_data, _tflag, _responses,
12796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                 _var_idx, _sample_idx, _var_type,
12806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                 _missing_mask, _params, false );
12816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( result = do_train(0));
12826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
12846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return result;
12866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
12876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )
12906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
12916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bool result = false;
12926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvDTree::train" );
12946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
12966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
12976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
12986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data = _data;
12996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data->shared = true;
13006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( result = do_train(_subsample_idx));
13016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
13036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return result;
13056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
13066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool CvDTree::do_train( const CvMat* _subsample_idx )
13096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
13106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bool result = false;
13116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvDTree::do_train" );
13136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
13156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    root = data->subsample_data( _subsample_idx );
13176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( try_split_node(root));
13196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( data->params.cv_folds > 0 )
13216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( prune_cv());
13226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !data->shared )
13246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        data->free_train_data();
13256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    result = true;
13276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
13296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return result;
13316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
13326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::try_split_node( CvDTreeNode* node )
13356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
13366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeSplit* best_split = 0;
13376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, n = node->sample_count, vi;
13386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bool can_split = true;
13396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double quality_scale;
13406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    calc_node_value( node );
13426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( node->sample_count <= data->params.min_sample_count ||
13446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->depth >= data->params.max_depth )
13456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        can_split = false;
13466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( can_split && data->is_classifier )
13486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
13496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // check if we have a "pure" node,
13506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // we assume that cls_count is filled by calc_node_value()
13516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int* cls_count = data->counts->data.i;
13526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int nz = 0, m = data->get_num_classes();
13536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < m; i++ )
13546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            nz += cls_count[i] != 0;
13556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( nz == 1 ) // there is only one class
13566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            can_split = false;
13576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
13586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else if( can_split )
13596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
13606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( sqrt(node->node_risk)/n < data->params.regression_accuracy )
13616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            can_split = false;
13626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
13636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( can_split )
13656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
13666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        best_split = find_best_split(node);
13676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // TODO: check the split quality ...
13686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->split = best_split;
13696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
13706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !can_split || !best_split )
13726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
13736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        data->free_node_data(node);
13746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        return;
13756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
13766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    quality_scale = calc_node_dir( node );
13786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( data->params.use_surrogates )
13806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
13816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // find all the surrogate splits
13826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // and sort them by their similarity to the primary one
13836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( vi = 0; vi < data->var_count; vi++ )
13846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
13856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CvDTreeSplit* split;
13866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int ci = data->get_var_type(vi);
13876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( vi == best_split->var_idx )
13896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                continue;
13906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( ci >= 0 )
13926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                split = find_surrogate_split_cat( node, vi );
13936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else
13946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                split = find_surrogate_split_ord( node, vi );
13956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
13966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( split )
13976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
13986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                // insert the split
13996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CvDTreeSplit* prev_split = node->split;
14006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                split->quality = (float)(split->quality*quality_scale);
14016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                while( prev_split->next &&
14036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                       prev_split->next->quality > split->quality )
14046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    prev_split = prev_split->next;
14056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                split->next = prev_split->next;
14066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                prev_split->next = split;
14076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
14086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
14096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
14106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    split_node_data( node );
14126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    try_split_node( node->left );
14136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    try_split_node( node->right );
14146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
14156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// calculate direction (left(-1),right(1),missing(0))
14186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// for each sample using the best split
14196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// the function returns scale coefficients for surrogate split quality factors.
14206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// the scale is applied to normalize surrogate split quality relatively to the
14216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// best (primary) split quality. That is, if a surrogate split is absolutely
14226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// identical to the primary split, its quality will be set to the maximum value =
14236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// quality of the primary split; otherwise, it will be lower.
14246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// besides, the function compute node->maxlr,
14256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// minimum possible quality (w/o considering the above mentioned scale)
14266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// for a surrogate split. Surrogate splits with quality less than node->maxlr
14276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// are not discarded.
14286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renndouble CvDTree::calc_node_dir( CvDTreeNode* node )
14296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
14306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    char* dir = (char*)data->direction->data.ptr;
14316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, n = node->sample_count, vi = node->split->var_idx;
14326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double L, R;
14336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    assert( !node->split->inversed );
14356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( data->get_var_type(vi) >= 0 ) // split on categorical var
14376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
14386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const int* labels = data->get_cat_var_data(node,vi);
14396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const int* subset = node->split->subset;
14406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !data->have_priors )
14426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
14436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int sum = 0, sum_abs = 0;
14446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < n; i++ )
14466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
14476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int idx = labels[i];
14486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
14496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                sum += d; sum_abs += d & 1;
14506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dir[i] = (char)d;
14516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
14526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            R = (sum_abs + sum) >> 1;
14546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            L = (sum_abs - sum) >> 1;
14556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
14566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
14576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
14586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const int* responses = data->get_class_labels(node);
14596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const double* priors = data->priors_mult->data.db;
14606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double sum = 0, sum_abs = 0;
14616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < n; i++ )
14636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
14646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int idx = labels[i];
14656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double w = priors[responses[i]];
14666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
14676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                sum += d*w; sum_abs += (d & 1)*w;
14686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dir[i] = (char)d;
14696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
14706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            R = (sum_abs + sum) * 0.5;
14726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            L = (sum_abs - sum) * 0.5;
14736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
14746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
14756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else // split on ordered var
14766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
14776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const CvPair32s32f* sorted = data->get_ord_var_data(node,vi);
14786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int split_point = node->split->ord.split_point;
14796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int n1 = node->get_num_valid(vi);
14806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        assert( 0 <= split_point && split_point < n1-1 );
14826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !data->have_priors )
14846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
14856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i <= split_point; i++ )
14866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dir[sorted[i].i] = (char)-1;
14876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( ; i < n1; i++ )
14886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dir[sorted[i].i] = (char)1;
14896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( ; i < n; i++ )
14906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dir[sorted[i].i] = (char)0;
14916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
14926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            L = split_point-1;
14936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            R = n1 - split_point + 1;
14946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
14956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
14966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
14976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const int* responses = data->get_class_labels(node);
14986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const double* priors = data->priors_mult->data.db;
14996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            L = R = 0;
15006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i <= split_point; i++ )
15026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
15036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int idx = sorted[i].i;
15046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double w = priors[responses[idx]];
15056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dir[idx] = (char)-1;
15066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                L += w;
15076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
15086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( ; i < n1; i++ )
15106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
15116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int idx = sorted[i].i;
15126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double w = priors[responses[idx]];
15136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dir[idx] = (char)1;
15146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                R += w;
15156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
15166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( ; i < n; i++ )
15186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dir[sorted[i].i] = (char)0;
15196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
15206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
15216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->maxlr = MAX( L, R );
15236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return node->split->quality/(L + R);
15246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
15256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )
15286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
15296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int vi;
15306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeSplit *best_split = 0, *split = 0, *t;
15316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( vi = 0; vi < data->var_count; vi++ )
15336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
15346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int ci = data->get_var_type(vi);
15356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( node->get_num_valid(vi) <= 1 )
15366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            continue;
15376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( data->is_classifier )
15396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
15406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( ci >= 0 )
15416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                split = find_split_cat_class( node, vi );
15426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else
15436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                split = find_split_ord_class( node, vi );
15446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
15456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
15466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
15476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( ci >= 0 )
15486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                split = find_split_cat_reg( node, vi );
15496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else
15506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                split = find_split_ord_reg( node, vi );
15516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
15526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( split )
15546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
15556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( !best_split || best_split->quality < split->quality )
15566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CV_SWAP( best_split, split, t );
15576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( split )
15586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvSetRemoveByPtr( data->split_heap, split );
15596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
15606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
15616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return best_split;
15636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
15646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi )
15676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
15686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const float epsilon = FLT_EPSILON*2;
15696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
15706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int* responses = data->get_class_labels(node);
15716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n = node->sample_count;
15726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n1 = node->get_num_valid(vi);
15736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int m = data->get_num_classes();
15746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int* rc0 = data->counts->data.i;
15756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));
15766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));
15776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, best_i = -1;
15786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double lsum2 = 0, rsum2 = 0, best_val = 0;
15796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
15806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // init arrays of class instance counters on both sides of the split
15826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < m; i++ )
15836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
15846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        lc[i] = 0;
15856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        rc[i] = rc0[i];
15866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
15876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // compensate for missing values
15896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = n1; i < n; i++ )
15906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        rc[responses[sorted[i].i]]--;
15916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !priors )
15936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
15946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int L = 0, R = n1;
15956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < m; i++ )
15976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            rsum2 += (double)rc[i]*rc[i];
15986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
15996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n1 - 1; i++ )
16006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
16016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int idx = responses[sorted[i].i];
16026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int lv, rv;
16036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            L++; R--;
16046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            lv = lc[idx]; rv = rc[idx];
16056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            lsum2 += lv*2 + 1;
16066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            rsum2 -= rv*2 - 1;
16076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            lc[idx] = lv + 1; rc[idx] = rv - 1;
16086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( sorted[i].val + epsilon < sorted[i+1].val )
16106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
16116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double val = (lsum2*R + rsum2*L)/((double)L*R);
16126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( best_val < val )
16136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
16146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    best_val = val;
16156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    best_i = i;
16166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
16176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
16186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
16196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
16206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
16216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
16226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double L = 0, R = 0;
16236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < m; i++ )
16246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
16256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double wv = rc[i]*priors[i];
16266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            R += wv;
16276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            rsum2 += wv*wv;
16286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
16296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n1 - 1; i++ )
16316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
16326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int idx = responses[sorted[i].i];
16336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int lv, rv;
16346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double p = priors[idx], p2 = p*p;
16356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            L += p; R -= p;
16366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            lv = lc[idx]; rv = rc[idx];
16376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            lsum2 += p2*(lv*2 + 1);
16386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            rsum2 -= p2*(rv*2 - 1);
16396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            lc[idx] = lv + 1; rc[idx] = rv - 1;
16406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( sorted[i].val + epsilon < sorted[i+1].val )
16426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
16436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double val = (lsum2*R + rsum2*L)/((double)L*R);
16446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( best_val < val )
16456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
16466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    best_val = val;
16476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    best_i = i;
16486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
16496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
16506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
16516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
16526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return best_i >= 0 ? data->new_split_ord( vi,
16546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
16556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        0, (float)best_val ) : 0;
16566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
16576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::cluster_categories( const int* vectors, int n, int m,
16606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                int* csums, int k, int* labels )
16616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
16626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // TODO: consider adding priors (class weights) and sample weights to the clustering algorithm
16636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int iters = 0, max_iters = 100;
16646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, j, idx;
16656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* buf = (double*)cvStackAlloc( (n + k)*sizeof(buf[0]) );
16666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double *v_weights = buf, *c_weights = buf + k;
16676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bool modified = true;
16686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvRNG* r = &data->rng;
16696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // assign labels randomly
16716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = idx = 0; i < n; i++ )
16726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
16736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int sum = 0;
16746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const int* v = vectors + i*m;
16756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        labels[i] = idx++;
16766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        idx &= idx < k ? -1 : 0;
16776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // compute weight of each vector
16796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( j = 0; j < m; j++ )
16806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            sum += v[j];
16816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        v_weights[i] = sum ? 1./sum : 0.;
16826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
16836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < n; i++ )
16856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
16866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int i1 = cvRandInt(r) % n;
16876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int i2 = cvRandInt(r) % n;
16886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_SWAP( labels[i1], labels[i2], j );
16896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
16906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
16916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( iters = 0; iters <= max_iters; iters++ )
16926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
16936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // calculate csums
16946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < k; i++ )
16956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
16966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j < m; j++ )
16976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                csums[i*m + j] = 0;
16986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
16996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
17006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n; i++ )
17016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
17026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const int* v = vectors + i*m;
17036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int* s = csums + labels[i]*m;
17046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j < m; j++ )
17056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                s[j] += v[j];
17066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
17076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
17086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // exit the loop here, when we have up-to-date csums
17096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( iters == max_iters || !modified )
17106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            break;
17116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
17126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        modified = false;
17136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
17146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // calculate weight of each cluster
17156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < k; i++ )
17166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
17176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const int* s = csums + i*m;
17186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int sum = 0;
17196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j < m; j++ )
17206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                sum += s[j];
17216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            c_weights[i] = sum ? 1./sum : 0;
17226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
17236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
17246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // now for each vector determine the closest cluster
17256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n; i++ )
17266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
17276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const int* v = vectors + i*m;
17286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double alpha = v_weights[i];
17296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double min_dist2 = DBL_MAX;
17306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int min_idx = -1;
17316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
17326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( idx = 0; idx < k; idx++ )
17336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
17346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                const int* s = csums + idx*m;
17356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double dist2 = 0., beta = c_weights[idx];
17366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( j = 0; j < m; j++ )
17376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
17386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    double t = v[j]*alpha - s[j]*beta;
17396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    dist2 += t*t;
17406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
17416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( min_dist2 > dist2 )
17426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
17436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    min_dist2 = dist2;
17446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    min_idx = idx;
17456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
17466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
17476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
17486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( min_idx != labels[i] )
17496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                modified = true;
17506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            labels[i] = min_idx;
17516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
17526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
17536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
17546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
17556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
17566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi )
17576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
17586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeSplit* split;
17596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int* labels = data->get_cat_var_data(node, vi);
17606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int* responses = data->get_class_labels(node);
17616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int ci = data->get_var_type(vi);
17626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n = node->sample_count;
17636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int m = data->get_num_classes();
17646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int _mi = data->cat_count->data.i[ci], mi = _mi;
17656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));
17666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));
17676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int* _cjk = (int*)cvStackAlloc(m*(mi+1)*sizeof(_cjk[0]))+m, *cjk = _cjk;
17686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* c_weights = (double*)cvStackAlloc( mi*sizeof(c_weights[0]) );
17696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int* cluster_labels = 0;
17706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int** int_ptr = 0;
17716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, j, k, idx;
17726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double L = 0, R = 0;
17736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double best_val = 0;
17746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
17756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const double* priors = data->priors_mult->data.db;
17766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
17776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // init array of counters:
17786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // c_{jk} - number of samples that have vi-th input variable = j and response = k.
17796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( j = -1; j < mi; j++ )
17806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( k = 0; k < m; k++ )
17816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cjk[j*m + k] = 0;
17826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
17836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < n; i++ )
17846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
17856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        j = labels[i];
17866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        k = responses[i];
17876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cjk[j*m + k]++;
17886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
17896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
17906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( m > 2 )
17916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
17926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( mi > data->params.max_categories )
17936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
17946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            mi = MIN(data->params.max_categories, n);
17956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cjk += _mi*m;
17966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cluster_labels = (int*)cvStackAlloc(mi*sizeof(cluster_labels[0]));
17976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );
17986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
17996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        subset_i = 1;
18006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        subset_n = 1 << mi;
18016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
18026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
18036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
18046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        assert( m == 2 );
18056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int_ptr = (int**)cvStackAlloc( mi*sizeof(int_ptr[0]) );
18066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( j = 0; j < mi; j++ )
18076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int_ptr[j] = cjk + j*2 + 1;
18086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        icvSortIntPtr( int_ptr, mi, 0 );
18096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        subset_i = 0;
18106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        subset_n = mi;
18116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
18126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
18136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( k = 0; k < m; k++ )
18146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
18156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int sum = 0;
18166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( j = 0; j < mi; j++ )
18176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            sum += cjk[j*m + k];
18186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        rc[k] = sum;
18196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        lc[k] = 0;
18206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
18216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
18226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( j = 0; j < mi; j++ )
18236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
18246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double sum = 0;
18256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( k = 0; k < m; k++ )
18266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            sum += cjk[j*m + k]*priors[k];
18276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        c_weights[j] = sum;
18286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        R += c_weights[j];
18296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
18306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
18316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( ; subset_i < subset_n; subset_i++ )
18326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
18336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double weight;
18346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int* crow;
18356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double lsum2 = 0, rsum2 = 0;
18366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
18376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( m == 2 )
18386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            idx = (int)(int_ptr[subset_i] - cjk)/2;
18396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
18406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
18416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int graycode = (subset_i>>1)^subset_i;
18426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int diff = graycode ^ prevcode;
18436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
18446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // determine index of the changed bit.
18456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            Cv32suf u;
18466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            idx = diff >= (1 << 16) ? 16 : 0;
18476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            u.f = (float)(((diff >> 16) | diff) & 65535);
18486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            idx += (u.i >> 23) - 127;
18496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            subtract = graycode < prevcode;
18506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            prevcode = graycode;
18516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
18526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
18536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        crow = cjk + idx*m;
18546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        weight = c_weights[idx];
18556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( weight < FLT_EPSILON )
18566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            continue;
18576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
18586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !subtract )
18596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
18606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( k = 0; k < m; k++ )
18616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
18626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int t = crow[k];
18636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int lval = lc[k] + t;
18646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int rval = rc[k] - t;
18656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double p = priors[k], p2 = p*p;
18666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                lsum2 += p2*lval*lval;
18676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                rsum2 += p2*rval*rval;
18686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                lc[k] = lval; rc[k] = rval;
18696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
18706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            L += weight;
18716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            R -= weight;
18726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
18736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
18746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
18756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( k = 0; k < m; k++ )
18766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
18776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int t = crow[k];
18786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int lval = lc[k] - t;
18796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int rval = rc[k] + t;
18806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double p = priors[k], p2 = p*p;
18816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                lsum2 += p2*lval*lval;
18826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                rsum2 += p2*rval*rval;
18836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                lc[k] = lval; rc[k] = rval;
18846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
18856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            L -= weight;
18866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            R += weight;
18876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
18886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
18896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( L > FLT_EPSILON && R > FLT_EPSILON )
18906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
18916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double val = (lsum2*R + rsum2*L)/((double)L*R);
18926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( best_val < val )
18936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
18946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                best_val = val;
18956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                best_subset = subset_i;
18966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
18976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
18986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
18996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
19006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( best_subset < 0 )
19016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        return 0;
19026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
19036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    split = data->new_split_cat( vi, (float)best_val );
19046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
19056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( m == 2 )
19066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
19076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i <= best_subset; i++ )
19086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
19096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            idx = (int)(int_ptr[i] - cjk) >> 1;
19106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            split->subset[idx >> 5] |= 1 << (idx & 31);
19116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
19126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
19136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
19146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
19156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < _mi; i++ )
19166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
19176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            idx = cluster_labels ? cluster_labels[i] : i;
19186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( best_subset & (1 << idx) )
19196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                split->subset[i >> 5] |= 1 << (i & 31);
19206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
19216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
19226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
19236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return split;
19246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
19256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
19266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
19276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi )
19286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
19296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const float epsilon = FLT_EPSILON*2;
19306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
19316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const float* responses = data->get_ord_responses(node);
19326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n = node->sample_count;
19336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n1 = node->get_num_valid(vi);
19346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, best_i = -1;
19356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double best_val = 0, lsum = 0, rsum = node->value*n;
19366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int L = 0, R = n1;
19376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
19386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // compensate for missing values
19396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = n1; i < n; i++ )
19406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        rsum -= responses[sorted[i].i];
19416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
19426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // find the optimal split
19436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < n1 - 1; i++ )
19446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
19456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        float t = responses[sorted[i].i];
19466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        L++; R--;
19476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        lsum += t;
19486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        rsum -= t;
19496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
19506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( sorted[i].val + epsilon < sorted[i+1].val )
19516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
19526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
19536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( best_val < val )
19546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
19556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                best_val = val;
19566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                best_i = i;
19576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
19586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
19596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
19606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
19616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return best_i >= 0 ? data->new_split_ord( vi,
19626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
19636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        0, (float)best_val ) : 0;
19646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
19656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
19666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
19676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi )
19686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
19696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeSplit* split;
19706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int* labels = data->get_cat_var_data(node, vi);
19716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const float* responses = data->get_ord_responses(node);
19726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int ci = data->get_var_type(vi);
19736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n = node->sample_count;
19746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int mi = data->cat_count->data.i[ci];
19756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;
19766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int* counts = (int*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;
19776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double** sum_ptr = 0;
19786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, L = 0, R = 0;
19796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double best_val = 0, lsum = 0, rsum = 0;
19806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int best_subset = -1, subset_i;
19816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
19826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = -1; i < mi; i++ )
19836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sum[i] = counts[i] = 0;
19846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
19856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // calculate sum response and weight of each category of the input var
19866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < n; i++ )
19876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
19886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int idx = labels[i];
19896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double s = sum[idx] + responses[i];
19906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int nc = counts[idx] + 1;
19916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sum[idx] = s;
19926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        counts[idx] = nc;
19936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
19946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
19956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // calculate average response in each category
19966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < mi; i++ )
19976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
19986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        R += counts[i];
19996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        rsum += sum[i];
20006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sum[i] /= MAX(counts[i],1);
20016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sum_ptr[i] = sum + i;
20026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
20036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
20046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    icvSortDblPtr( sum_ptr, mi, 0 );
20056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
20066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // revert back to unnormalized sums
20076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // (there should be a very little loss of accuracy)
20086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < mi; i++ )
20096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sum[i] *= counts[i];
20106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
20116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( subset_i = 0; subset_i < mi-1; subset_i++ )
20126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
20136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int idx = (int)(sum_ptr[subset_i] - sum);
20146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int ni = counts[idx];
20156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
20166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( ni )
20176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
20186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double s = sum[idx];
20196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            lsum += s; L += ni;
20206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            rsum -= s; R -= ni;
20216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
20226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( L && R )
20236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
20246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
20256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( best_val < val )
20266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
20276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    best_val = val;
20286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    best_subset = subset_i;
20296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
20306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
20316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
20326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
20336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
20346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( best_subset < 0 )
20356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        return 0;
20366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
20376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    split = data->new_split_cat( vi, (float)best_val );
20386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i <= best_subset; i++ )
20396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
20406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int idx = (int)(sum_ptr[i] - sum);
20416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        split->subset[idx >> 5] |= 1 << (idx & 31);
20426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
20436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
20446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return split;
20456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
20466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
20476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
20486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )
20496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
20506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const float epsilon = FLT_EPSILON*2;
20516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
20526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const char* dir = (char*)data->direction->data.ptr;
20536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n1 = node->get_num_valid(vi);
20546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // LL - number of samples that both the primary and the surrogate splits send to the left
20556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // LR - ... primary split sends to the left and the surrogate split sends to the right
20566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // RL - ... primary split sends to the right and the surrogate split sends to the left
20576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // RR - ... both send to the right
20586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, best_i = -1, best_inversed = 0;
20596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double best_val;
20606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
20616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !data->have_priors )
20626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
20636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int LL = 0, RL = 0, LR, RR;
20646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int worst_val = cvFloor(node->maxlr), _best_val = worst_val;
20656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int sum = 0, sum_abs = 0;
20666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
20676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n1; i++ )
20686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
20696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int d = dir[sorted[i].i];
20706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            sum += d; sum_abs += d & 1;
20716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
20726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
20736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // sum_abs = R + L; sum = R - L
20746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        RR = (sum_abs + sum) >> 1;
20756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        LR = (sum_abs - sum) >> 1;
20766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
20776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // initially all the samples are sent to the right by the surrogate split,
20786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // LR of them are sent to the left by primary split, and RR - to the right.
20796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
20806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n1 - 1; i++ )
20816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
20826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int d = dir[sorted[i].i];
20836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
20846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( d < 0 )
20856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
20866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                LL++; LR--;
20876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( LL + RR > _best_val && sorted[i].val + epsilon < sorted[i+1].val )
20886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
20896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    best_val = LL + RR;
20906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    best_i = i; best_inversed = 0;
20916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
20926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
20936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else if( d > 0 )
20946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
20956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                RL++; RR--;
20966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( RL + LR > _best_val && sorted[i].val + epsilon < sorted[i+1].val )
20976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
20986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    best_val = RL + LR;
20996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    best_i = i; best_inversed = 1;
21006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
21016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
21026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
21036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        best_val = _best_val;
21046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
21056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
21066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
21076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double LL = 0, RL = 0, LR, RR;
21086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double worst_val = node->maxlr;
21096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double sum = 0, sum_abs = 0;
21106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const double* priors = data->priors_mult->data.db;
21116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const int* responses = data->get_class_labels(node);
21126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        best_val = worst_val;
21136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
21146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n1; i++ )
21156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
21166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int idx = sorted[i].i;
21176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double w = priors[responses[idx]];
21186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int d = dir[idx];
21196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            sum += d*w; sum_abs += (d & 1)*w;
21206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
21216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
21226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // sum_abs = R + L; sum = R - L
21236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        RR = (sum_abs + sum)*0.5;
21246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        LR = (sum_abs - sum)*0.5;
21256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
21266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // initially all the samples are sent to the right by the surrogate split,
21276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // LR of them are sent to the left by primary split, and RR - to the right.
21286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
21296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n1 - 1; i++ )
21306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
21316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int idx = sorted[i].i;
21326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double w = priors[responses[idx]];
21336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int d = dir[idx];
21346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
21356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( d < 0 )
21366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
21376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                LL += w; LR -= w;
21386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( LL + RR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
21396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
21406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    best_val = LL + RR;
21416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    best_i = i; best_inversed = 0;
21426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
21436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
21446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else if( d > 0 )
21456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
21466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                RL += w; RR -= w;
21476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( RL + LR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
21486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
21496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    best_val = RL + LR;
21506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    best_i = i; best_inversed = 1;
21516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
21526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
21536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
21546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
21556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
21566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
21576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
21586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        best_inversed, (float)best_val ) : 0;
21596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
21606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
21616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
21626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )
21636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
21646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int* labels = data->get_cat_var_data(node, vi);
21656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const char* dir = (char*)data->direction->data.ptr;
21666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int n = node->sample_count;
21676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // LL - number of samples that both the primary and the surrogate splits send to the left
21686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // LR - ... primary split sends to the left and the surrogate split sends to the right
21696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // RL - ... primary split sends to the right and the surrogate split sends to the left
21706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // RR - ... both send to the right
21716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeSplit* split = data->new_split_cat( vi, 0 );
21726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, mi = data->cat_count->data.i[data->get_var_type(vi)], l_win = 0;
21736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double best_val = 0;
21746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;
21756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* rc = lc + mi + 1;
21766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
21776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = -1; i < mi; i++ )
21786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        lc[i] = rc[i] = 0;
21796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
21806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // for each category calculate the weight of samples
21816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // sent to the left (lc) and to the right (rc) by the primary split
21826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !data->have_priors )
21836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
21846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int* _lc = (int*)cvStackAlloc((mi+2)*2*sizeof(_lc[0])) + 1;
21856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int* _rc = _lc + mi + 1;
21866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
21876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = -1; i < mi; i++ )
21886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            _lc[i] = _rc[i] = 0;
21896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
21906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n; i++ )
21916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
21926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int idx = labels[i];
21936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int d = dir[i];
21946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int sum = _lc[idx] + d;
21956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int sum_abs = _rc[idx] + (d & 1);
21966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            _lc[idx] = sum; _rc[idx] = sum_abs;
21976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
21986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
21996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < mi; i++ )
22006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
22016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int sum = _lc[i];
22026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int sum_abs = _rc[i];
22036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            lc[i] = (sum_abs - sum) >> 1;
22046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            rc[i] = (sum_abs + sum) >> 1;
22056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
22066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
22076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
22086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
22096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const double* priors = data->priors_mult->data.db;
22106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const int* responses = data->get_class_labels(node);
22116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
22126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n; i++ )
22136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
22146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int idx = labels[i];
22156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double w = priors[responses[i]];
22166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int d = dir[i];
22176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double sum = lc[idx] + d*w;
22186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double sum_abs = rc[idx] + (d & 1)*w;
22196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            lc[idx] = sum; rc[idx] = sum_abs;
22206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
22216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
22226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < mi; i++ )
22236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
22246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double sum = lc[i];
22256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double sum_abs = rc[i];
22266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            lc[i] = (sum_abs - sum) * 0.5;
22276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            rc[i] = (sum_abs + sum) * 0.5;
22286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
22296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
22306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
22316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // 2. now form the split.
22326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // in each category send all the samples to the same direction as majority
22336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < mi; i++ )
22346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
22356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double lval = lc[i], rval = rc[i];
22366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( lval > rval )
22376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
22386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            split->subset[i >> 5] |= 1 << (i & 31);
22396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            best_val += lval;
22406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            l_win++;
22416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
22426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
22436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            best_val += rval;
22446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
22456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
22466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    split->quality = (float)best_val;
22476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( split->quality <= node->maxlr || l_win == 0 || l_win == mi )
22486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvSetRemoveByPtr( data->split_heap, split ), split = 0;
22496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
22506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return split;
22516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
22526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
22536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
22546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::calc_node_value( CvDTreeNode* node )
22556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
22566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;
22576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int* cv_labels = data->get_labels(node);
22586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
22596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( data->is_classifier )
22606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
22616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // in case of classification tree:
22626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        //  * node value is the label of the class that has the largest weight in the node.
22636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        //  * node risk is the weighted number of misclassified samples,
22646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        //  * j-th cross-validation fold value and risk are calculated as above,
22656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        //    but using the samples with cv_labels(*)!=j.
22666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        //  * j-th cross-validation fold error is calculated as the weighted number of
22676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        //    misclassified samples with cv_labels(*)==j.
22686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
22696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // compute the number of instances of each class
22706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int* cls_count = data->counts->data.i;
22716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const int* responses = data->get_class_labels(node);
22726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int m = data->get_num_classes();
22736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int* cv_cls_count = (int*)cvStackAlloc(m*cv_n*sizeof(cv_cls_count[0]));
22746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double max_val = -1, total_weight = 0;
22756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int max_k = -1;
22766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double* priors = data->priors_mult->data.db;
22776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
22786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( k = 0; k < m; k++ )
22796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cls_count[k] = 0;
22806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
22816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( cv_n == 0 )
22826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
22836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < n; i++ )
22846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cls_count[responses[i]]++;
22856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
22866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
22876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
22886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j < cv_n; j++ )
22896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( k = 0; k < m; k++ )
22906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    cv_cls_count[j*m + k] = 0;
22916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
22926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < n; i++ )
22936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
22946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                j = cv_labels[i]; k = responses[i];
22956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cv_cls_count[j*m + k]++;
22966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
22976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
22986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j < cv_n; j++ )
22996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( k = 0; k < m; k++ )
23006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    cls_count[k] += cv_cls_count[j*m + k];
23016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
23026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
23036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( data->have_priors && node->parent == 0 )
23046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
23056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // compute priors_mult from priors, take the sample ratio into account.
23066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double sum = 0;
23076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( k = 0; k < m; k++ )
23086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
23096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int n_k = cls_count[k];
23106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                priors[k] = data->priors->data.db[k]*(n_k ? 1./n_k : 0.);
23116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                sum += priors[k];
23126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
23136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            sum = 1./sum;
23146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( k = 0; k < m; k++ )
23156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                priors[k] *= sum;
23166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
23176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
23186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( k = 0; k < m; k++ )
23196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
23206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double val = cls_count[k]*priors[k];
23216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            total_weight += val;
23226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( max_val < val )
23236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
23246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                max_val = val;
23256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                max_k = k;
23266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
23276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
23286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
23296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->class_idx = max_k;
23306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->value = data->cat_map->data.i[
23316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            data->cat_ofs->data.i[data->cat_var_count] + max_k];
23326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->node_risk = total_weight - max_val;
23336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
23346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( j = 0; j < cv_n; j++ )
23356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
23366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double sum_k = 0, sum = 0, max_val_k = 0;
23376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            max_val = -1; max_k = -1;
23386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
23396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( k = 0; k < m; k++ )
23406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
23416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double w = priors[k];
23426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double val_k = cv_cls_count[j*m + k]*w;
23436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double val = cls_count[k]*w - val_k;
23446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                sum_k += val_k;
23456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                sum += val;
23466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( max_val < val )
23476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
23486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    max_val = val;
23496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    max_val_k = val_k;
23506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    max_k = k;
23516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
23526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
23536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
23546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            node->cv_Tn[j] = INT_MAX;
23556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            node->cv_node_risk[j] = sum - max_val;
23566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            node->cv_node_error[j] = sum_k - max_val_k;
23576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
23586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
23596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
23606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
23616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // in case of regression tree:
23626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
23636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        //    n is the number of samples in the node.
23646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
23656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        //  * j-th cross-validation fold value and risk are calculated as above,
23666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        //    but using the samples with cv_labels(*)!=j.
23676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        //  * j-th cross-validation fold error is calculated
23686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        //    using samples with cv_labels(*)==j as the test subset:
23696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        //    error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
23706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        //    where node_value_j is the node value calculated
23716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        //    as described in the previous bullet, and summation is done
23726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        //    over the samples with cv_labels(*)==j.
23736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
23746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double sum = 0, sum2 = 0;
23756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const float* values = data->get_ord_responses(node);
23766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double *cv_sum = 0, *cv_sum2 = 0;
23776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int* cv_count = 0;
23786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
23796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( cv_n == 0 )
23806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
23816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < n; i++ )
23826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
23836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double t = values[i];
23846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                sum += t;
23856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                sum2 += t*t;
23866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
23876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
23886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
23896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
23906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cv_sum = (double*)cvStackAlloc( cv_n*sizeof(cv_sum[0]) );
23916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cv_sum2 = (double*)cvStackAlloc( cv_n*sizeof(cv_sum2[0]) );
23926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cv_count = (int*)cvStackAlloc( cv_n*sizeof(cv_count[0]) );
23936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
23946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j < cv_n; j++ )
23956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
23966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cv_sum[j] = cv_sum2[j] = 0.;
23976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cv_count[j] = 0;
23986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
23996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
24006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < n; i++ )
24016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
24026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                j = cv_labels[i];
24036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double t = values[i];
24046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double s = cv_sum[j] + t;
24056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double s2 = cv_sum2[j] + t*t;
24066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int nc = cv_count[j] + 1;
24076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cv_sum[j] = s;
24086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cv_sum2[j] = s2;
24096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cv_count[j] = nc;
24106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
24116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
24126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j < cv_n; j++ )
24136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
24146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                sum += cv_sum[j];
24156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                sum2 += cv_sum2[j];
24166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
24176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
24186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
24196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->node_risk = sum2 - (sum/n)*sum;
24206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->value = sum/n;
24216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
24226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( j = 0; j < cv_n; j++ )
24236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
24246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double s = cv_sum[j], si = sum - s;
24256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double s2 = cv_sum2[j], s2i = sum2 - s2;
24266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int c = cv_count[j], ci = n - c;
24276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double r = si/MAX(ci,1);
24286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            node->cv_node_risk[j] = s2i - r*r*ci;
24296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            node->cv_node_error[j] = s2 - 2*r*s + c*r*r;
24306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            node->cv_Tn[j] = INT_MAX;
24316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
24326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
24336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
24346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
24356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
24366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::complete_node_dir( CvDTreeNode* node )
24376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
24386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
24396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int nz = n - node->get_num_valid(node->split->var_idx);
24406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    char* dir = (char*)data->direction->data.ptr;
24416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
24426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // try to complete direction using surrogate splits
24436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( nz && data->params.use_surrogates )
24446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
24456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvDTreeSplit* split = node->split->next;
24466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( ; split != 0 && nz; split = split->next )
24476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
24486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int inversed_mask = split->inversed ? -1 : 0;
24496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            vi = split->var_idx;
24506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
24516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( data->get_var_type(vi) >= 0 ) // split on categorical var
24526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
24536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                const int* labels = data->get_cat_var_data(node, vi);
24546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                const int* subset = split->subset;
24556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
24566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( i = 0; i < n; i++ )
24576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
24586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    int idx;
24596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( !dir[i] && (idx = labels[i]) >= 0 )
24606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    {
24616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        int d = CV_DTREE_CAT_DIR(idx,subset);
24626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
24636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        if( --nz )
24646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            break;
24656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    }
24666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
24676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
24686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else // split on ordered var
24696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
24706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
24716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int split_point = split->ord.split_point;
24726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int n1 = node->get_num_valid(vi);
24736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
24746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                assert( 0 <= split_point && split_point < n-1 );
24756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
24766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( i = 0; i < n1; i++ )
24776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
24786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    int idx = sorted[i].i;
24796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( !dir[idx] )
24806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    {
24816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        int d = i <= split_point ? -1 : 1;
24826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        dir[idx] = (char)((d ^ inversed_mask) - inversed_mask);
24836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        if( --nz )
24846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            break;
24856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    }
24866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
24876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
24886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
24896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
24906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
24916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // find the default direction for the rest
24926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( nz )
24936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
24946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = nr = 0; i < n; i++ )
24956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            nr += dir[i] > 0;
24966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        nl = n - nr - nz;
24976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        d0 = nl > nr ? -1 : nr > nl;
24986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
24996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
25006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // make sure that every sample is directed either to the left or to the right
25016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < n; i++ )
25026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
25036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int d = dir[i];
25046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !d )
25056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
25066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            d = d0;
25076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( !d )
25086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                d = d1, d1 = -d1;
25096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
25106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        d = d > 0;
25116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        dir[i] = (char)d; // remap (-1,1) to (0,1)
25126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
25136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
25146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
25156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
25166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::split_node_data( CvDTreeNode* node )
25176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
25186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int vi, i, n = node->sample_count, nl, nr;
25196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    char* dir = (char*)data->direction->data.ptr;
25206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeNode *left = 0, *right = 0;
25216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int* new_idx = data->split_buf->data.i;
25226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int new_buf_idx = data->get_child_buf_idx( node );
25236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int work_var_count = data->get_work_var_count();
25246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
25256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // speedup things a little, especially for tree ensembles with a lots of small trees:
25266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    //   do not physically split the input data between the left and right child nodes
25276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    //   when we are not going to split them further,
25286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    //   as calc_node_value() does not requires input features anyway.
25296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bool split_input_data;
25306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
25316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    complete_node_dir(node);
25326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
25336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = nl = nr = 0; i < n; i++ )
25346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
25356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int d = dir[i];
25366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // initialize new indices for splitting ordered variables
25376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li
25386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        nr += d;
25396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        nl += d^1;
25406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
25416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
25426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
25436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->right = right = data->new_node( node, nr, new_buf_idx, node->offset +
25446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        (data->ord_var_count + work_var_count)*nl );
25456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
25466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    split_input_data = node->depth + 1 < data->params.max_depth &&
25476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        (node->left->sample_count > data->params.min_sample_count ||
25486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->right->sample_count > data->params.min_sample_count);
25496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
25506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // split ordered variables, keep both halves sorted.
25516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( vi = 0; vi < data->var_count; vi++ )
25526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
25536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int ci = data->get_var_type(vi);
25546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int n1 = node->get_num_valid(vi);
25556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvPair32s32f *src, *ldst0, *rdst0, *ldst, *rdst;
25566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvPair32s32f tl, tr;
25576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
25586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( ci >= 0 || !split_input_data )
25596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            continue;
25606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
25616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        src = data->get_ord_var_data(node, vi);
25626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        ldst0 = ldst = data->get_ord_var_data(left, vi);
25636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        rdst0 = rdst = data->get_ord_var_data(right, vi);
25646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        tl = ldst0[nl]; tr = rdst0[nr];
25656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
25666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // split sorted
25676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n1; i++ )
25686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
25696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int idx = src[i].i;
25706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            float val = src[i].val;
25716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int d = dir[idx];
25726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            idx = new_idx[idx];
25736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            ldst->i = rdst->i = idx;
25746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            ldst->val = rdst->val = val;
25756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            ldst += d^1;
25766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            rdst += d;
25776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
25786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
25796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        left->set_num_valid(vi, (int)(ldst - ldst0));
25806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        right->set_num_valid(vi, (int)(rdst - rdst0));
25816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
25826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // split missing
25836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( ; i < n; i++ )
25846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
25856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int idx = src[i].i;
25866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int d = dir[idx];
25876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            idx = new_idx[idx];
25886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            ldst->i = rdst->i = idx;
25896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            ldst->val = rdst->val = ord_nan;
25906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            ldst += d^1;
25916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            rdst += d;
25926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
25936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
25946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        ldst0[nl] = tl; rdst0[nr] = tr;
25956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
25966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
25976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // split categorical vars, responses and cv_labels using new_idx relocation table
25986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( vi = 0; vi < work_var_count; vi++ )
25996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
26006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int ci = data->get_var_type(vi);
26016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int n1 = node->get_num_valid(vi), nr1 = 0;
26026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int *src, *ldst0, *rdst0, *ldst, *rdst;
26036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int tl, tr;
26046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
26056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( ci < 0 || (vi < data->var_count && !split_input_data) )
26066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            continue;
26076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
26086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        src = data->get_cat_var_data(node, vi);
26096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        ldst0 = ldst = data->get_cat_var_data(left, vi);
26106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        rdst0 = rdst = data->get_cat_var_data(right, vi);
26116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        tl = ldst0[nl]; tr = rdst0[nr];
26126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
26136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n; i++ )
26146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
26156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int d = dir[i];
26166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int val = src[i];
26176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            *ldst = *rdst = val;
26186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            ldst += d^1;
26196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            rdst += d;
26206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            nr1 += (val >= 0)&d;
26216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
26226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
26236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( vi < data->var_count )
26246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
26256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            left->set_num_valid(vi, n1 - nr1);
26266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            right->set_num_valid(vi, nr1);
26276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
26286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
26296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        ldst0[nl] = tl; rdst0[nr] = tr;
26306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
26316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
26326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // deallocate the parent node data that is not needed anymore
26336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data->free_node_data(node);
26346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
26356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
26366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
26376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::prune_cv()
26386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
26396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* ab = 0;
26406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* temp = 0;
26416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* err_jk = 0;
26426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
26436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
26446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // 2. choose the best tree index (if need, apply 1SE rule).
26456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // 3. store the best index and cut the branches.
26466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
26476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvDTree::prune_cv" );
26486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
26496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
26506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
26516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count;
26526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // currently, 1SE for regression is not implemented
26536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier;
26546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* err;
26556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double min_err = 0, min_err_se = 0;
26566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int min_idx = -1;
26576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
26586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));
26596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
26606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // build the main tree sequence, calculate alpha's
26616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for(;;tree_count++)
26626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
26636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double min_alpha = update_tree_rnc(tree_count, -1);
26646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( cut_tree(tree_count, -1, min_alpha) )
26656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            break;
26666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
26676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( ab->cols <= tree_count )
26686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
26696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F ));
26706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( ti = 0; ti < ab->cols; ti++ )
26716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                temp->data.db[ti] = ab->data.db[ti];
26726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvReleaseMat( &ab );
26736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            ab = temp;
26746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            temp = 0;
26756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
26766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
26776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        ab->data.db[tree_count] = min_alpha;
26786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
26796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
26806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ab->data.db[0] = 0.;
26816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
26826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( tree_count > 0 )
26836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
26846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( ti = 1; ti < tree_count-1; ti++ )
26856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);
26866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        ab->data.db[tree_count-1] = DBL_MAX*0.5;
26876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
26886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));
26896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        err = err_jk->data.db;
26906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
26916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( j = 0; j < cv_n; j++ )
26926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
26936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int tj = 0, tk = 0;
26946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( ; tk < tree_count; tj++ )
26956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
26966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double min_alpha = update_tree_rnc(tj, j);
26976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( cut_tree(tj, j, min_alpha) )
26986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    min_alpha = DBL_MAX;
26996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
27006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( ; tk < tree_count; tk++ )
27016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
27026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( ab->data.db[tk] > min_alpha )
27036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        break;
27046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    err[j*tree_count + tk] = root->tree_error;
27056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
27066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
27076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
27086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
27096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( ti = 0; ti < tree_count; ti++ )
27106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
27116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double sum_err = 0;
27126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j < cv_n; j++ )
27136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                sum_err += err[j*tree_count + ti];
27146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( ti == 0 || sum_err < min_err )
27156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
27166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                min_err = sum_err;
27176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                min_idx = ti;
27186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( use_1se )
27196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    min_err_se = sqrt( sum_err*(n - sum_err) );
27206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
27216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else if( sum_err < min_err + min_err_se )
27226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                min_idx = ti;
27236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
27246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
27256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
27266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    pruned_tree_idx = min_idx;
27276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    free_prune_data(data->params.truncate_pruned_tree != 0);
27286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
27296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
27306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
27316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &err_jk );
27326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &ab );
27336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &temp );
27346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
27356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
27366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
27376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renndouble CvDTree::update_tree_rnc( int T, int fold )
27386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
27396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeNode* node = root;
27406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double min_alpha = DBL_MAX;
27416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
27426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for(;;)
27436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
27446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvDTreeNode* parent;
27456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for(;;)
27466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
27476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
27486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( t <= T || !node->left )
27496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
27506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                node->complexity = 1;
27516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                node->tree_risk = node->node_risk;
27526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                node->tree_error = 0.;
27536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( fold >= 0 )
27546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
27556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    node->tree_risk = node->cv_node_risk[fold];
27566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    node->tree_error = node->cv_node_error[fold];
27576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
27586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                break;
27596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
27606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            node = node->left;
27616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
27626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
27636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( parent = node->parent; parent && parent->right == node;
27646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            node = parent, parent = parent->parent )
27656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
27666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            parent->complexity += node->complexity;
27676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            parent->tree_risk += node->tree_risk;
27686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            parent->tree_error += node->tree_error;
27696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
27706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk)
27716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                - parent->tree_risk)/(parent->complexity - 1);
27726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            min_alpha = MIN( min_alpha, parent->alpha );
27736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
27746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
27756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !parent )
27766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            break;
27776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
27786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        parent->complexity = node->complexity;
27796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        parent->tree_risk = node->tree_risk;
27806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        parent->tree_error = node->tree_error;
27816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node = parent->right;
27826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
27836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
27846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return min_alpha;
27856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
27866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
27876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
27886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvDTree::cut_tree( int T, int fold, double min_alpha )
27896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
27906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeNode* node = root;
27916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !node->left )
27926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        return 1;
27936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
27946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for(;;)
27956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
27966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvDTreeNode* parent;
27976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for(;;)
27986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
27996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
28006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( t <= T || !node->left )
28016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                break;
28026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( node->alpha <= min_alpha + FLT_EPSILON )
28036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
28046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( fold >= 0 )
28056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    node->cv_Tn[fold] = T;
28066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                else
28076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    node->Tn = T;
28086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( node == root )
28096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    return 1;
28106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                break;
28116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
28126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            node = node->left;
28136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
28146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
28156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( parent = node->parent; parent && parent->right == node;
28166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            node = parent, parent = parent->parent )
28176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            ;
28186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
28196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !parent )
28206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            break;
28216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
28226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node = parent->right;
28236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
28246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
28256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return 0;
28266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
28276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
28286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
28296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::free_prune_data(bool cut_tree)
28306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
28316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeNode* node = root;
28326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
28336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for(;;)
28346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
28356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvDTreeNode* parent;
28366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for(;;)
28376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
28386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )
28396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // as we will clear the whole cross-validation heap at the end
28406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            node->cv_Tn = 0;
28416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            node->cv_node_error = node->cv_node_risk = 0;
28426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( !node->left )
28436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                break;
28446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            node = node->left;
28456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
28466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
28476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( parent = node->parent; parent && parent->right == node;
28486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            node = parent, parent = parent->parent )
28496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
28506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( cut_tree && parent->Tn <= pruned_tree_idx )
28516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
28526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                data->free_node( parent->left );
28536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                data->free_node( parent->right );
28546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                parent->left = parent->right = 0;
28556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
28566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
28576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
28586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !parent )
28596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            break;
28606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
28616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node = parent->right;
28626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
28636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
28646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( data->cv_heap )
28656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvClearSet( data->cv_heap );
28666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
28676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
28686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
28696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::free_tree()
28706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
28716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( root && data && data->shared )
28726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
28736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        pruned_tree_idx = INT_MIN;
28746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        free_prune_data(true);
28756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        data->free_node(root);
28766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        root = 0;
28776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
28786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
28796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
28806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
28816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeNode* CvDTree::predict( const CvMat* _sample,
28826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const CvMat* _missing, bool preprocessed_input ) const
28836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
28846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeNode* result = 0;
28856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int* catbuf = 0;
28866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
28876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvDTree::predict" );
28886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
28896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
28906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
28916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, step, mstep = 0;
28926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const float* sample;
28936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const uchar* m = 0;
28946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeNode* node = root;
28956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int* vtype;
28966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int* vidx;
28976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int* cmap;
28986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int* cofs;
28996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
29006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !node )
29016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsError, "The tree has not been trained yet" );
29026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
29036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
29046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        _sample->cols != 1 && _sample->rows != 1 ||
29056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        _sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input ||
29066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        _sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input )
29076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadArg,
29086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        "the input sample must be 1d floating-point vector with the same "
29096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        "number of elements as the total number of variables used for training" );
29106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
29116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    sample = _sample->data.fl;
29126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]);
29136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
29146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( data->cat_count && !preprocessed_input ) // cache for categorical variables
29156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
29166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int n = data->cat_count->cols;
29176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        catbuf = (int*)cvStackAlloc(n*sizeof(catbuf[0]));
29186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n; i++ )
29196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            catbuf[i] = -1;
29206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
29216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
29226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( _missing )
29236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
29246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
29256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        !CV_ARE_SIZES_EQ(_missing, _sample) )
29266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadArg,
29276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        "the missing data mask must be 8-bit vector of the same size as input sample" );
29286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        m = _missing->data.ptr;
29296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);
29306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
29316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
29326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    vtype = data->var_type->data.i;
29336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;
29346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cmap = data->cat_map ? data->cat_map->data.i : 0;
29356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cofs = data->cat_ofs ? data->cat_ofs->data.i : 0;
29366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
29376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    while( node->Tn > pruned_tree_idx && node->left )
29386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
29396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvDTreeSplit* split = node->split;
29406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int dir = 0;
29416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( ; !dir && split != 0; split = split->next )
29426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
29436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int vi = split->var_idx;
29446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int ci = vtype[vi];
29456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            i = vidx ? vidx[vi] : vi;
29466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            float val = sample[i*step];
29476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( m && m[i*mstep] )
29486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                continue;
29496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( ci < 0 ) // ordered
29506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dir = val <= split->ord.c ? -1 : 1;
29516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else // categorical
29526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
29536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int c;
29546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( preprocessed_input )
29556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    c = cvRound(val);
29566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                else
29576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
29586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    c = catbuf[ci];
29596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( c < 0 )
29606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    {
29616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        int a = c = cofs[ci];
29626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        int b = cofs[ci+1];
29636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        int ival = cvRound(val);
29646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        if( ival != val )
29656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            CV_ERROR( CV_StsBadArg,
29666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            "one of input categorical variable is not an integer" );
29676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
29686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        while( a < b )
29696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        {
29706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            c = (a + b) >> 1;
29716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            if( ival < cmap[c] )
29726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                b = c;
29736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            else if( ival > cmap[c] )
29746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                a = c+1;
29756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            else
29766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                break;
29776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        }
29786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
29796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        if( c < 0 || ival != cmap[c] )
29806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            continue;
29816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
29826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        catbuf[ci] = c -= cofs[ci];
29836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    }
29846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
29856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dir = CV_DTREE_CAT_DIR(c, split->subset);
29866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
29876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
29886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( split->inversed )
29896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dir = -dir;
29906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
29916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
29926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !dir )
29936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
29946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double diff = node->right->sample_count - node->left->sample_count;
29956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            dir = diff < 0 ? -1 : 1;
29966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
29976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node = dir < 0 ? node->left : node->right;
29986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
29996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    result = node;
30016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
30036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return result;
30056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
30066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennconst CvMat* CvDTree::get_var_importance()
30096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
30106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !var_importance )
30116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
30126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvDTreeNode* node = root;
30136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double* importance;
30146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !node )
30156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            return 0;
30166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        var_importance = cvCreateMat( 1, data->var_count, CV_64F );
30176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvZero( var_importance );
30186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        importance = var_importance->data.db;
30196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for(;;)
30216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
30226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CvDTreeNode* parent;
30236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( ;; node = node->left )
30246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
30256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CvDTreeSplit* split = node->split;
30266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( !node->left || node->Tn <= pruned_tree_idx )
30286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    break;
30296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( ; split != 0; split = split->next )
30316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    importance[split->var_idx] += split->quality;
30326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
30336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( parent = node->parent; parent && parent->right == node;
30356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                node = parent, parent = parent->parent )
30366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                ;
30376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( !parent )
30396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                break;
30406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            node = parent->right;
30426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
30436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
30456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
30466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return var_importance;
30486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
30496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split )
30526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
30536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int ci;
30546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );
30566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteInt( fs, "var", split->var_idx );
30576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteReal( fs, "quality", split->quality );
30586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ci = data->get_var_type(split->var_idx);
30606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( ci >= 0 ) // split on a categorical var
30616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
30626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;
30636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n; i++ )
30646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            to_right += CV_DTREE_CAT_DIR(i,split->subset) > 0;
30656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // ad-hoc rule when to use inverse categorical split notation
30676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // to achieve more compact and clear representation
30686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;
30696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?
30716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );
30726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < n; i++ )
30746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
30756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int dir = CV_DTREE_CAT_DIR(i,split->subset);
30766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( dir*default_dir < 0 )
30776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvWriteInt( fs, 0, i );
30786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
30796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvEndWriteStruct( fs );
30806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
30816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
30826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c );
30836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvEndWriteStruct( fs );
30856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
30866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node )
30896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
30906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeSplit* split;
30916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvStartWriteStruct( fs, 0, CV_NODE_MAP );
30936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteInt( fs, "depth", node->depth );
30956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteInt( fs, "sample_count", node->sample_count );
30966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteReal( fs, "value", node->value );
30976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
30986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( data->is_classifier )
30996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWriteInt( fs, "norm_class_idx", node->class_idx );
31006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteInt( fs, "Tn", node->Tn );
31026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteInt( fs, "complexity", node->complexity );
31036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteReal( fs, "alpha", node->alpha );
31046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteReal( fs, "node_risk", node->node_risk );
31056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteReal( fs, "tree_risk", node->tree_risk );
31066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteReal( fs, "tree_error", node->tree_error );
31076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( node->left )
31096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
31106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvStartWriteStruct( fs, "splits", CV_NODE_SEQ );
31116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( split = node->split; split != 0; split = split->next )
31136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            write_split( fs, split );
31146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvEndWriteStruct( fs );
31166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
31176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvEndWriteStruct( fs );
31196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
31206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::write_tree_nodes( CvFileStorage* fs )
31236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
31246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    //CV_FUNCNAME( "CvDTree::write_tree_nodes" );
31256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
31276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeNode* node = root;
31296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // traverse the tree and save all the nodes in depth-first order
31316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for(;;)
31326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
31336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvDTreeNode* parent;
31346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for(;;)
31356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
31366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            write_node( fs, node );
31376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( !node->left )
31386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                break;
31396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            node = node->left;
31406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
31416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( parent = node->parent; parent && parent->right == node;
31436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            node = parent, parent = parent->parent )
31446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            ;
31456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !parent )
31476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            break;
31486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node = parent->right;
31506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
31516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
31536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
31546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::write( CvFileStorage* fs, const char* name )
31576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
31586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    //CV_FUNCNAME( "CvDTree::write" );
31596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
31616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );
31636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    get_var_importance();
31656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data->write_params( fs );
31666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( var_importance )
31676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWrite( fs, "var_importance", var_importance );
31686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    write( fs );
31696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvEndWriteStruct( fs );
31716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
31736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
31746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::write( CvFileStorage* fs )
31776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
31786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    //CV_FUNCNAME( "CvDTree::write" );
31796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
31816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
31836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
31856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    write_tree_nodes( fs );
31866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvEndWriteStruct( fs );
31876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
31896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
31906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )
31936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
31946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeSplit* split = 0;
31956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvDTree::read_split" );
31976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
31986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
31996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
32006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int vi, ci;
32016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
32026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
32036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );
32046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
32056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    vi = cvReadIntByName( fs, fnode, "var", -1 );
32066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( (unsigned)vi >= (unsigned)data->var_count )
32076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" );
32086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
32096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ci = data->get_var_type(vi);
32106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( ci >= 0 ) // split on categorical var
32116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
32126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int i, n = data->cat_count->data.i[ci], inversed = 0, val;
32136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvSeqReader reader;
32146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvFileNode* inseq;
32156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        split = data->new_split_cat( vi, 0 );
32166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        inseq = cvGetFileNodeByName( fs, fnode, "in" );
32176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !inseq )
32186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
32196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            inseq = cvGetFileNodeByName( fs, fnode, "not_in" );
32206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            inversed = 1;
32216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
32226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !inseq ||
32236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            (CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ && CV_NODE_TYPE(inseq->tag) != CV_NODE_INT))
32246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsParseError,
32256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            "Either 'in' or 'not_in' tags should be inside a categorical split data" );
32266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
32276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( CV_NODE_TYPE(inseq->tag) == CV_NODE_INT )
32286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
32296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            val = inseq->data.i;
32306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( (unsigned)val >= (unsigned)n )
32316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
32326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
32336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            split->subset[val >> 5] |= 1 << (val & 31);
32346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
32356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
32366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
32376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvStartReadSeq( inseq->data.seq, &reader );
32386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
32396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < reader.seq->total; i++ )
32406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
32416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CvFileNode* inode = (CvFileNode*)reader.ptr;
32426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                val = inode->data.i;
32436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )
32446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
32456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
32466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                split->subset[val >> 5] |= 1 << (val & 31);
32476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
32486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
32496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
32506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
32516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // for categorical splits we do not use inversed splits,
32526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // instead we inverse the variable set in the split
32536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( inversed )
32546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < (n + 31) >> 5; i++ )
32556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                split->subset[i] ^= -1;
32566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
32576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
32586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
32596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvFileNode* cmp_node;
32606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        split = data->new_split_ord( vi, 0, 0, 0, 0 );
32616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
32626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cmp_node = cvGetFileNodeByName( fs, fnode, "le" );
32636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !cmp_node )
32646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
32656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cmp_node = cvGetFileNodeByName( fs, fnode, "gt" );
32666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            split->inversed = 1;
32676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
32686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
32696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        split->ord.c = (float)cvReadReal( cmp_node );
32706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
32716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
32726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    split->quality = (float)cvReadRealByName( fs, fnode, "quality" );
32736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
32746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
32756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
32766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return split;
32776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
32786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
32796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
32806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )
32816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
32826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeNode* node = 0;
32836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
32846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvDTree::read_node" );
32856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
32866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
32876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
32886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvFileNode* splits;
32896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, depth;
32906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
32916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
32926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" );
32936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
32946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( node = data->new_node( parent, 0, 0, 0 ));
32956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    depth = cvReadIntByName( fs, fnode, "depth", -1 );
32966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( depth != node->depth )
32976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsParseError, "incorrect node depth" );
32986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
32996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );
33006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->value = cvReadRealByName( fs, fnode, "value" );
33016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( data->is_classifier )
33026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );
33036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->Tn = cvReadIntByName( fs, fnode, "Tn" );
33056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->complexity = cvReadIntByName( fs, fnode, "complexity" );
33066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->alpha = cvReadRealByName( fs, fnode, "alpha" );
33076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );
33086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );
33096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );
33106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    splits = cvGetFileNodeByName( fs, fnode, "splits" );
33126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( splits )
33136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
33146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvSeqReader reader;
33156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvDTreeSplit* last_split = 0;
33166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ )
33186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" );
33196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvStartReadSeq( splits->data.seq, &reader );
33216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < reader.seq->total; i++ )
33226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
33236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CvDTreeSplit* split;
33246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr ));
33256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( !last_split )
33266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                node->split = last_split = split;
33276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else
33286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                last_split = last_split->next = split;
33296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
33316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
33326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
33336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
33356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return node;
33376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
33386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )
33416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
33426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvDTree::read_tree_nodes" );
33436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
33456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvSeqReader reader;
33476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeNode _root;
33486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeNode* parent = &_root;
33496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i;
33506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    parent->left = parent->right = parent->parent = 0;
33516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvStartReadSeq( fnode->data.seq, &reader );
33536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < reader.seq->total; i++ )
33556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
33566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvDTreeNode* node;
33576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));
33596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !parent->left )
33606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            parent->left = node;
33616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
33626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            parent->right = node;
33636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( node->split )
33646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            parent = node;
33656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
33666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
33676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            while( parent && parent->right )
33686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                parent = parent->parent;
33696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
33706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
33726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
33736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    root = _root.left;
33756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
33776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
33786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )
33816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
33826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeTrainData* _data = new CvDTreeTrainData();
33836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    _data->read_params( fs, fnode );
33846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    read( fs, fnode, _data );
33866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    get_var_importance();
33876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
33886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// a special entry point for reading weak decision trees from the tree ensembles
33916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data )
33926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
33936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvDTree::read" );
33946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
33966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvFileNode* tree_nodes;
33986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
33996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
34006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data = _data;
34016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
34026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    tree_nodes = cvGetFileNodeByName( fs, node, "nodes" );
34036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )
34046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsParseError, "nodes tag is missing" );
34056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
34066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 );
34076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    read_tree_nodes( fs, tree_nodes );
34086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
34096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
34106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
34116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
34126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/* End of file. */
3413