1#include "opencv2/ml/ml.hpp"
2#include "opencv2/core/core.hpp"
3#include "opencv2/core/utility.hpp"
4#include <stdio.h>
5#include <string>
6#include <map>
7
8using namespace cv;
9using namespace cv::ml;
10
11static void help()
12{
13    printf(
14        "\nThis sample demonstrates how to use different decision trees and forests including boosting and random trees.\n"
15        "Usage:\n\t./tree_engine [-r <response_column>] [-ts type_spec] <csv filename>\n"
16        "where -r <response_column> specified the 0-based index of the response (0 by default)\n"
17        "-ts specifies the var type spec in the form ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\n"
18        "<csv filename> is the name of training data file in comma-separated value format\n\n");
19}
20
21static void train_and_print_errs(Ptr<StatModel> model, const Ptr<TrainData>& data)
22{
23    bool ok = model->train(data);
24    if( !ok )
25    {
26        printf("Training failed\n");
27    }
28    else
29    {
30        printf( "train error: %f\n", model->calcError(data, false, noArray()) );
31        printf( "test error: %f\n\n", model->calcError(data, true, noArray()) );
32    }
33}
34
35int main(int argc, char** argv)
36{
37    if(argc < 2)
38    {
39        help();
40        return 0;
41    }
42    const char* filename = 0;
43    int response_idx = 0;
44    std::string typespec;
45
46    for(int i = 1; i < argc; i++)
47    {
48        if(strcmp(argv[i], "-r") == 0)
49            sscanf(argv[++i], "%d", &response_idx);
50        else if(strcmp(argv[i], "-ts") == 0)
51            typespec = argv[++i];
52        else if(argv[i][0] != '-' )
53            filename = argv[i];
54        else
55        {
56            printf("Error. Invalid option %s\n", argv[i]);
57            help();
58            return -1;
59        }
60    }
61
62    printf("\nReading in %s...\n\n",filename);
63    const double train_test_split_ratio = 0.5;
64
65    Ptr<TrainData> data = TrainData::loadFromCSV(filename, 0, response_idx, response_idx+1, typespec);
66
67    if( data.empty() )
68    {
69        printf("ERROR: File %s can not be read\n", filename);
70        return 0;
71    }
72
73    data->setTrainTestSplitRatio(train_test_split_ratio);
74
75    printf("======DTREE=====\n");
76    Ptr<DTrees> dtree = DTrees::create();
77    dtree->setMaxDepth(10);
78    dtree->setMinSampleCount(2);
79    dtree->setRegressionAccuracy(0);
80    dtree->setUseSurrogates(false);
81    dtree->setMaxCategories(16);
82    dtree->setCVFolds(0);
83    dtree->setUse1SERule(false);
84    dtree->setTruncatePrunedTree(false);
85    dtree->setPriors(Mat());
86    train_and_print_errs(dtree, data);
87
88    if( (int)data->getClassLabels().total() <= 2 ) // regression or 2-class classification problem
89    {
90        printf("======BOOST=====\n");
91        Ptr<Boost> boost = Boost::create();
92        boost->setBoostType(Boost::GENTLE);
93        boost->setWeakCount(100);
94        boost->setWeightTrimRate(0.95);
95        boost->setMaxDepth(2);
96        boost->setUseSurrogates(false);
97        boost->setPriors(Mat());
98        train_and_print_errs(boost, data);
99    }
100
101    printf("======RTREES=====\n");
102    Ptr<RTrees> rtrees = RTrees::create();
103    rtrees->setMaxDepth(10);
104    rtrees->setMinSampleCount(2);
105    rtrees->setRegressionAccuracy(0);
106    rtrees->setUseSurrogates(false);
107    rtrees->setMaxCategories(16);
108    rtrees->setPriors(Mat());
109    rtrees->setCalculateVarImportance(false);
110    rtrees->setActiveVarCount(0);
111    rtrees->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 0));
112    train_and_print_errs(rtrees, data);
113
114    return 0;
115}
116