16acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/*M///////////////////////////////////////////////////////////////////////////////////////
26acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
36acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
46acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
56acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  By downloading, copying, installing or using the software you agree to this license.
66acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  If you do not agree to this license, do not download, install,
76acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  copy or use the software.
86acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
96acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//                        Intel License Agreement
116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Copyright (C) 2000, Intel Corporation, all rights reserved.
136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Third party copyrights are property of their respective owners.
146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Redistribution and use in source and binary forms, with or without modification,
166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// are permitted provided that the following conditions are met:
176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//   * Redistribution's of source code must retain the above copyright notice,
196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     this list of conditions and the following disclaimer.
206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//   * Redistribution's in binary form must reproduce the above copyright notice,
226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     this list of conditions and the following disclaimer in the documentation
236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     and/or other materials provided with the distribution.
246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//   * The name of Intel Corporation may not be used to endorse or promote products
266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     derived from this software without specific prior written permission.
276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// This software is provided by the copyright holders and contributors "as is" and
296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// any express or implied warranties, including, but not limited to, the implied
306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// warranties of merchantability and fitness for a particular purpose are disclaimed.
316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// In no event shall the Intel Corporation or contributors be liable for any direct,
326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// indirect, incidental, special, exemplary, or consequential damages
336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// (including, but not limited to, procurement of substitute goods or services;
346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// loss of use, data, or profits; or business interruption) however caused
356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// and on any theory of liability, whether in contract, strict liability,
366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// or tort (including negligence or otherwise) arising in any way out of
376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// the use of this software, even if advised of the possibility of such damage.
386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//M*/
406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#include "_ml.h"
426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvForestTree::CvForestTree()
446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    forest = NULL;
466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvForestTree::~CvForestTree()
506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool CvForestTree::train( CvDTreeTrainData* _data,
566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                          const CvMat* _subsample_idx,
576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                          CvRTrees* _forest )
586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bool result = false;
606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvForestTree::train" );
626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    forest = _forest;
686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data = _data;
706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data->shared = true;
716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(result = do_train(_subsample_idx));
726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return result;
766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool
806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvForestTree::train( const CvMat*, int, const CvMat*, const CvMat*,
816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    const CvMat*, const CvMat*, const CvMat*, CvDTreeParams )
826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    assert(0);
846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return false;
856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool
896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvForestTree::train( CvDTreeTrainData*, const CvMat* )
906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    assert(0);
926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return false;
936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvDTreeSplit* CvForestTree::find_best_split( CvDTreeNode* node )
976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int vi;
996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeSplit *best_split = 0, *split = 0, *t;
1006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("CvForestTree::find_best_split");
1026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
1036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* active_var_mask = 0;
1056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( forest )
1066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
1076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int var_count;
1086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvRNG* rng = forest->get_rng();
1096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        active_var_mask = forest->get_active_var_mask();
1116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        var_count = active_var_mask->cols;
1126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ASSERT( var_count == data->var_count );
1146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( vi = 0; vi < var_count; vi++ )
1166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
1176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            uchar temp;
1186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int i1 = cvRandInt(rng) % var_count;
1196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int i2 = cvRandInt(rng) % var_count;
1206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_SWAP( active_var_mask->data.ptr[i1],
1216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                active_var_mask->data.ptr[i2], temp );
1226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
1236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
1246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( vi = 0; vi < data->var_count; vi++ )
1256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
1266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int ci = data->var_type->data.i[vi];
1276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( node->num_valid[vi] <= 1
1286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            || (active_var_mask && !active_var_mask->data.ptr[vi]) )
1296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            continue;
1306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( data->is_classifier )
1326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
1336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( ci >= 0 )
1346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                split = find_split_cat_class( node, vi );
1356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else
1366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                split = find_split_ord_class( node, vi );
1376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
1386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
1396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
1406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( ci >= 0 )
1416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                split = find_split_cat_reg( node, vi );
1426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else
1436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                split = find_split_ord_reg( node, vi );
1446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
1456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( split )
1476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
1486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( !best_split || best_split->quality < split->quality )
1496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CV_SWAP( best_split, split, t );
1506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( split )
1516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvSetRemoveByPtr( data->split_heap, split );
1526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
1536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
1546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
1566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return best_split;
1586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
1596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvForestTree::read( CvFileStorage* fs, CvFileNode* fnode, CvRTrees* _forest, CvDTreeTrainData* _data )
1626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
1636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTree::read( fs, fnode, _data );
1646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    forest = _forest;
1656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
1666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvForestTree::read( CvFileStorage*, CvFileNode* )
1696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
1706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    assert(0);
1716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
1726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvForestTree::read( CvFileStorage* _fs, CvFileNode* _node,
1746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                         CvDTreeTrainData* _data )
1756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
1766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTree::read( _fs, _node, _data );
1776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
1786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//////////////////////////////////////////////////////////////////////////////////////////
1816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//                                  Random trees                                        //
1826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//////////////////////////////////////////////////////////////////////////////////////////
1836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvRTrees::CvRTrees()
1856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
1866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    nclasses         = 0;
1876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    oob_error        = 0;
1886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ntrees           = 0;
1896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    trees            = NULL;
1906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data             = NULL;
1916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    active_var_mask  = NULL;
1926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    var_importance   = NULL;
1936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    rng = cvRNG(0xffffffff);
1946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    default_model_name = "my_random_trees";
1956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
1966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvRTrees::clear()
1996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
2006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int k;
2016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( k = 0; k < ntrees; k++ )
2026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        delete trees[k];
2036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( &trees );
2046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    delete data;
2066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data = 0;
2076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &active_var_mask );
2096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &var_importance );
2106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ntrees = 0;
2116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
2126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvRTrees::~CvRTrees()
2156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
2166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
2176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
2186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvMat* CvRTrees::get_active_var_mask()
2216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
2226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return active_var_mask;
2236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
2246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvRNG* CvRTrees::get_rng()
2276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
2286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return &rng;
2296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
2306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool CvRTrees::train( const CvMat* _train_data, int _tflag,
2326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        const CvMat* _responses, const CvMat* _var_idx,
2336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        const CvMat* _sample_idx, const CvMat* _var_type,
2346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        const CvMat* _missing_mask, CvRTParams params )
2356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
2366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bool result = false;
2376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("CvRTrees::train");
2396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
2406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int var_count = 0;
2426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
2446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvDTreeParams tree_params( params.max_depth, params.min_sample_count,
2466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.regression_accuracy, params.use_surrogates, params.max_categories,
2476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.cv_folds, params.use_1se_rule, false, params.priors );
2486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data = new CvDTreeTrainData();
2506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(data->set_data( _train_data, _tflag, _responses, _var_idx,
2516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        _sample_idx, _var_type, _missing_mask, tree_params, true));
2526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    var_count = data->var_count;
2546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.nactive_vars > var_count )
2556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.nactive_vars = var_count;
2566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else if( params.nactive_vars == 0 )
2576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.nactive_vars = (int)sqrt((double)var_count);
2586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else if( params.nactive_vars < 0 )
2596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "<nactive_vars> must be non-negative" );
2606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params.term_crit = cvCheckTermCriteria( params.term_crit, 0.1, 1000 );
2616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // Create mask of active variables at the tree nodes
2636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 ));
2646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.calc_var_importance )
2656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
2666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(var_importance  = cvCreateMat( 1, var_count, CV_32FC1 ));
2676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvZero(var_importance);
2686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
2696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    { // initialize active variables mask
2706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvMat submask1, submask2;
2716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvGetCols( active_var_mask, &submask1, 0, params.nactive_vars );
2726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvGetCols( active_var_mask, &submask2, params.nactive_vars, var_count );
2736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvSet( &submask1, cvScalar(1) );
2746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvZero( &submask2 );
2756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
2766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(result = grow_forest( params.term_crit ));
2786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    result = true;
2806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
2826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return result;
2846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
2856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool CvRTrees::grow_forest( const CvTermCriteria term_crit )
2886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
2896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bool result = false;
2906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* sample_idx_mask_for_tree = 0;
2926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* sample_idx_for_tree      = 0;
2936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* oob_sample_votes	   = 0;
2956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* oob_responses       = 0;
2966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float* oob_samples_perm_ptr= 0;
2986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float* samples_ptr     = 0;
3006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    uchar* missing_ptr     = 0;
3016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float* true_resp_ptr   = 0;
3026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("CvRTrees::grow_forest");
3046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
3056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int max_ntrees = term_crit.max_iter;
3076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const double max_oob_err = term_crit.epsilon;
3086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int dims = data->var_count;
3106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float maximal_response = 0;
3116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // oob_predictions_sum[i] = sum of predicted values for the i-th sample
3136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // oob_num_of_predictions[i] = number of summands
3146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    //                            (number of predictions for the i-th sample)
3156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // initialize these variable to avoid warning C4701
3166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat oob_predictions_sum = cvMat( 1, 1, CV_32FC1 );
3176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat oob_num_of_predictions = cvMat( 1, 1, CV_32FC1 );
3186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    nsamples = data->sample_count;
3206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    nclasses = data->get_num_classes();
3216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees );
3236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    memset( trees, 0, sizeof(trees[0])*max_ntrees );
3246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( data->is_classifier )
3266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(oob_sample_votes = cvCreateMat( nsamples, nclasses, CV_32SC1 ));
3286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvZero(oob_sample_votes);
3296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
3306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
3316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // oob_responses[0,i] = oob_predictions_sum[i]
3336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        //    = sum of predicted values for the i-th sample
3346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // oob_responses[1,i] = oob_num_of_predictions[i]
3356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        //    = number of summands (number of predictions for the i-th sample)
3366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(oob_responses = cvCreateMat( 2, nsamples, CV_32FC1 ));
3376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvZero(oob_responses);
3386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvGetRow( oob_responses, &oob_predictions_sum, 0 );
3396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvGetRow( oob_responses, &oob_num_of_predictions, 1 );
3406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
3416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(sample_idx_mask_for_tree = cvCreateMat( 1, nsamples, CV_8UC1 ));
3426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(sample_idx_for_tree      = cvCreateMat( 1, nsamples, CV_32SC1 ));
3436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(oob_samples_perm_ptr     = (float*)cvAlloc( sizeof(float)*nsamples*dims ));
3446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(samples_ptr              = (float*)cvAlloc( sizeof(float)*nsamples*dims ));
3456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(missing_ptr              = (uchar*)cvAlloc( sizeof(uchar)*nsamples*dims ));
3466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(true_resp_ptr            = (float*)cvAlloc( sizeof(float)*nsamples ));
3476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(data->get_vectors( 0, samples_ptr, missing_ptr, true_resp_ptr ));
3496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double minval, maxval;
3516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvMat responses = cvMat(1, nsamples, CV_32FC1, true_resp_ptr);
3526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvMinMaxLoc( &responses, &minval, &maxval );
3536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        maximal_response = (float)MAX( MAX( fabs(minval), fabs(maxval) ), 0 );
3546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
3556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ntrees = 0;
3576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    while( ntrees < max_ntrees )
3586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int i, oob_samples_count = 0;
3606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double ncorrect_responses = 0; // used for estimation of variable importance
3616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvMat sample, missing;
3626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvForestTree* tree = 0;
3636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvZero( sample_idx_mask_for_tree );
3656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < nsamples; i++ ) //form sample for creation one tree
3666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
3676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int idx = cvRandInt( &rng ) % nsamples;
3686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            sample_idx_for_tree->data.i[i] = idx;
3696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            sample_idx_mask_for_tree->data.ptr[idx] = 0xFF;
3706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
3716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        trees[ntrees] = new CvForestTree();
3736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        tree = trees[ntrees];
3746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(tree->train( data, sample_idx_for_tree, this ));
3756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // form array of OOB samples indices and get these samples
3776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        sample   = cvMat( 1, dims, CV_32FC1, samples_ptr );
3786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        missing  = cvMat( 1, dims, CV_8UC1,  missing_ptr );
3796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        oob_error = 0;
3816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < nsamples; i++,
3826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            sample.data.fl += dims, missing.data.ptr += dims )
3836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
3846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CvDTreeNode* predicted_node = 0;
3856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // check if the sample is OOB
3866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( sample_idx_mask_for_tree->data.ptr[i] )
3876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                continue;
3886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // predict oob samples
3906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( !predicted_node )
3916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CV_CALL(predicted_node = tree->predict(&sample, &missing, true));
3926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( !data->is_classifier ) //regression
3946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
3956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double avg_resp, resp = predicted_node->value;
3966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                oob_predictions_sum.data.fl[i] += (float)resp;
3976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                oob_num_of_predictions.data.fl[i] += 1;
3986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                // compute oob error
4006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                avg_resp = oob_predictions_sum.data.fl[i]/oob_num_of_predictions.data.fl[i];
4016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                avg_resp -= true_resp_ptr[i];
4026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                oob_error += avg_resp*avg_resp;
4036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                resp = (resp - true_resp_ptr[i])/maximal_response;
4046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                ncorrect_responses += exp( -resp*resp );
4056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
4066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else //classification
4076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
4086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double prdct_resp;
4096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CvPoint max_loc;
4106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CvMat votes;
4116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvGetRow(oob_sample_votes, &votes, i);
4136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                votes.data.i[predicted_node->class_idx]++;
4146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                // compute oob error
4166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvMinMaxLoc( &votes, 0, 0, 0, &max_loc );
4176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                prdct_resp = data->cat_map->data.i[max_loc.x];
4196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                oob_error += (fabs(prdct_resp - true_resp_ptr[i]) < FLT_EPSILON) ? 0 : 1;
4206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                ncorrect_responses += cvRound(predicted_node->value - true_resp_ptr[i]) == 0;
4226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
4236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            oob_samples_count++;
4246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
4256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( oob_samples_count > 0 )
4266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            oob_error /= (double)oob_samples_count;
4276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // estimate variable importance
4296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( var_importance && oob_samples_count > 0 )
4306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
4316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int m;
4326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            memcpy( oob_samples_perm_ptr, samples_ptr, dims*nsamples*sizeof(float));
4346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( m = 0; m < dims; m++ )
4356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
4366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double ncorrect_responses_permuted = 0;
4376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                // randomly permute values of the m-th variable in the oob samples
4386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                float* mth_var_ptr = oob_samples_perm_ptr + m;
4396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( i = 0; i < nsamples; i++ )
4416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
4426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    int i1, i2;
4436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    float temp;
4446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( sample_idx_mask_for_tree->data.ptr[i] ) //the sample is not OOB
4466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        continue;
4476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    i1 = cvRandInt( &rng ) % nsamples;
4486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    i2 = cvRandInt( &rng ) % nsamples;
4496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    CV_SWAP( mth_var_ptr[i1*dims], mth_var_ptr[i2*dims], temp );
4506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    // turn values of (m-1)-th variable, that were permuted
4526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    // at the previous iteration, untouched
4536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( m > 1 )
4546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        oob_samples_perm_ptr[i*dims+m-1] = samples_ptr[i*dims+m-1];
4556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
4566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                // predict "permuted" cases and calculate the number of votes for the
4586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                // correct class in the variable-m-permuted oob data
4596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                sample  = cvMat( 1, dims, CV_32FC1, oob_samples_perm_ptr );
4606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                missing = cvMat( 1, dims, CV_8UC1, missing_ptr );
4616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( i = 0; i < nsamples; i++,
4626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    sample.data.fl += dims, missing.data.ptr += dims )
4636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
4646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    double predct_resp, true_resp;
4656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( sample_idx_mask_for_tree->data.ptr[i] ) //the sample is not OOB
4676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        continue;
4686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    predct_resp = tree->predict(&sample, &missing, true)->value;
4706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    true_resp   = true_resp_ptr[i];
4716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( data->is_classifier )
4726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        ncorrect_responses_permuted += cvRound(true_resp - predct_resp) == 0;
4736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    else
4746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    {
4756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        true_resp = (true_resp - predct_resp)/maximal_response;
4766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        ncorrect_responses_permuted += exp( -true_resp*true_resp );
4776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    }
4786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
4796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                var_importance->data.fl[m] += (float)(ncorrect_responses
4806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    - ncorrect_responses_permuted);
4816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
4826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
4836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        ntrees++;
4846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( term_crit.type != CV_TERMCRIT_ITER && oob_error < max_oob_err )
4856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            break;
4866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
4876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( var_importance )
4886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(cvConvertScale( var_importance, var_importance, 1./ntrees/nsamples ));
4896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    result = true;
4916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
4936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &sample_idx_mask_for_tree );
4956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &sample_idx_for_tree );
4966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &oob_sample_votes );
4976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &oob_responses );
4986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( &oob_samples_perm_ptr );
5006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( &samples_ptr );
5016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( &missing_ptr );
5026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( &true_resp_ptr );
5036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return result;
5056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
5066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennconst CvMat* CvRTrees::get_var_importance()
5096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
5106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return var_importance;
5116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
5126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennfloat CvRTrees::get_proximity( const CvMat* sample1, const CvMat* sample2,
5156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                              const CvMat* missing1, const CvMat* missing2 ) const
5166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
5176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float result = 0;
5186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvRTrees::get_proximity" );
5206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
5226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i;
5246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < ntrees; i++ )
5256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        result += trees[i]->predict( sample1, missing1 ) ==
5266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        trees[i]->predict( sample2, missing2 ) ?  1 : 0;
5276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    result = result/(float)ntrees;
5286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
5306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return result;
5326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
5336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennfloat CvRTrees::predict( const CvMat* sample, const CvMat* missing ) const
5366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
5376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double result = -1;
5386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("CvRTrees::predict");
5406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
5416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int k;
5436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( nclasses > 0 ) //classification
5456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
5466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int max_nvotes = 0;
5476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int* votes = (int*)alloca( sizeof(int)*nclasses );
5486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        memset( votes, 0, sizeof(*votes)*nclasses );
5496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( k = 0; k < ntrees; k++ )
5506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
5516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CvDTreeNode* predicted_node = trees[k]->predict( sample, missing );
5526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int nvotes;
5536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int class_idx = predicted_node->class_idx;
5546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ASSERT( 0 <= class_idx && class_idx < nclasses );
5556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            nvotes = ++votes[class_idx];
5576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( nvotes > max_nvotes )
5586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
5596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                max_nvotes = nvotes;
5606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                result = predicted_node->value;
5616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
5626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
5636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
5646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else // regression
5656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
5666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        result = 0;
5676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( k = 0; k < ntrees; k++ )
5686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            result += trees[k]->predict( sample, missing )->value;
5696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        result /= (double)ntrees;
5706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
5716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
5736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return (float)result;
5756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
5766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvRTrees::write( CvFileStorage* fs, const char* name )
5796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
5806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvRTrees::write" );
5816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
5836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int k;
5856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( ntrees < 1 || !trees || nsamples < 1 )
5876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Invalid CvRTrees object" );
5886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_RTREES );
5906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteInt( fs, "nclasses", nclasses );
5926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteInt( fs, "nsamples", nsamples );
5936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteInt( fs, "nactive_vars", (int)cvSum(active_var_mask).val[0] );
5946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteReal( fs, "oob_error", oob_error );
5956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( var_importance )
5976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvWrite( fs, "var_importance", var_importance );
5986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvWriteInt( fs, "ntrees", ntrees );
6006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(data->write_params( fs ));
6026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvStartWriteStruct( fs, "trees", CV_NODE_SEQ );
6046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( k = 0; k < ntrees; k++ )
6066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
6076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvStartWriteStruct( fs, 0, CV_NODE_MAP );
6086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( trees[k]->write( fs ));
6096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvEndWriteStruct( fs );
6106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
6116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvEndWriteStruct( fs ); //trees
6136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvEndWriteStruct( fs ); //CV_TYPE_NAME_ML_RTREES
6146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
6166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
6176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvRTrees::read( CvFileStorage* fs, CvFileNode* fnode )
6206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
6216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvRTrees::read" );
6226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
6246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int nactive_vars, var_count, k;
6266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvSeqReader reader;
6276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvFileNode* trees_fnode = 0;
6286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
6306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    nclasses     = cvReadIntByName( fs, fnode, "nclasses", -1 );
6326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    nsamples     = cvReadIntByName( fs, fnode, "nsamples" );
6336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    nactive_vars = cvReadIntByName( fs, fnode, "nactive_vars", -1 );
6346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    oob_error    = cvReadRealByName(fs, fnode, "oob_error", -1 );
6356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ntrees       = cvReadIntByName( fs, fnode, "ntrees", -1 );
6366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    var_importance = (CvMat*)cvReadByName( fs, fnode, "var_importance" );
6386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( nclasses < 0 || nsamples <= 0 || nactive_vars < 0 || oob_error < 0 || ntrees <= 0)
6406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsParseError, "Some <nclasses>, <nsamples>, <var_count>, "
6416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        "<nactive_vars>, <oob_error>, <ntrees> of tags are missing" );
6426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    rng = CvRNG( -1 );
6446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*ntrees );
6466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    memset( trees, 0, sizeof(trees[0])*ntrees );
6476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data = new CvDTreeTrainData();
6496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data->read_params( fs, fnode );
6506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    data->shared = true;
6516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    trees_fnode = cvGetFileNodeByName( fs, fnode, "trees" );
6536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !trees_fnode || !CV_NODE_IS_SEQ(trees_fnode->tag) )
6546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsParseError, "<trees> tag is missing" );
6556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvStartReadSeq( trees_fnode->data.seq, &reader );
6576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( reader.seq->total != ntrees )
6586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsParseError,
6596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        "<ntrees> is not equal to the number of trees saved in file" );
6606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( k = 0; k < ntrees; k++ )
6626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
6636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        trees[k] = new CvForestTree();
6646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL(trees[k]->read( fs, (CvFileNode*)reader.ptr, this, data ));
6656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
6666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
6676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    var_count = data->var_count;
6696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL(active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 ));
6706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
6716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // initialize active variables mask
6726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvMat submask1, submask2;
6736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvGetCols( active_var_mask, &submask1, 0, nactive_vars );
6746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvGetCols( active_var_mask, &submask2, nactive_vars, var_count );
6756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvSet( &submask1, cvScalar(1) );
6766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvZero( &submask2 );
6776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
6786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
6806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
6816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvRTrees::get_tree_count() const
6846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
6856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return ntrees;
6866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
6876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvForestTree* CvRTrees::get_tree(int i) const
6896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
6906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return (unsigned)i < (unsigned)ntrees ? trees[i] : 0;
6916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
6926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// End of file.
694