1/*M///////////////////////////////////////////////////////////////////////////////////////
2//
3//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4//
5//  By downloading, copying, installing or using the software you agree to this license.
6//  If you do not agree to this license, do not download, install,
7//  copy or use the software.
8//
9//
10//                           License Agreement
11//                For Open Source Computer Vision Library
12//
13// Copyright (C) 2000, Intel Corporation, all rights reserved.
14// Copyright (C) 2014, Itseez Inc, all rights reserved.
15// Third party copyrights are property of their respective owners.
16//
17// Redistribution and use in source and binary forms, with or without modification,
18// are permitted provided that the following conditions are met:
19//
20//   * Redistribution's of source code must retain the above copyright notice,
21//     this list of conditions and the following disclaimer.
22//
23//   * Redistribution's in binary form must reproduce the above copyright notice,
24//     this list of conditions and the following disclaimer in the documentation
25//     and/or other materials provided with the distribution.
26//
27//   * The name of the copyright holders may not be used to endorse or promote products
28//     derived from this software without specific prior written permission.
29//
30// This software is provided by the copyright holders and contributors "as is" and
31// any express or implied warranties, including, but not limited to, the implied
32// warranties of merchantability and fitness for a particular purpose are disclaimed.
33// In no event shall the Intel Corporation or contributors be liable for any direct,
34// indirect, incidental, special, exemplary, or consequential damages
35// (including, but not limited to, procurement of substitute goods or services;
36// loss of use, data, or profits; or business interruption) however caused
37// and on any theory of liability, whether in contract, strict liability,
38// or tort (including negligence or otherwise) arising in any way out of
39// the use of this software, even if advised of the possibility of such damage.
40//
41//M*/
42
43#include "precomp.hpp"
44#include <ctype.h>
45
46namespace cv {
47namespace ml {
48
49using std::vector;
50
51TreeParams::TreeParams()
52{
53    maxDepth = INT_MAX;
54    minSampleCount = 10;
55    regressionAccuracy = 0.01f;
56    useSurrogates = false;
57    maxCategories = 10;
58    CVFolds = 10;
59    use1SERule = true;
60    truncatePrunedTree = true;
61    priors = Mat();
62}
63
64TreeParams::TreeParams(int _maxDepth, int _minSampleCount,
65                       double _regressionAccuracy, bool _useSurrogates,
66                       int _maxCategories, int _CVFolds,
67                       bool _use1SERule, bool _truncatePrunedTree,
68                       const Mat& _priors)
69{
70    maxDepth = _maxDepth;
71    minSampleCount = _minSampleCount;
72    regressionAccuracy = (float)_regressionAccuracy;
73    useSurrogates = _useSurrogates;
74    maxCategories = _maxCategories;
75    CVFolds = _CVFolds;
76    use1SERule = _use1SERule;
77    truncatePrunedTree = _truncatePrunedTree;
78    priors = _priors;
79}
80
81DTrees::Node::Node()
82{
83    classIdx = 0;
84    value = 0;
85    parent = left = right = split = defaultDir = -1;
86}
87
88DTrees::Split::Split()
89{
90    varIdx = 0;
91    inversed = false;
92    quality = 0.f;
93    next = -1;
94    c = 0.f;
95    subsetOfs = 0;
96}
97
98
99DTreesImpl::WorkData::WorkData(const Ptr<TrainData>& _data)
100{
101    data = _data;
102    vector<int> subsampleIdx;
103    Mat sidx0 = _data->getTrainSampleIdx();
104    if( !sidx0.empty() )
105    {
106        sidx0.copyTo(sidx);
107        std::sort(sidx.begin(), sidx.end());
108    }
109    else
110    {
111        int n = _data->getNSamples();
112        setRangeVector(sidx, n);
113    }
114
115    maxSubsetSize = 0;
116}
117
118DTreesImpl::DTreesImpl() {}
119DTreesImpl::~DTreesImpl() {}
120void DTreesImpl::clear()
121{
122    varIdx.clear();
123    compVarIdx.clear();
124    varType.clear();
125    catOfs.clear();
126    catMap.clear();
127    roots.clear();
128    nodes.clear();
129    splits.clear();
130    subsets.clear();
131    classLabels.clear();
132
133    w.release();
134    _isClassifier = false;
135}
136
137void DTreesImpl::startTraining( const Ptr<TrainData>& data, int )
138{
139    clear();
140    w = makePtr<WorkData>(data);
141
142    Mat vtype = data->getVarType();
143    vtype.copyTo(varType);
144
145    data->getCatOfs().copyTo(catOfs);
146    data->getCatMap().copyTo(catMap);
147    data->getDefaultSubstValues().copyTo(missingSubst);
148
149    int nallvars = data->getNAllVars();
150
151    Mat vidx0 = data->getVarIdx();
152    if( !vidx0.empty() )
153        vidx0.copyTo(varIdx);
154    else
155        setRangeVector(varIdx, nallvars);
156
157    initCompVarIdx();
158
159    w->maxSubsetSize = 0;
160
161    int i, nvars = (int)varIdx.size();
162    for( i = 0; i < nvars; i++ )
163        w->maxSubsetSize = std::max(w->maxSubsetSize, getCatCount(varIdx[i]));
164
165    w->maxSubsetSize = std::max((w->maxSubsetSize + 31)/32, 1);
166
167    data->getSampleWeights().copyTo(w->sample_weights);
168
169    _isClassifier = data->getResponseType() == VAR_CATEGORICAL;
170
171    if( _isClassifier )
172    {
173        data->getNormCatResponses().copyTo(w->cat_responses);
174        data->getClassLabels().copyTo(classLabels);
175        int nclasses = (int)classLabels.size();
176
177        Mat class_weights = params.priors;
178        if( !class_weights.empty() )
179        {
180            if( class_weights.type() != CV_64F || !class_weights.isContinuous() )
181            {
182                Mat temp;
183                class_weights.convertTo(temp, CV_64F);
184                class_weights = temp;
185            }
186            CV_Assert( class_weights.checkVector(1, CV_64F) == nclasses );
187
188            int nsamples = (int)w->cat_responses.size();
189            const double* cw = class_weights.ptr<double>();
190            CV_Assert( (int)w->sample_weights.size() == nsamples );
191
192            for( i = 0; i < nsamples; i++ )
193            {
194                int ci = w->cat_responses[i];
195                CV_Assert( 0 <= ci && ci < nclasses );
196                w->sample_weights[i] *= cw[ci];
197            }
198        }
199    }
200    else
201        data->getResponses().copyTo(w->ord_responses);
202}
203
204
205void DTreesImpl::initCompVarIdx()
206{
207    int nallvars = (int)varType.size();
208    compVarIdx.assign(nallvars, -1);
209    int i, nvars = (int)varIdx.size(), prevIdx = -1;
210    for( i = 0; i < nvars; i++ )
211    {
212        int vi = varIdx[i];
213        CV_Assert( 0 <= vi && vi < nallvars && vi > prevIdx );
214        prevIdx = vi;
215        compVarIdx[vi] = i;
216    }
217}
218
219void DTreesImpl::endTraining()
220{
221    w.release();
222}
223
224bool DTreesImpl::train( const Ptr<TrainData>& trainData, int flags )
225{
226    startTraining(trainData, flags);
227    bool ok = addTree( w->sidx ) >= 0;
228    w.release();
229    endTraining();
230    return ok;
231}
232
233const vector<int>& DTreesImpl::getActiveVars()
234{
235    return varIdx;
236}
237
238int DTreesImpl::addTree(const vector<int>& sidx )
239{
240    size_t n = (params.getMaxDepth() > 0 ? (1 << params.getMaxDepth()) : 1024) + w->wnodes.size();
241
242    w->wnodes.reserve(n);
243    w->wsplits.reserve(n);
244    w->wsubsets.reserve(n*w->maxSubsetSize);
245    w->wnodes.clear();
246    w->wsplits.clear();
247    w->wsubsets.clear();
248
249    int cv_n = params.getCVFolds();
250
251    if( cv_n > 0 )
252    {
253        w->cv_Tn.resize(n*cv_n);
254        w->cv_node_error.resize(n*cv_n);
255        w->cv_node_risk.resize(n*cv_n);
256    }
257
258    // build the tree recursively
259    int w_root = addNodeAndTrySplit(-1, sidx);
260    int maxdepth = INT_MAX;//pruneCV(root);
261
262    int w_nidx = w_root, pidx = -1, depth = 0;
263    int root = (int)nodes.size();
264
265    for(;;)
266    {
267        const WNode& wnode = w->wnodes[w_nidx];
268        Node node;
269        node.parent = pidx;
270        node.classIdx = wnode.class_idx;
271        node.value = wnode.value;
272        node.defaultDir = wnode.defaultDir;
273
274        int wsplit_idx = wnode.split;
275        if( wsplit_idx >= 0 )
276        {
277            const WSplit& wsplit = w->wsplits[wsplit_idx];
278            Split split;
279            split.c = wsplit.c;
280            split.quality = wsplit.quality;
281            split.inversed = wsplit.inversed;
282            split.varIdx = wsplit.varIdx;
283            split.subsetOfs = -1;
284            if( wsplit.subsetOfs >= 0 )
285            {
286                int ssize = getSubsetSize(split.varIdx);
287                split.subsetOfs = (int)subsets.size();
288                subsets.resize(split.subsetOfs + ssize);
289                // This check verifies that subsets index is in the correct range
290                // as in case ssize == 0 no real resize performed.
291                // Thus memory kept safe.
292                // Also this skips useless memcpy call when size parameter is zero
293                if(ssize > 0)
294                {
295                    memcpy(&subsets[split.subsetOfs], &w->wsubsets[wsplit.subsetOfs], ssize*sizeof(int));
296                }
297            }
298            node.split = (int)splits.size();
299            splits.push_back(split);
300        }
301        int nidx = (int)nodes.size();
302        nodes.push_back(node);
303        if( pidx >= 0 )
304        {
305            int w_pidx = w->wnodes[w_nidx].parent;
306            if( w->wnodes[w_pidx].left == w_nidx )
307            {
308                nodes[pidx].left = nidx;
309            }
310            else
311            {
312                CV_Assert(w->wnodes[w_pidx].right == w_nidx);
313                nodes[pidx].right = nidx;
314            }
315        }
316
317        if( wnode.left >= 0 && depth+1 < maxdepth )
318        {
319            w_nidx = wnode.left;
320            pidx = nidx;
321            depth++;
322        }
323        else
324        {
325            int w_pidx = wnode.parent;
326            while( w_pidx >= 0 && w->wnodes[w_pidx].right == w_nidx )
327            {
328                w_nidx = w_pidx;
329                w_pidx = w->wnodes[w_pidx].parent;
330                nidx = pidx;
331                pidx = nodes[pidx].parent;
332                depth--;
333            }
334
335            if( w_pidx < 0 )
336                break;
337
338            w_nidx = w->wnodes[w_pidx].right;
339            CV_Assert( w_nidx >= 0 );
340        }
341    }
342    roots.push_back(root);
343    return root;
344}
345
346void DTreesImpl::setDParams(const TreeParams& _params)
347{
348    params = _params;
349}
350
351int DTreesImpl::addNodeAndTrySplit( int parent, const vector<int>& sidx )
352{
353    w->wnodes.push_back(WNode());
354    int nidx = (int)(w->wnodes.size() - 1);
355    WNode& node = w->wnodes.back();
356
357    node.parent = parent;
358    node.depth = parent >= 0 ? w->wnodes[parent].depth + 1 : 0;
359    int nfolds = params.getCVFolds();
360
361    if( nfolds > 0 )
362    {
363        w->cv_Tn.resize((nidx+1)*nfolds);
364        w->cv_node_error.resize((nidx+1)*nfolds);
365        w->cv_node_risk.resize((nidx+1)*nfolds);
366    }
367
368    int i, n = node.sample_count = (int)sidx.size();
369    bool can_split = true;
370    vector<int> sleft, sright;
371
372    calcValue( nidx, sidx );
373
374    if( n <= params.getMinSampleCount() || node.depth >= params.getMaxDepth() )
375        can_split = false;
376    else if( _isClassifier )
377    {
378        const int* responses = &w->cat_responses[0];
379        const int* s = &sidx[0];
380        int first = responses[s[0]];
381        for( i = 1; i < n; i++ )
382            if( responses[s[i]] != first )
383                break;
384        if( i == n )
385            can_split = false;
386    }
387    else
388    {
389        if( sqrt(node.node_risk) < params.getRegressionAccuracy() )
390            can_split = false;
391    }
392
393    if( can_split )
394        node.split = findBestSplit( sidx );
395
396    //printf("depth=%d, nidx=%d, parent=%d, n=%d, %s, value=%.1f, risk=%.1f\n", node.depth, nidx, node.parent, n, (node.split < 0 ? "leaf" : varType[w->wsplits[node.split].varIdx] == VAR_CATEGORICAL ? "cat" : "ord"), node.value, node.node_risk);
397
398    if( node.split >= 0 )
399    {
400        node.defaultDir = calcDir( node.split, sidx, sleft, sright );
401        if( params.useSurrogates )
402            CV_Error( CV_StsNotImplemented, "surrogate splits are not implemented yet");
403
404        int left = addNodeAndTrySplit( nidx, sleft );
405        int right = addNodeAndTrySplit( nidx, sright );
406        w->wnodes[nidx].left = left;
407        w->wnodes[nidx].right = right;
408        CV_Assert( w->wnodes[nidx].left > 0 && w->wnodes[nidx].right > 0 );
409    }
410
411    return nidx;
412}
413
414int DTreesImpl::findBestSplit( const vector<int>& _sidx )
415{
416    const vector<int>& activeVars = getActiveVars();
417    int splitidx = -1;
418    int vi_, nv = (int)activeVars.size();
419    AutoBuffer<int> buf(w->maxSubsetSize*2);
420    int *subset = buf, *best_subset = subset + w->maxSubsetSize;
421    WSplit split, best_split;
422    best_split.quality = 0.;
423
424    for( vi_ = 0; vi_ < nv; vi_++ )
425    {
426        int vi = activeVars[vi_];
427        if( varType[vi] == VAR_CATEGORICAL )
428        {
429            if( _isClassifier )
430                split = findSplitCatClass(vi, _sidx, 0, subset);
431            else
432                split = findSplitCatReg(vi, _sidx, 0, subset);
433        }
434        else
435        {
436            if( _isClassifier )
437                split = findSplitOrdClass(vi, _sidx, 0);
438            else
439                split = findSplitOrdReg(vi, _sidx, 0);
440        }
441        if( split.quality > best_split.quality )
442        {
443            best_split = split;
444            std::swap(subset, best_subset);
445        }
446    }
447
448    if( best_split.quality > 0 )
449    {
450        int best_vi = best_split.varIdx;
451        CV_Assert( compVarIdx[best_split.varIdx] >= 0 && best_vi >= 0 );
452        int i, prevsz = (int)w->wsubsets.size(), ssize = getSubsetSize(best_vi);
453        w->wsubsets.resize(prevsz + ssize);
454        for( i = 0; i < ssize; i++ )
455            w->wsubsets[prevsz + i] = best_subset[i];
456        best_split.subsetOfs = prevsz;
457        w->wsplits.push_back(best_split);
458        splitidx = (int)(w->wsplits.size()-1);
459    }
460
461    return splitidx;
462}
463
464void DTreesImpl::calcValue( int nidx, const vector<int>& _sidx )
465{
466    WNode* node = &w->wnodes[nidx];
467    int i, j, k, n = (int)_sidx.size(), cv_n = params.getCVFolds();
468    int m = (int)classLabels.size();
469
470    cv::AutoBuffer<double> buf(std::max(m, 3)*(cv_n+1));
471
472    if( cv_n > 0 )
473    {
474        size_t sz = w->cv_Tn.size();
475        w->cv_Tn.resize(sz + cv_n);
476        w->cv_node_risk.resize(sz + cv_n);
477        w->cv_node_error.resize(sz + cv_n);
478    }
479
480    if( _isClassifier )
481    {
482        // in case of classification tree:
483        //  * node value is the label of the class that has the largest weight in the node.
484        //  * node risk is the weighted number of misclassified samples,
485        //  * j-th cross-validation fold value and risk are calculated as above,
486        //    but using the samples with cv_labels(*)!=j.
487        //  * j-th cross-validation fold error is calculated as the weighted number of
488        //    misclassified samples with cv_labels(*)==j.
489
490        // compute the number of instances of each class
491        double* cls_count = buf;
492        double* cv_cls_count = cls_count + m;
493
494        double max_val = -1, total_weight = 0;
495        int max_k = -1;
496
497        for( k = 0; k < m; k++ )
498            cls_count[k] = 0;
499
500        if( cv_n == 0 )
501        {
502            for( i = 0; i < n; i++ )
503            {
504                int si = _sidx[i];
505                cls_count[w->cat_responses[si]] += w->sample_weights[si];
506            }
507        }
508        else
509        {
510            for( j = 0; j < cv_n; j++ )
511                for( k = 0; k < m; k++ )
512                    cv_cls_count[j*m + k] = 0;
513
514            for( i = 0; i < n; i++ )
515            {
516                int si = _sidx[i];
517                j = w->cv_labels[si]; k = w->cat_responses[si];
518                cv_cls_count[j*m + k] += w->sample_weights[si];
519            }
520
521            for( j = 0; j < cv_n; j++ )
522                for( k = 0; k < m; k++ )
523                    cls_count[k] += cv_cls_count[j*m + k];
524        }
525
526        for( k = 0; k < m; k++ )
527        {
528            double val = cls_count[k];
529            total_weight += val;
530            if( max_val < val )
531            {
532                max_val = val;
533                max_k = k;
534            }
535        }
536
537        node->class_idx = max_k;
538        node->value = classLabels[max_k];
539        node->node_risk = total_weight - max_val;
540
541        for( j = 0; j < cv_n; j++ )
542        {
543            double sum_k = 0, sum = 0, max_val_k = 0;
544            max_val = -1; max_k = -1;
545
546            for( k = 0; k < m; k++ )
547            {
548                double val_k = cv_cls_count[j*m + k];
549                double val = cls_count[k] - val_k;
550                sum_k += val_k;
551                sum += val;
552                if( max_val < val )
553                {
554                    max_val = val;
555                    max_val_k = val_k;
556                    max_k = k;
557                }
558            }
559
560            w->cv_Tn[nidx*cv_n + j] = INT_MAX;
561            w->cv_node_risk[nidx*cv_n + j] = sum - max_val;
562            w->cv_node_error[nidx*cv_n + j] = sum_k - max_val_k;
563        }
564    }
565    else
566    {
567        // in case of regression tree:
568        //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
569        //    n is the number of samples in the node.
570        //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
571        //  * j-th cross-validation fold value and risk are calculated as above,
572        //    but using the samples with cv_labels(*)!=j.
573        //  * j-th cross-validation fold error is calculated
574        //    using samples with cv_labels(*)==j as the test subset:
575        //    error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
576        //    where node_value_j is the node value calculated
577        //    as described in the previous bullet, and summation is done
578        //    over the samples with cv_labels(*)==j.
579        double sum = 0, sum2 = 0, sumw = 0;
580
581        if( cv_n == 0 )
582        {
583            for( i = 0; i < n; i++ )
584            {
585                int si = _sidx[i];
586                double wval = w->sample_weights[si];
587                double t = w->ord_responses[si];
588                sum += t*wval;
589                sum2 += t*t*wval;
590                sumw += wval;
591            }
592        }
593        else
594        {
595            double *cv_sum = buf, *cv_sum2 = cv_sum + cv_n;
596            double* cv_count = (double*)(cv_sum2 + cv_n);
597
598            for( j = 0; j < cv_n; j++ )
599            {
600                cv_sum[j] = cv_sum2[j] = 0.;
601                cv_count[j] = 0;
602            }
603
604            for( i = 0; i < n; i++ )
605            {
606                int si = _sidx[i];
607                j = w->cv_labels[si];
608                double wval = w->sample_weights[si];
609                double t = w->ord_responses[si];
610                cv_sum[j] += t*wval;
611                cv_sum2[j] += t*t*wval;
612                cv_count[j] += wval;
613            }
614
615            for( j = 0; j < cv_n; j++ )
616            {
617                sum += cv_sum[j];
618                sum2 += cv_sum2[j];
619                sumw += cv_count[j];
620            }
621
622            for( j = 0; j < cv_n; j++ )
623            {
624                double s = sum - cv_sum[j], si = sum - s;
625                double s2 = sum2 - cv_sum2[j], s2i = sum2 - s2;
626                double c = cv_count[j], ci = sumw - c;
627                double r = si/std::max(ci, DBL_EPSILON);
628                w->cv_node_risk[nidx*cv_n + j] = s2i - r*r*ci;
629                w->cv_node_error[nidx*cv_n + j] = s2 - 2*r*s + c*r*r;
630                w->cv_Tn[nidx*cv_n + j] = INT_MAX;
631            }
632        }
633
634        node->node_risk = sum2 - (sum/sumw)*sum;
635        node->value = sum/sumw;
636    }
637}
638
639DTreesImpl::WSplit DTreesImpl::findSplitOrdClass( int vi, const vector<int>& _sidx, double initQuality )
640{
641    const double epsilon = FLT_EPSILON*2;
642    int n = (int)_sidx.size();
643    int m = (int)classLabels.size();
644
645    cv::AutoBuffer<uchar> buf(n*(sizeof(float) + sizeof(int)) + m*2*sizeof(double));
646    const int* sidx = &_sidx[0];
647    const int* responses = &w->cat_responses[0];
648    const double* weights = &w->sample_weights[0];
649    double* lcw = (double*)(uchar*)buf;
650    double* rcw = lcw + m;
651    float* values = (float*)(rcw + m);
652    int* sorted_idx = (int*)(values + n);
653    int i, best_i = -1;
654    double best_val = initQuality;
655
656    for( i = 0; i < m; i++ )
657        lcw[i] = rcw[i] = 0.;
658
659    w->data->getValues( vi, _sidx, values );
660
661    for( i = 0; i < n; i++ )
662    {
663        sorted_idx[i] = i;
664        int si = sidx[i];
665        rcw[responses[si]] += weights[si];
666    }
667
668    std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values));
669
670    double L = 0, R = 0, lsum2 = 0, rsum2 = 0;
671    for( i = 0; i < m; i++ )
672    {
673        double wval = rcw[i];
674        R += wval;
675        rsum2 += wval*wval;
676    }
677
678    for( i = 0; i < n - 1; i++ )
679    {
680        int curr = sorted_idx[i];
681        int next = sorted_idx[i+1];
682        int si = sidx[curr];
683        double wval = weights[si], w2 = wval*wval;
684        L += wval; R -= wval;
685        int idx = responses[si];
686        double lv = lcw[idx], rv = rcw[idx];
687        lsum2 += 2*lv*wval + w2;
688        rsum2 -= 2*rv*wval - w2;
689        lcw[idx] = lv + wval; rcw[idx] = rv - wval;
690
691        if( values[curr] + epsilon < values[next] )
692        {
693            double val = (lsum2*R + rsum2*L)/(L*R);
694            if( best_val < val )
695            {
696                best_val = val;
697                best_i = i;
698            }
699        }
700    }
701
702    WSplit split;
703    if( best_i >= 0 )
704    {
705        split.varIdx = vi;
706        split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f;
707        split.inversed = false;
708        split.quality = (float)best_val;
709    }
710    return split;
711}
712
713// simple k-means, slightly modified to take into account the "weight" (L1-norm) of each vector.
714void DTreesImpl::clusterCategories( const double* vectors, int n, int m, double* csums, int k, int* labels )
715{
716    int iters = 0, max_iters = 100;
717    int i, j, idx;
718    cv::AutoBuffer<double> buf(n + k);
719    double *v_weights = buf, *c_weights = buf + n;
720    bool modified = true;
721    RNG r((uint64)-1);
722
723    // assign labels randomly
724    for( i = 0; i < n; i++ )
725    {
726        double sum = 0;
727        const double* v = vectors + i*m;
728        labels[i] = i < k ? i : r.uniform(0, k);
729
730        // compute weight of each vector
731        for( j = 0; j < m; j++ )
732            sum += v[j];
733        v_weights[i] = sum ? 1./sum : 0.;
734    }
735
736    for( i = 0; i < n; i++ )
737    {
738        int i1 = r.uniform(0, n);
739        int i2 = r.uniform(0, n);
740        std::swap( labels[i1], labels[i2] );
741    }
742
743    for( iters = 0; iters <= max_iters; iters++ )
744    {
745        // calculate csums
746        for( i = 0; i < k; i++ )
747        {
748            for( j = 0; j < m; j++ )
749                csums[i*m + j] = 0;
750        }
751
752        for( i = 0; i < n; i++ )
753        {
754            const double* v = vectors + i*m;
755            double* s = csums + labels[i]*m;
756            for( j = 0; j < m; j++ )
757                s[j] += v[j];
758        }
759
760        // exit the loop here, when we have up-to-date csums
761        if( iters == max_iters || !modified )
762            break;
763
764        modified = false;
765
766        // calculate weight of each cluster
767        for( i = 0; i < k; i++ )
768        {
769            const double* s = csums + i*m;
770            double sum = 0;
771            for( j = 0; j < m; j++ )
772                sum += s[j];
773            c_weights[i] = sum ? 1./sum : 0;
774        }
775
776        // now for each vector determine the closest cluster
777        for( i = 0; i < n; i++ )
778        {
779            const double* v = vectors + i*m;
780            double alpha = v_weights[i];
781            double min_dist2 = DBL_MAX;
782            int min_idx = -1;
783
784            for( idx = 0; idx < k; idx++ )
785            {
786                const double* s = csums + idx*m;
787                double dist2 = 0., beta = c_weights[idx];
788                for( j = 0; j < m; j++ )
789                {
790                    double t = v[j]*alpha - s[j]*beta;
791                    dist2 += t*t;
792                }
793                if( min_dist2 > dist2 )
794                {
795                    min_dist2 = dist2;
796                    min_idx = idx;
797                }
798            }
799
800            if( min_idx != labels[i] )
801                modified = true;
802            labels[i] = min_idx;
803        }
804    }
805}
806
807DTreesImpl::WSplit DTreesImpl::findSplitCatClass( int vi, const vector<int>& _sidx,
808                                                  double initQuality, int* subset )
809{
810    int _mi = getCatCount(vi), mi = _mi;
811    int n = (int)_sidx.size();
812    int m = (int)classLabels.size();
813
814    int base_size = m*(3 + mi) + mi + 1;
815    if( m > 2 && mi > params.getMaxCategories() )
816        base_size += m*std::min(params.getMaxCategories(), n) + mi;
817    else
818        base_size += mi;
819    AutoBuffer<double> buf(base_size + n);
820
821    double* lc = (double*)buf;
822    double* rc = lc + m;
823    double* _cjk = rc + m*2, *cjk = _cjk;
824    double* c_weights = cjk + m*mi;
825
826    int* labels = (int*)(buf + base_size);
827    w->data->getNormCatValues(vi, _sidx, labels);
828    const int* responses = &w->cat_responses[0];
829    const double* weights = &w->sample_weights[0];
830
831    int* cluster_labels = 0;
832    double** dbl_ptr = 0;
833    int i, j, k, si, idx;
834    double L = 0, R = 0;
835    double best_val = initQuality;
836    int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
837
838    // init array of counters:
839    // c_{jk} - number of samples that have vi-th input variable = j and response = k.
840    for( j = -1; j < mi; j++ )
841        for( k = 0; k < m; k++ )
842            cjk[j*m + k] = 0;
843
844    for( i = 0; i < n; i++ )
845    {
846        si = _sidx[i];
847        j = labels[i];
848        k = responses[si];
849        cjk[j*m + k] += weights[si];
850    }
851
852    if( m > 2 )
853    {
854        if( mi > params.getMaxCategories() )
855        {
856            mi = std::min(params.getMaxCategories(), n);
857            cjk = c_weights + _mi;
858            cluster_labels = (int*)(cjk + m*mi);
859            clusterCategories( _cjk, _mi, m, cjk, mi, cluster_labels );
860        }
861        subset_i = 1;
862        subset_n = 1 << mi;
863    }
864    else
865    {
866        assert( m == 2 );
867        dbl_ptr = (double**)(c_weights + _mi);
868        for( j = 0; j < mi; j++ )
869            dbl_ptr[j] = cjk + j*2 + 1;
870        std::sort(dbl_ptr, dbl_ptr + mi, cmp_lt_ptr<double>());
871        subset_i = 0;
872        subset_n = mi;
873    }
874
875    for( k = 0; k < m; k++ )
876    {
877        double sum = 0;
878        for( j = 0; j < mi; j++ )
879            sum += cjk[j*m + k];
880        CV_Assert(sum > 0);
881        rc[k] = sum;
882        lc[k] = 0;
883    }
884
885    for( j = 0; j < mi; j++ )
886    {
887        double sum = 0;
888        for( k = 0; k < m; k++ )
889            sum += cjk[j*m + k];
890        c_weights[j] = sum;
891        R += c_weights[j];
892    }
893
894    for( ; subset_i < subset_n; subset_i++ )
895    {
896        double lsum2 = 0, rsum2 = 0;
897
898        if( m == 2 )
899            idx = (int)(dbl_ptr[subset_i] - cjk)/2;
900        else
901        {
902            int graycode = (subset_i>>1)^subset_i;
903            int diff = graycode ^ prevcode;
904
905            // determine index of the changed bit.
906            Cv32suf u;
907            idx = diff >= (1 << 16) ? 16 : 0;
908            u.f = (float)(((diff >> 16) | diff) & 65535);
909            idx += (u.i >> 23) - 127;
910            subtract = graycode < prevcode;
911            prevcode = graycode;
912        }
913
914        double* crow = cjk + idx*m;
915        double weight = c_weights[idx];
916        if( weight < FLT_EPSILON )
917            continue;
918
919        if( !subtract )
920        {
921            for( k = 0; k < m; k++ )
922            {
923                double t = crow[k];
924                double lval = lc[k] + t;
925                double rval = rc[k] - t;
926                lsum2 += lval*lval;
927                rsum2 += rval*rval;
928                lc[k] = lval; rc[k] = rval;
929            }
930            L += weight;
931            R -= weight;
932        }
933        else
934        {
935            for( k = 0; k < m; k++ )
936            {
937                double t = crow[k];
938                double lval = lc[k] - t;
939                double rval = rc[k] + t;
940                lsum2 += lval*lval;
941                rsum2 += rval*rval;
942                lc[k] = lval; rc[k] = rval;
943            }
944            L -= weight;
945            R += weight;
946        }
947
948        if( L > FLT_EPSILON && R > FLT_EPSILON )
949        {
950            double val = (lsum2*R + rsum2*L)/(L*R);
951            if( best_val < val )
952            {
953                best_val = val;
954                best_subset = subset_i;
955            }
956        }
957    }
958
959    WSplit split;
960    if( best_subset >= 0 )
961    {
962        split.varIdx = vi;
963        split.quality = (float)best_val;
964        memset( subset, 0, getSubsetSize(vi) * sizeof(int) );
965        if( m == 2 )
966        {
967            for( i = 0; i <= best_subset; i++ )
968            {
969                idx = (int)(dbl_ptr[i] - cjk) >> 1;
970                subset[idx >> 5] |= 1 << (idx & 31);
971            }
972        }
973        else
974        {
975            for( i = 0; i < _mi; i++ )
976            {
977                idx = cluster_labels ? cluster_labels[i] : i;
978                if( best_subset & (1 << idx) )
979                    subset[i >> 5] |= 1 << (i & 31);
980            }
981        }
982    }
983    return split;
984}
985
986DTreesImpl::WSplit DTreesImpl::findSplitOrdReg( int vi, const vector<int>& _sidx, double initQuality )
987{
988    const float epsilon = FLT_EPSILON*2;
989    const double* weights = &w->sample_weights[0];
990    int n = (int)_sidx.size();
991
992    AutoBuffer<uchar> buf(n*(sizeof(int) + sizeof(float)));
993
994    float* values = (float*)(uchar*)buf;
995    int* sorted_idx = (int*)(values + n);
996    w->data->getValues(vi, _sidx, values);
997    const double* responses = &w->ord_responses[0];
998
999    int i, si, best_i = -1;
1000    double L = 0, R = 0;
1001    double best_val = initQuality, lsum = 0, rsum = 0;
1002
1003    for( i = 0; i < n; i++ )
1004    {
1005        sorted_idx[i] = i;
1006        si = _sidx[i];
1007        R += weights[si];
1008        rsum += weights[si]*responses[si];
1009    }
1010
1011    std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values));
1012
1013    // find the optimal split
1014    for( i = 0; i < n - 1; i++ )
1015    {
1016        int curr = sorted_idx[i];
1017        int next = sorted_idx[i+1];
1018        si = _sidx[curr];
1019        double wval = weights[si];
1020        double t = responses[si]*wval;
1021        L += wval; R -= wval;
1022        lsum += t; rsum -= t;
1023
1024        if( values[curr] + epsilon < values[next] )
1025        {
1026            double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
1027            if( best_val < val )
1028            {
1029                best_val = val;
1030                best_i = i;
1031            }
1032        }
1033    }
1034
1035    WSplit split;
1036    if( best_i >= 0 )
1037    {
1038        split.varIdx = vi;
1039        split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f;
1040        split.inversed = false;
1041        split.quality = (float)best_val;
1042    }
1043    return split;
1044}
1045
1046DTreesImpl::WSplit DTreesImpl::findSplitCatReg( int vi, const vector<int>& _sidx,
1047                                                double initQuality, int* subset )
1048{
1049    const double* weights = &w->sample_weights[0];
1050    const double* responses = &w->ord_responses[0];
1051    int n = (int)_sidx.size();
1052    int mi = getCatCount(vi);
1053
1054    AutoBuffer<double> buf(3*mi + 3 + n);
1055    double* sum = (double*)buf + 1;
1056    double* counts = sum + mi + 1;
1057    double** sum_ptr = (double**)(counts + mi);
1058    int* cat_labels = (int*)(sum_ptr + mi);
1059
1060    w->data->getNormCatValues(vi, _sidx, cat_labels);
1061
1062    double L = 0, R = 0, best_val = initQuality, lsum = 0, rsum = 0;
1063    int i, si, best_subset = -1, subset_i;
1064
1065    for( i = -1; i < mi; i++ )
1066        sum[i] = counts[i] = 0;
1067
1068    // calculate sum response and weight of each category of the input var
1069    for( i = 0; i < n; i++ )
1070    {
1071        int idx = cat_labels[i];
1072        si = _sidx[i];
1073        double wval = weights[si];
1074        sum[idx] += responses[si]*wval;
1075        counts[idx] += wval;
1076    }
1077
1078    // calculate average response in each category
1079    for( i = 0; i < mi; i++ )
1080    {
1081        R += counts[i];
1082        rsum += sum[i];
1083        sum[i] = fabs(counts[i]) > DBL_EPSILON ? sum[i]/counts[i] : 0;
1084        sum_ptr[i] = sum + i;
1085    }
1086
1087    std::sort(sum_ptr, sum_ptr + mi, cmp_lt_ptr<double>());
1088
1089    // revert back to unnormalized sums
1090    // (there should be a very little loss in accuracy)
1091    for( i = 0; i < mi; i++ )
1092        sum[i] *= counts[i];
1093
1094    for( subset_i = 0; subset_i < mi-1; subset_i++ )
1095    {
1096        int idx = (int)(sum_ptr[subset_i] - sum);
1097        double ni = counts[idx];
1098
1099        if( ni > FLT_EPSILON )
1100        {
1101            double s = sum[idx];
1102            lsum += s; L += ni;
1103            rsum -= s; R -= ni;
1104
1105            if( L > FLT_EPSILON && R > FLT_EPSILON )
1106            {
1107                double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
1108                if( best_val < val )
1109                {
1110                    best_val = val;
1111                    best_subset = subset_i;
1112                }
1113            }
1114        }
1115    }
1116
1117    WSplit split;
1118    if( best_subset >= 0 )
1119    {
1120        split.varIdx = vi;
1121        split.quality = (float)best_val;
1122        memset( subset, 0, getSubsetSize(vi) * sizeof(int));
1123        for( i = 0; i <= best_subset; i++ )
1124        {
1125            int idx = (int)(sum_ptr[i] - sum);
1126            subset[idx >> 5] |= 1 << (idx & 31);
1127        }
1128    }
1129    return split;
1130}
1131
1132int DTreesImpl::calcDir( int splitidx, const vector<int>& _sidx,
1133                         vector<int>& _sleft, vector<int>& _sright )
1134{
1135    WSplit split = w->wsplits[splitidx];
1136    int i, si, n = (int)_sidx.size(), vi = split.varIdx;
1137    _sleft.reserve(n);
1138    _sright.reserve(n);
1139    _sleft.clear();
1140    _sright.clear();
1141
1142    AutoBuffer<float> buf(n);
1143    int mi = getCatCount(vi);
1144    double wleft = 0, wright = 0;
1145    const double* weights = &w->sample_weights[0];
1146
1147    if( mi <= 0 ) // split on an ordered variable
1148    {
1149        float c = split.c;
1150        float* values = buf;
1151        w->data->getValues(vi, _sidx, values);
1152
1153        for( i = 0; i < n; i++ )
1154        {
1155            si = _sidx[i];
1156            if( values[i] <= c )
1157            {
1158                _sleft.push_back(si);
1159                wleft += weights[si];
1160            }
1161            else
1162            {
1163                _sright.push_back(si);
1164                wright += weights[si];
1165            }
1166        }
1167    }
1168    else
1169    {
1170        const int* subset = &w->wsubsets[split.subsetOfs];
1171        int* cat_labels = (int*)(float*)buf;
1172        w->data->getNormCatValues(vi, _sidx, cat_labels);
1173
1174        for( i = 0; i < n; i++ )
1175        {
1176            si = _sidx[i];
1177            unsigned u = cat_labels[i];
1178            if( CV_DTREE_CAT_DIR(u, subset) < 0 )
1179            {
1180                _sleft.push_back(si);
1181                wleft += weights[si];
1182            }
1183            else
1184            {
1185                _sright.push_back(si);
1186                wright += weights[si];
1187            }
1188        }
1189    }
1190    CV_Assert( (int)_sleft.size() < n && (int)_sright.size() < n );
1191    return wleft > wright ? -1 : 1;
1192}
1193
1194int DTreesImpl::pruneCV( int root )
1195{
1196    vector<double> ab;
1197
1198    // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
1199    // 2. choose the best tree index (if need, apply 1SE rule).
1200    // 3. store the best index and cut the branches.
1201
1202    int ti, tree_count = 0, j, cv_n = params.getCVFolds(), n = w->wnodes[root].sample_count;
1203    // currently, 1SE for regression is not implemented
1204    bool use_1se = params.use1SERule != 0 && _isClassifier;
1205    double min_err = 0, min_err_se = 0;
1206    int min_idx = -1;
1207
1208    // build the main tree sequence, calculate alpha's
1209    for(;;tree_count++)
1210    {
1211        double min_alpha = updateTreeRNC(root, tree_count, -1);
1212        if( cutTree(root, tree_count, -1, min_alpha) )
1213            break;
1214
1215        ab.push_back(min_alpha);
1216    }
1217
1218    if( tree_count > 0 )
1219    {
1220        ab[0] = 0.;
1221
1222        for( ti = 1; ti < tree_count-1; ti++ )
1223            ab[ti] = std::sqrt(ab[ti]*ab[ti+1]);
1224        ab[tree_count-1] = DBL_MAX*0.5;
1225
1226        Mat err_jk(cv_n, tree_count, CV_64F);
1227
1228        for( j = 0; j < cv_n; j++ )
1229        {
1230            int tj = 0, tk = 0;
1231            for( ; tj < tree_count; tj++ )
1232            {
1233                double min_alpha = updateTreeRNC(root, tj, j);
1234                if( cutTree(root, tj, j, min_alpha) )
1235                    min_alpha = DBL_MAX;
1236
1237                for( ; tk < tree_count; tk++ )
1238                {
1239                    if( ab[tk] > min_alpha )
1240                        break;
1241                    err_jk.at<double>(j, tk) = w->wnodes[root].tree_error;
1242                }
1243            }
1244        }
1245
1246        for( ti = 0; ti < tree_count; ti++ )
1247        {
1248            double sum_err = 0;
1249            for( j = 0; j < cv_n; j++ )
1250                sum_err += err_jk.at<double>(j, ti);
1251            if( ti == 0 || sum_err < min_err )
1252            {
1253                min_err = sum_err;
1254                min_idx = ti;
1255                if( use_1se )
1256                    min_err_se = sqrt( sum_err*(n - sum_err) );
1257            }
1258            else if( sum_err < min_err + min_err_se )
1259                min_idx = ti;
1260        }
1261    }
1262
1263    return min_idx;
1264}
1265
1266double DTreesImpl::updateTreeRNC( int root, double T, int fold )
1267{
1268    int nidx = root, pidx = -1, cv_n = params.getCVFolds();
1269    double min_alpha = DBL_MAX;
1270
1271    for(;;)
1272    {
1273        WNode *node = 0, *parent = 0;
1274
1275        for(;;)
1276        {
1277            node = &w->wnodes[nidx];
1278            double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn;
1279            if( t <= T || node->left < 0 )
1280            {
1281                node->complexity = 1;
1282                node->tree_risk = node->node_risk;
1283                node->tree_error = 0.;
1284                if( fold >= 0 )
1285                {
1286                    node->tree_risk = w->cv_node_risk[nidx*cv_n + fold];
1287                    node->tree_error = w->cv_node_error[nidx*cv_n + fold];
1288                }
1289                break;
1290            }
1291            nidx = node->left;
1292        }
1293
1294        for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx;
1295             nidx = pidx, pidx = w->wnodes[pidx].parent )
1296        {
1297            node = &w->wnodes[nidx];
1298            parent = &w->wnodes[pidx];
1299            parent->complexity += node->complexity;
1300            parent->tree_risk += node->tree_risk;
1301            parent->tree_error += node->tree_error;
1302
1303            parent->alpha = ((fold >= 0 ? w->cv_node_risk[pidx*cv_n + fold] : parent->node_risk)
1304                             - parent->tree_risk)/(parent->complexity - 1);
1305            min_alpha = std::min( min_alpha, parent->alpha );
1306        }
1307
1308        if( pidx < 0 )
1309            break;
1310
1311        node = &w->wnodes[nidx];
1312        parent = &w->wnodes[pidx];
1313        parent->complexity = node->complexity;
1314        parent->tree_risk = node->tree_risk;
1315        parent->tree_error = node->tree_error;
1316        nidx = parent->right;
1317    }
1318
1319    return min_alpha;
1320}
1321
1322bool DTreesImpl::cutTree( int root, double T, int fold, double min_alpha )
1323{
1324    int cv_n = params.getCVFolds(), nidx = root, pidx = -1;
1325    WNode* node = &w->wnodes[root];
1326    if( node->left < 0 )
1327        return true;
1328
1329    for(;;)
1330    {
1331        for(;;)
1332        {
1333            node = &w->wnodes[nidx];
1334            double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn;
1335            if( t <= T || node->left < 0 )
1336                break;
1337            if( node->alpha <= min_alpha + FLT_EPSILON )
1338            {
1339                if( fold >= 0 )
1340                    w->cv_Tn[nidx*cv_n + fold] = T;
1341                else
1342                    node->Tn = T;
1343                if( nidx == root )
1344                    return true;
1345                break;
1346            }
1347            nidx = node->left;
1348        }
1349
1350        for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx;
1351             nidx = pidx, pidx = w->wnodes[pidx].parent )
1352            ;
1353
1354        if( pidx < 0 )
1355            break;
1356
1357        nidx = w->wnodes[pidx].right;
1358    }
1359
1360    return false;
1361}
1362
1363float DTreesImpl::predictTrees( const Range& range, const Mat& sample, int flags ) const
1364{
1365    CV_Assert( sample.type() == CV_32F );
1366
1367    int predictType = flags & PREDICT_MASK;
1368    int nvars = (int)varIdx.size();
1369    if( nvars == 0 )
1370        nvars = (int)varType.size();
1371    int i, ncats = (int)catOfs.size(), nclasses = (int)classLabels.size();
1372    int catbufsize = ncats > 0 ? nvars : 0;
1373    AutoBuffer<int> buf(nclasses + catbufsize + 1);
1374    int* votes = buf;
1375    int* catbuf = votes + nclasses;
1376    const int* cvidx = (flags & (COMPRESSED_INPUT|PREPROCESSED_INPUT)) == 0 && !varIdx.empty() ? &compVarIdx[0] : 0;
1377    const uchar* vtype = &varType[0];
1378    const Vec2i* cofs = !catOfs.empty() ? &catOfs[0] : 0;
1379    const int* cmap = !catMap.empty() ? &catMap[0] : 0;
1380    const float* psample = sample.ptr<float>();
1381    const float* missingSubstPtr = !missingSubst.empty() ? &missingSubst[0] : 0;
1382    size_t sstep = sample.isContinuous() ? 1 : sample.step/sizeof(float);
1383    double sum = 0.;
1384    int lastClassIdx = -1;
1385    const float MISSED_VAL = TrainData::missingValue();
1386
1387    for( i = 0; i < catbufsize; i++ )
1388        catbuf[i] = -1;
1389
1390    if( predictType == PREDICT_AUTO )
1391    {
1392        predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ?
1393            PREDICT_SUM : PREDICT_MAX_VOTE;
1394    }
1395
1396    if( predictType == PREDICT_MAX_VOTE )
1397    {
1398        for( i = 0; i < nclasses; i++ )
1399            votes[i] = 0;
1400    }
1401
1402    for( int ridx = range.start; ridx < range.end; ridx++ )
1403    {
1404        int nidx = roots[ridx], prev = nidx, c = 0;
1405
1406        for(;;)
1407        {
1408            prev = nidx;
1409            const Node& node = nodes[nidx];
1410            if( node.split < 0 )
1411                break;
1412            const Split& split = splits[node.split];
1413            int vi = split.varIdx;
1414            int ci = cvidx ? cvidx[vi] : vi;
1415            float val = psample[ci*sstep];
1416            if( val == MISSED_VAL )
1417            {
1418                if( !missingSubstPtr )
1419                {
1420                    nidx = node.defaultDir < 0 ? node.left : node.right;
1421                    continue;
1422                }
1423                val = missingSubstPtr[vi];
1424            }
1425
1426            if( vtype[vi] == VAR_ORDERED )
1427                nidx = val <= split.c ? node.left : node.right;
1428            else
1429            {
1430                if( flags & PREPROCESSED_INPUT )
1431                    c = cvRound(val);
1432                else
1433                {
1434                    c = catbuf[ci];
1435                    if( c < 0 )
1436                    {
1437                        int a = c = cofs[vi][0];
1438                        int b = cofs[vi][1];
1439
1440                        int ival = cvRound(val);
1441                        if( ival != val )
1442                            CV_Error( CV_StsBadArg,
1443                                     "one of input categorical variable is not an integer" );
1444
1445                        while( a < b )
1446                        {
1447                            c = (a + b) >> 1;
1448                            if( ival < cmap[c] )
1449                                b = c;
1450                            else if( ival > cmap[c] )
1451                                a = c+1;
1452                            else
1453                                break;
1454                        }
1455
1456                        CV_Assert( c >= 0 && ival == cmap[c] );
1457
1458                        c -= cofs[vi][0];
1459                        catbuf[ci] = c;
1460                    }
1461                    const int* subset = &subsets[split.subsetOfs];
1462                    unsigned u = c;
1463                    nidx = CV_DTREE_CAT_DIR(u, subset) < 0 ? node.left : node.right;
1464                }
1465            }
1466        }
1467
1468        if( predictType == PREDICT_SUM )
1469            sum += nodes[prev].value;
1470        else
1471        {
1472            lastClassIdx = nodes[prev].classIdx;
1473            votes[lastClassIdx]++;
1474        }
1475    }
1476
1477    if( predictType == PREDICT_MAX_VOTE )
1478    {
1479        int best_idx = lastClassIdx;
1480        if( range.end - range.start > 1 )
1481        {
1482            best_idx = 0;
1483            for( i = 1; i < nclasses; i++ )
1484                if( votes[best_idx] < votes[i] )
1485                    best_idx = i;
1486        }
1487        sum = (flags & RAW_OUTPUT) ? (float)best_idx : classLabels[best_idx];
1488    }
1489
1490    return (float)sum;
1491}
1492
1493
1494float DTreesImpl::predict( InputArray _samples, OutputArray _results, int flags ) const
1495{
1496    CV_Assert( !roots.empty() );
1497    Mat samples = _samples.getMat(), results;
1498    int i, nsamples = samples.rows;
1499    int rtype = CV_32F;
1500    bool needresults = _results.needed();
1501    float retval = 0.f;
1502    bool iscls = isClassifier();
1503    float scale = !iscls ? 1.f/(int)roots.size() : 1.f;
1504
1505    if( iscls && (flags & PREDICT_MASK) == PREDICT_MAX_VOTE )
1506        rtype = CV_32S;
1507
1508    if( needresults )
1509    {
1510        _results.create(nsamples, 1, rtype);
1511        results = _results.getMat();
1512    }
1513    else
1514        nsamples = std::min(nsamples, 1);
1515
1516    for( i = 0; i < nsamples; i++ )
1517    {
1518        float val = predictTrees( Range(0, (int)roots.size()), samples.row(i), flags )*scale;
1519        if( needresults )
1520        {
1521            if( rtype == CV_32F )
1522                results.at<float>(i) = val;
1523            else
1524                results.at<int>(i) = cvRound(val);
1525        }
1526        if( i == 0 )
1527            retval = val;
1528    }
1529    return retval;
1530}
1531
1532void DTreesImpl::writeTrainingParams(FileStorage& fs) const
1533{
1534    fs << "use_surrogates" << (params.useSurrogates ? 1 : 0);
1535    fs << "max_categories" << params.getMaxCategories();
1536    fs << "regression_accuracy" << params.getRegressionAccuracy();
1537
1538    fs << "max_depth" << params.getMaxDepth();
1539    fs << "min_sample_count" << params.getMinSampleCount();
1540    fs << "cross_validation_folds" << params.getCVFolds();
1541
1542    if( params.getCVFolds() > 1 )
1543        fs << "use_1se_rule" << (params.use1SERule ? 1 : 0);
1544
1545    if( !params.priors.empty() )
1546        fs << "priors" << params.priors;
1547}
1548
1549void DTreesImpl::writeParams(FileStorage& fs) const
1550{
1551    fs << "is_classifier" << isClassifier();
1552    fs << "var_all" << (int)varType.size();
1553    fs << "var_count" << getVarCount();
1554
1555    int ord_var_count = 0, cat_var_count = 0;
1556    int i, n = (int)varType.size();
1557    for( i = 0; i < n; i++ )
1558        if( varType[i] == VAR_ORDERED )
1559            ord_var_count++;
1560        else
1561            cat_var_count++;
1562    fs << "ord_var_count" << ord_var_count;
1563    fs << "cat_var_count" << cat_var_count;
1564
1565    fs << "training_params" << "{";
1566    writeTrainingParams(fs);
1567
1568    fs << "}";
1569
1570    if( !varIdx.empty() )
1571    {
1572        fs << "global_var_idx" << 1;
1573        fs << "var_idx" << varIdx;
1574    }
1575
1576    fs << "var_type" << varType;
1577
1578    if( !catOfs.empty() )
1579        fs << "cat_ofs" << catOfs;
1580    if( !catMap.empty() )
1581        fs << "cat_map" << catMap;
1582    if( !classLabels.empty() )
1583        fs << "class_labels" << classLabels;
1584    if( !missingSubst.empty() )
1585        fs << "missing_subst" << missingSubst;
1586}
1587
1588void DTreesImpl::writeSplit( FileStorage& fs, int splitidx ) const
1589{
1590    const Split& split = splits[splitidx];
1591
1592    fs << "{:";
1593
1594    int vi = split.varIdx;
1595    fs << "var" << vi;
1596    fs << "quality" << split.quality;
1597
1598    if( varType[vi] == VAR_CATEGORICAL ) // split on a categorical var
1599    {
1600        int i, n = getCatCount(vi), to_right = 0;
1601        const int* subset = &subsets[split.subsetOfs];
1602        for( i = 0; i < n; i++ )
1603            to_right += CV_DTREE_CAT_DIR(i, subset) > 0;
1604
1605        // ad-hoc rule when to use inverse categorical split notation
1606        // to achieve more compact and clear representation
1607        int default_dir = to_right <= 1 || to_right <= std::min(3, n/2) || to_right <= n/3 ? -1 : 1;
1608
1609        fs << (default_dir*(split.inversed ? -1 : 1) > 0 ? "in" : "not_in") << "[:";
1610
1611        for( i = 0; i < n; i++ )
1612        {
1613            int dir = CV_DTREE_CAT_DIR(i, subset);
1614            if( dir*default_dir < 0 )
1615                fs << i;
1616        }
1617
1618        fs << "]";
1619    }
1620    else
1621        fs << (!split.inversed ? "le" : "gt") << split.c;
1622
1623    fs << "}";
1624}
1625
1626void DTreesImpl::writeNode( FileStorage& fs, int nidx, int depth ) const
1627{
1628    const Node& node = nodes[nidx];
1629    fs << "{";
1630    fs << "depth" << depth;
1631    fs << "value" << node.value;
1632
1633    if( _isClassifier )
1634        fs << "norm_class_idx" << node.classIdx;
1635
1636    if( node.split >= 0 )
1637    {
1638        fs << "splits" << "[";
1639
1640        for( int splitidx = node.split; splitidx >= 0; splitidx = splits[splitidx].next )
1641            writeSplit( fs, splitidx );
1642
1643        fs << "]";
1644    }
1645
1646    fs << "}";
1647}
1648
1649void DTreesImpl::writeTree( FileStorage& fs, int root ) const
1650{
1651    fs << "nodes" << "[";
1652
1653    int nidx = root, pidx = 0, depth = 0;
1654    const Node *node = 0;
1655
1656    // traverse the tree and save all the nodes in depth-first order
1657    for(;;)
1658    {
1659        for(;;)
1660        {
1661            writeNode( fs, nidx, depth );
1662            node = &nodes[nidx];
1663            if( node->left < 0 )
1664                break;
1665            nidx = node->left;
1666            depth++;
1667        }
1668
1669        for( pidx = node->parent; pidx >= 0 && nodes[pidx].right == nidx;
1670             nidx = pidx, pidx = nodes[pidx].parent )
1671            depth--;
1672
1673        if( pidx < 0 )
1674            break;
1675
1676        nidx = nodes[pidx].right;
1677    }
1678
1679    fs << "]";
1680}
1681
1682void DTreesImpl::write( FileStorage& fs ) const
1683{
1684    writeParams(fs);
1685    writeTree(fs, roots[0]);
1686}
1687
1688void DTreesImpl::readParams( const FileNode& fn )
1689{
1690    _isClassifier = (int)fn["is_classifier"] != 0;
1691    /*int var_all = (int)fn["var_all"];
1692    int var_count = (int)fn["var_count"];
1693    int cat_var_count = (int)fn["cat_var_count"];
1694    int ord_var_count = (int)fn["ord_var_count"];*/
1695
1696    FileNode tparams_node = fn["training_params"];
1697
1698    TreeParams params0 = TreeParams();
1699
1700    if( !tparams_node.empty() ) // training parameters are not necessary
1701    {
1702        params0.useSurrogates = (int)tparams_node["use_surrogates"] != 0;
1703        params0.setMaxCategories((int)(tparams_node["max_categories"].empty() ? 16 : tparams_node["max_categories"]));
1704        params0.setRegressionAccuracy((float)tparams_node["regression_accuracy"]);
1705        params0.setMaxDepth((int)tparams_node["max_depth"]);
1706        params0.setMinSampleCount((int)tparams_node["min_sample_count"]);
1707        params0.setCVFolds((int)tparams_node["cross_validation_folds"]);
1708
1709        if( params0.getCVFolds() > 1 )
1710        {
1711            params.use1SERule = (int)tparams_node["use_1se_rule"] != 0;
1712        }
1713
1714        tparams_node["priors"] >> params0.priors;
1715    }
1716
1717    readVectorOrMat(fn["var_idx"], varIdx);
1718    fn["var_type"] >> varType;
1719
1720    int format = 0;
1721    fn["format"] >> format;
1722    bool isLegacy = format < 3;
1723
1724    int varAll = (int)fn["var_all"];
1725    if (isLegacy && (int)varType.size() <= varAll)
1726    {
1727        std::vector<uchar> extendedTypes(varAll + 1, 0);
1728
1729        int i = 0, n;
1730        if (!varIdx.empty())
1731        {
1732            n = (int)varIdx.size();
1733            for (; i < n; ++i)
1734            {
1735                int var = varIdx[i];
1736                extendedTypes[var] = varType[i];
1737            }
1738        }
1739        else
1740        {
1741            n = (int)varType.size();
1742            for (; i < n; ++i)
1743            {
1744                extendedTypes[i] = varType[i];
1745            }
1746        }
1747        extendedTypes[varAll] = (uchar)(_isClassifier ? VAR_CATEGORICAL : VAR_ORDERED);
1748        extendedTypes.swap(varType);
1749    }
1750
1751    readVectorOrMat(fn["cat_map"], catMap);
1752
1753    if (isLegacy)
1754    {
1755        // generating "catOfs" from "cat_count"
1756        catOfs.clear();
1757        classLabels.clear();
1758        std::vector<int> counts;
1759        readVectorOrMat(fn["cat_count"], counts);
1760        unsigned int i = 0, j = 0, curShift = 0, size = (int)varType.size() - 1;
1761        for (; i < size; ++i)
1762        {
1763            Vec2i newOffsets(0, 0);
1764            if (varType[i] == VAR_CATEGORICAL) // only categorical vars are represented in catMap
1765            {
1766                newOffsets[0] = curShift;
1767                curShift += counts[j];
1768                newOffsets[1] = curShift;
1769                ++j;
1770            }
1771            catOfs.push_back(newOffsets);
1772        }
1773        // other elements in "catMap" are "classLabels"
1774        if (curShift < catMap.size())
1775        {
1776            classLabels.insert(classLabels.end(), catMap.begin() + curShift, catMap.end());
1777            catMap.erase(catMap.begin() + curShift, catMap.end());
1778        }
1779    }
1780    else
1781    {
1782        fn["cat_ofs"] >> catOfs;
1783        fn["missing_subst"] >> missingSubst;
1784        fn["class_labels"] >> classLabels;
1785    }
1786
1787    // init var mapping for node reading (var indexes or varIdx indexes)
1788    bool globalVarIdx = false;
1789    fn["global_var_idx"] >> globalVarIdx;
1790    if (globalVarIdx || varIdx.empty())
1791        setRangeVector(varMapping, (int)varType.size());
1792    else
1793        varMapping = varIdx;
1794
1795    initCompVarIdx();
1796    setDParams(params0);
1797}
1798
1799int DTreesImpl::readSplit( const FileNode& fn )
1800{
1801    Split split;
1802
1803    int vi = (int)fn["var"];
1804    CV_Assert( 0 <= vi && vi <= (int)varType.size() );
1805    vi = varMapping[vi]; // convert to varIdx if needed
1806    split.varIdx = vi;
1807
1808    if( varType[vi] == VAR_CATEGORICAL ) // split on categorical var
1809    {
1810        int i, val, ssize = getSubsetSize(vi);
1811        split.subsetOfs = (int)subsets.size();
1812        for( i = 0; i < ssize; i++ )
1813            subsets.push_back(0);
1814        int* subset = &subsets[split.subsetOfs];
1815        FileNode fns = fn["in"];
1816        if( fns.empty() )
1817        {
1818            fns = fn["not_in"];
1819            split.inversed = true;
1820        }
1821
1822        if( fns.isInt() )
1823        {
1824            val = (int)fns;
1825            subset[val >> 5] |= 1 << (val & 31);
1826        }
1827        else
1828        {
1829            FileNodeIterator it = fns.begin();
1830            int n = (int)fns.size();
1831            for( i = 0; i < n; i++, ++it )
1832            {
1833                val = (int)*it;
1834                subset[val >> 5] |= 1 << (val & 31);
1835            }
1836        }
1837
1838        // for categorical splits we do not use inversed splits,
1839        // instead we inverse the variable set in the split
1840        if( split.inversed )
1841        {
1842            for( i = 0; i < ssize; i++ )
1843                subset[i] ^= -1;
1844            split.inversed = false;
1845        }
1846    }
1847    else
1848    {
1849        FileNode cmpNode = fn["le"];
1850        if( cmpNode.empty() )
1851        {
1852            cmpNode = fn["gt"];
1853            split.inversed = true;
1854        }
1855        split.c = (float)cmpNode;
1856    }
1857
1858    split.quality = (float)fn["quality"];
1859    splits.push_back(split);
1860
1861    return (int)(splits.size() - 1);
1862}
1863
1864int DTreesImpl::readNode( const FileNode& fn )
1865{
1866    Node node;
1867    node.value = (double)fn["value"];
1868
1869    if( _isClassifier )
1870        node.classIdx = (int)fn["norm_class_idx"];
1871
1872    FileNode sfn = fn["splits"];
1873    if( !sfn.empty() )
1874    {
1875        int i, n = (int)sfn.size(), prevsplit = -1;
1876        FileNodeIterator it = sfn.begin();
1877
1878        for( i = 0; i < n; i++, ++it )
1879        {
1880            int splitidx = readSplit(*it);
1881            if( splitidx < 0 )
1882                break;
1883            if( prevsplit < 0 )
1884                node.split = splitidx;
1885            else
1886                splits[prevsplit].next = splitidx;
1887            prevsplit = splitidx;
1888        }
1889    }
1890    nodes.push_back(node);
1891    return (int)(nodes.size() - 1);
1892}
1893
1894int DTreesImpl::readTree( const FileNode& fn )
1895{
1896    int i, n = (int)fn.size(), root = -1, pidx = -1;
1897    FileNodeIterator it = fn.begin();
1898
1899    for( i = 0; i < n; i++, ++it )
1900    {
1901        int nidx = readNode(*it);
1902        if( nidx < 0 )
1903            break;
1904        Node& node = nodes[nidx];
1905        node.parent = pidx;
1906        if( pidx < 0 )
1907            root = nidx;
1908        else
1909        {
1910            Node& parent = nodes[pidx];
1911            if( parent.left < 0 )
1912                parent.left = nidx;
1913            else
1914                parent.right = nidx;
1915        }
1916        if( node.split >= 0 )
1917            pidx = nidx;
1918        else
1919        {
1920            while( pidx >= 0 && nodes[pidx].right >= 0 )
1921                pidx = nodes[pidx].parent;
1922        }
1923    }
1924    roots.push_back(root);
1925    return root;
1926}
1927
1928void DTreesImpl::read( const FileNode& fn )
1929{
1930    clear();
1931    readParams(fn);
1932
1933    FileNode fnodes = fn["nodes"];
1934    CV_Assert( !fnodes.empty() );
1935    readTree(fnodes);
1936}
1937
1938Ptr<DTrees> DTrees::create()
1939{
1940    return makePtr<DTreesImpl>();
1941}
1942
1943}
1944}
1945
1946/* End of file. */
1947