1/*M///////////////////////////////////////////////////////////////////////////////////////
2//
3//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4//
5//  By downloading, copying, installing or using the software you agree to this license.
6//  If you do not agree to this license, do not download, install,
7//  copy or use the software.
8//
9//
10//                        Intel License Agreement
11//
12// Copyright (C) 2000, Intel Corporation, all rights reserved.
13// Third party copyrights are property of their respective owners.
14//
15// Redistribution and use in source and binary forms, with or without modification,
16// are permitted provided that the following conditions are met:
17//
18//   * Redistribution's of source code must retain the above copyright notice,
19//     this list of conditions and the following disclaimer.
20//
21//   * Redistribution's in binary form must reproduce the above copyright notice,
22//     this list of conditions and the following disclaimer in the documentation
23//     and/or other materials provided with the distribution.
24//
25//   * The name of Intel Corporation may not be used to endorse or promote products
26//     derived from this software without specific prior written permission.
27//
28// This software is provided by the copyright holders and contributors "as is" and
29// any express or implied warranties, including, but not limited to, the implied
30// warranties of merchantability and fitness for a particular purpose are disclaimed.
31// In no event shall the Intel Corporation or contributors be liable for any direct,
32// indirect, incidental, special, exemplary, or consequential damages
33// (including, but not limited to, procurement of substitute goods or services;
34// loss of use, data, or profits; or business interruption) however caused
35// and on any theory of liability, whether in contract, strict liability,
36// or tort (including negligence or otherwise) arising in any way out of
37// the use of this software, even if advised of the possibility of such damage.
38//
39//M*/
40
41#include "precomp.hpp"
42#include <ctype.h>
43#include <algorithm>
44#include <iterator>
45
46namespace cv { namespace ml {
47
48static const float MISSED_VAL = TrainData::missingValue();
49static const int VAR_MISSED = VAR_ORDERED;
50
51TrainData::~TrainData() {}
52
53Mat TrainData::getSubVector(const Mat& vec, const Mat& idx)
54{
55    if( idx.empty() )
56        return vec;
57    int i, j, n = idx.checkVector(1, CV_32S);
58    int type = vec.type();
59    CV_Assert( type == CV_32S || type == CV_32F || type == CV_64F );
60    int dims = 1, m;
61
62    if( vec.cols == 1 || vec.rows == 1 )
63    {
64        dims = 1;
65        m = vec.cols + vec.rows - 1;
66    }
67    else
68    {
69        dims = vec.cols;
70        m = vec.rows;
71    }
72
73    Mat subvec;
74
75    if( vec.cols == m )
76        subvec.create(dims, n, type);
77    else
78        subvec.create(n, dims, type);
79    if( type == CV_32S )
80        for( i = 0; i < n; i++ )
81        {
82            int k = idx.at<int>(i);
83            CV_Assert( 0 <= k && k < m );
84            if( dims == 1 )
85                subvec.at<int>(i) = vec.at<int>(k);
86            else
87                for( j = 0; j < dims; j++ )
88                    subvec.at<int>(i, j) = vec.at<int>(k, j);
89        }
90    else if( type == CV_32F )
91        for( i = 0; i < n; i++ )
92        {
93            int k = idx.at<int>(i);
94            CV_Assert( 0 <= k && k < m );
95            if( dims == 1 )
96                subvec.at<float>(i) = vec.at<float>(k);
97            else
98                for( j = 0; j < dims; j++ )
99                    subvec.at<float>(i, j) = vec.at<float>(k, j);
100        }
101    else
102        for( i = 0; i < n; i++ )
103        {
104            int k = idx.at<int>(i);
105            CV_Assert( 0 <= k && k < m );
106            if( dims == 1 )
107                subvec.at<double>(i) = vec.at<double>(k);
108            else
109                for( j = 0; j < dims; j++ )
110                    subvec.at<double>(i, j) = vec.at<double>(k, j);
111        }
112    return subvec;
113}
114
115class TrainDataImpl : public TrainData
116{
117public:
118    typedef std::map<String, int> MapType;
119
120    TrainDataImpl()
121    {
122        file = 0;
123        clear();
124    }
125
126    virtual ~TrainDataImpl() { closeFile(); }
127
128    int getLayout() const { return layout; }
129    int getNSamples() const
130    {
131        return !sampleIdx.empty() ? (int)sampleIdx.total() :
132               layout == ROW_SAMPLE ? samples.rows : samples.cols;
133    }
134    int getNTrainSamples() const
135    {
136        return !trainSampleIdx.empty() ? (int)trainSampleIdx.total() : getNSamples();
137    }
138    int getNTestSamples() const
139    {
140        return !testSampleIdx.empty() ? (int)testSampleIdx.total() : 0;
141    }
142    int getNVars() const
143    {
144        return !varIdx.empty() ? (int)varIdx.total() : getNAllVars();
145    }
146    int getNAllVars() const
147    {
148        return layout == ROW_SAMPLE ? samples.cols : samples.rows;
149    }
150
151    Mat getSamples() const { return samples; }
152    Mat getResponses() const { return responses; }
153    Mat getMissing() const { return missing; }
154    Mat getVarIdx() const { return varIdx; }
155    Mat getVarType() const { return varType; }
156    int getResponseType() const
157    {
158        return classLabels.empty() ? VAR_ORDERED : VAR_CATEGORICAL;
159    }
160    Mat getTrainSampleIdx() const { return !trainSampleIdx.empty() ? trainSampleIdx : sampleIdx; }
161    Mat getTestSampleIdx() const { return testSampleIdx; }
162    Mat getSampleWeights() const
163    {
164        return sampleWeights;
165    }
166    Mat getTrainSampleWeights() const
167    {
168        return getSubVector(sampleWeights, getTrainSampleIdx());
169    }
170    Mat getTestSampleWeights() const
171    {
172        Mat idx = getTestSampleIdx();
173        return idx.empty() ? Mat() : getSubVector(sampleWeights, idx);
174    }
175    Mat getTrainResponses() const
176    {
177        return getSubVector(responses, getTrainSampleIdx());
178    }
179    Mat getTrainNormCatResponses() const
180    {
181        return getSubVector(normCatResponses, getTrainSampleIdx());
182    }
183    Mat getTestResponses() const
184    {
185        Mat idx = getTestSampleIdx();
186        return idx.empty() ? Mat() : getSubVector(responses, idx);
187    }
188    Mat getTestNormCatResponses() const
189    {
190        Mat idx = getTestSampleIdx();
191        return idx.empty() ? Mat() : getSubVector(normCatResponses, idx);
192    }
193    Mat getNormCatResponses() const { return normCatResponses; }
194    Mat getClassLabels() const { return classLabels; }
195    Mat getClassCounters() const { return classCounters; }
196    int getCatCount(int vi) const
197    {
198        int n = (int)catOfs.total();
199        CV_Assert( 0 <= vi && vi < n );
200        Vec2i ofs = catOfs.at<Vec2i>(vi);
201        return ofs[1] - ofs[0];
202    }
203
204    Mat getCatOfs() const { return catOfs; }
205    Mat getCatMap() const { return catMap; }
206
207    Mat getDefaultSubstValues() const { return missingSubst; }
208
209    void closeFile() { if(file) fclose(file); file=0; }
210    void clear()
211    {
212        closeFile();
213        samples.release();
214        missing.release();
215        varType.release();
216        responses.release();
217        sampleIdx.release();
218        trainSampleIdx.release();
219        testSampleIdx.release();
220        normCatResponses.release();
221        classLabels.release();
222        classCounters.release();
223        catMap.release();
224        catOfs.release();
225        nameMap = MapType();
226        layout = ROW_SAMPLE;
227    }
228
229    typedef std::map<int, int> CatMapHash;
230
231    void setData(InputArray _samples, int _layout, InputArray _responses,
232                 InputArray _varIdx, InputArray _sampleIdx, InputArray _sampleWeights,
233                 InputArray _varType, InputArray _missing)
234    {
235        clear();
236
237        CV_Assert(_layout == ROW_SAMPLE || _layout == COL_SAMPLE );
238        samples = _samples.getMat();
239        layout = _layout;
240        responses = _responses.getMat();
241        varIdx = _varIdx.getMat();
242        sampleIdx = _sampleIdx.getMat();
243        sampleWeights = _sampleWeights.getMat();
244        varType = _varType.getMat();
245        missing = _missing.getMat();
246
247        int nsamples = layout == ROW_SAMPLE ? samples.rows : samples.cols;
248        int ninputvars = layout == ROW_SAMPLE ? samples.cols : samples.rows;
249        int i, noutputvars = 0;
250
251        CV_Assert( samples.type() == CV_32F || samples.type() == CV_32S );
252
253        if( !sampleIdx.empty() )
254        {
255            CV_Assert( (sampleIdx.checkVector(1, CV_32S, true) > 0 &&
256                       checkRange(sampleIdx, true, 0, 0, nsamples-1)) ||
257                       sampleIdx.checkVector(1, CV_8U, true) == nsamples );
258            if( sampleIdx.type() == CV_8U )
259                sampleIdx = convertMaskToIdx(sampleIdx);
260        }
261
262        if( !sampleWeights.empty() )
263        {
264            CV_Assert( sampleWeights.checkVector(1, CV_32F, true) == nsamples );
265        }
266        else
267        {
268            sampleWeights = Mat::ones(nsamples, 1, CV_32F);
269        }
270
271        if( !varIdx.empty() )
272        {
273            CV_Assert( (varIdx.checkVector(1, CV_32S, true) > 0 &&
274                       checkRange(varIdx, true, 0, 0, ninputvars)) ||
275                       varIdx.checkVector(1, CV_8U, true) == ninputvars );
276            if( varIdx.type() == CV_8U )
277                varIdx = convertMaskToIdx(varIdx);
278            varIdx = varIdx.clone();
279            std::sort(varIdx.ptr<int>(), varIdx.ptr<int>() + varIdx.total());
280        }
281
282        if( !responses.empty() )
283        {
284            CV_Assert( responses.type() == CV_32F || responses.type() == CV_32S );
285            if( (responses.cols == 1 || responses.rows == 1) && (int)responses.total() == nsamples )
286                noutputvars = 1;
287            else
288            {
289                CV_Assert( (layout == ROW_SAMPLE && responses.rows == nsamples) ||
290                           (layout == COL_SAMPLE && responses.cols == nsamples) );
291                noutputvars = layout == ROW_SAMPLE ? responses.cols : responses.rows;
292            }
293            if( !responses.isContinuous() || (layout == COL_SAMPLE && noutputvars > 1) )
294            {
295                Mat temp;
296                transpose(responses, temp);
297                responses = temp;
298            }
299        }
300
301        int nvars = ninputvars + noutputvars;
302
303        if( !varType.empty() )
304        {
305            CV_Assert( varType.checkVector(1, CV_8U, true) == nvars &&
306                       checkRange(varType, true, 0, VAR_ORDERED, VAR_CATEGORICAL+1) );
307        }
308        else
309        {
310            varType.create(1, nvars, CV_8U);
311            varType = Scalar::all(VAR_ORDERED);
312            if( noutputvars == 1 )
313                varType.at<uchar>(ninputvars) = (uchar)(responses.type() < CV_32F ? VAR_CATEGORICAL : VAR_ORDERED);
314        }
315
316        if( noutputvars > 1 )
317        {
318            for( i = 0; i < noutputvars; i++ )
319                CV_Assert( varType.at<uchar>(ninputvars + i) == VAR_ORDERED );
320        }
321
322        catOfs = Mat::zeros(1, nvars, CV_32SC2);
323        missingSubst = Mat::zeros(1, nvars, CV_32F);
324
325        vector<int> labels, counters, sortbuf, tempCatMap;
326        vector<Vec2i> tempCatOfs;
327        CatMapHash ofshash;
328
329        AutoBuffer<uchar> buf(nsamples);
330        Mat non_missing(layout == ROW_SAMPLE ? Size(1, nsamples) : Size(nsamples, 1), CV_8U, (uchar*)buf);
331        bool haveMissing = !missing.empty();
332        if( haveMissing )
333        {
334            CV_Assert( missing.size() == samples.size() && missing.type() == CV_8U );
335        }
336
337        // we iterate through all the variables. For each categorical variable we build a map
338        // in order to convert input values of the variable into normalized values (0..catcount_vi-1)
339        // often many categorical variables are similar, so we compress the map - try to re-use
340        // maps for different variables if they are identical
341        for( i = 0; i < ninputvars; i++ )
342        {
343            Mat values_i = layout == ROW_SAMPLE ? samples.col(i) : samples.row(i);
344
345            if( varType.at<uchar>(i) == VAR_CATEGORICAL )
346            {
347                preprocessCategorical(values_i, 0, labels, 0, sortbuf);
348                missingSubst.at<float>(i) = -1.f;
349                int j, m = (int)labels.size();
350                CV_Assert( m > 0 );
351                int a = labels.front(), b = labels.back();
352                const int* currmap = &labels[0];
353                int hashval = ((unsigned)a*127 + (unsigned)b)*127 + m;
354                CatMapHash::iterator it = ofshash.find(hashval);
355                if( it != ofshash.end() )
356                {
357                    int vi = it->second;
358                    Vec2i ofs0 = tempCatOfs[vi];
359                    int m0 = ofs0[1] - ofs0[0];
360                    const int* map0 = &tempCatMap[ofs0[0]];
361                    if( m0 == m && map0[0] == a && map0[m0-1] == b )
362                    {
363                        for( j = 0; j < m; j++ )
364                            if( map0[j] != currmap[j] )
365                                break;
366                        if( j == m )
367                        {
368                            // re-use the map
369                            tempCatOfs.push_back(ofs0);
370                            continue;
371                        }
372                    }
373                }
374                else
375                    ofshash[hashval] = i;
376                Vec2i ofs;
377                ofs[0] = (int)tempCatMap.size();
378                ofs[1] = ofs[0] + m;
379                tempCatOfs.push_back(ofs);
380                std::copy(labels.begin(), labels.end(), std::back_inserter(tempCatMap));
381            }
382            else
383            {
384                tempCatOfs.push_back(Vec2i(0, 0));
385                /*Mat missing_i = layout == ROW_SAMPLE ? missing.col(i) : missing.row(i);
386                compare(missing_i, Scalar::all(0), non_missing, CMP_EQ);
387                missingSubst.at<float>(i) = (float)(mean(values_i, non_missing)[0]);*/
388                missingSubst.at<float>(i) = 0.f;
389            }
390        }
391
392        if( !tempCatOfs.empty() )
393        {
394            Mat(tempCatOfs).copyTo(catOfs);
395            Mat(tempCatMap).copyTo(catMap);
396        }
397
398        if( varType.at<uchar>(ninputvars) == VAR_CATEGORICAL )
399        {
400            preprocessCategorical(responses, &normCatResponses, labels, &counters, sortbuf);
401            Mat(labels).copyTo(classLabels);
402            Mat(counters).copyTo(classCounters);
403        }
404    }
405
406    Mat convertMaskToIdx(const Mat& mask)
407    {
408        int i, j, nz = countNonZero(mask), n = mask.cols + mask.rows - 1;
409        Mat idx(1, nz, CV_32S);
410        for( i = j = 0; i < n; i++ )
411            if( mask.at<uchar>(i) )
412                idx.at<int>(j++) = i;
413        return idx;
414    }
415
416    struct CmpByIdx
417    {
418        CmpByIdx(const int* _data, int _step) : data(_data), step(_step) {}
419        bool operator ()(int i, int j) const { return data[i*step] < data[j*step]; }
420        const int* data;
421        int step;
422    };
423
424    void preprocessCategorical(const Mat& data, Mat* normdata, vector<int>& labels,
425                               vector<int>* counters, vector<int>& sortbuf)
426    {
427        CV_Assert((data.cols == 1 || data.rows == 1) && (data.type() == CV_32S || data.type() == CV_32F));
428        int* odata = 0;
429        int ostep = 0;
430
431        if(normdata)
432        {
433            normdata->create(data.size(), CV_32S);
434            odata = normdata->ptr<int>();
435            ostep = normdata->isContinuous() ? 1 : (int)normdata->step1();
436        }
437
438        int i, n = data.cols + data.rows - 1;
439        sortbuf.resize(n*2);
440        int* idx = &sortbuf[0];
441        int* idata = (int*)data.ptr<int>();
442        int istep = data.isContinuous() ? 1 : (int)data.step1();
443
444        if( data.type() == CV_32F )
445        {
446            idata = idx + n;
447            const float* fdata = data.ptr<float>();
448            for( i = 0; i < n; i++ )
449            {
450                if( fdata[i*istep] == MISSED_VAL )
451                    idata[i] = -1;
452                else
453                {
454                    idata[i] = cvRound(fdata[i*istep]);
455                    CV_Assert( (float)idata[i] == fdata[i*istep] );
456                }
457            }
458            istep = 1;
459        }
460
461        for( i = 0; i < n; i++ )
462            idx[i] = i;
463
464        std::sort(idx, idx + n, CmpByIdx(idata, istep));
465
466        int clscount = 1;
467        for( i = 1; i < n; i++ )
468            clscount += idata[idx[i]*istep] != idata[idx[i-1]*istep];
469
470        int clslabel = -1;
471        int prev = ~idata[idx[0]*istep];
472        int previdx = 0;
473
474        labels.resize(clscount);
475        if(counters)
476            counters->resize(clscount);
477
478        for( i = 0; i < n; i++ )
479        {
480            int l = idata[idx[i]*istep];
481            if( l != prev )
482            {
483                clslabel++;
484                labels[clslabel] = l;
485                int k = i - previdx;
486                if( clslabel > 0 && counters )
487                    counters->at(clslabel-1) = k;
488                prev = l;
489                previdx = i;
490            }
491            if(odata)
492                odata[idx[i]*ostep] = clslabel;
493        }
494        if(counters)
495            counters->at(clslabel) = i - previdx;
496    }
497
498    bool loadCSV(const String& filename, int headerLines,
499                 int responseStartIdx, int responseEndIdx,
500                 const String& varTypeSpec, char delimiter, char missch)
501    {
502        const int M = 1000000;
503        const char delimiters[3] = { ' ', delimiter, '\0' };
504        int nvars = 0;
505        bool varTypesSet = false;
506
507        clear();
508
509        file = fopen( filename.c_str(), "rt" );
510
511        if( !file )
512            return false;
513
514        std::vector<char> _buf(M);
515        std::vector<float> allresponses;
516        std::vector<float> rowvals;
517        std::vector<uchar> vtypes, rowtypes;
518        bool haveMissed = false;
519        char* buf = &_buf[0];
520
521        int i, ridx0 = responseStartIdx, ridx1 = responseEndIdx;
522        int ninputvars = 0, noutputvars = 0;
523
524        Mat tempSamples, tempMissing, tempResponses;
525        MapType tempNameMap;
526        int catCounter = 1;
527
528        // skip header lines
529        int lineno = 0;
530        for(;;lineno++)
531        {
532            if( !fgets(buf, M, file) )
533                break;
534            if(lineno < headerLines )
535                continue;
536            // trim trailing spaces
537            int idx = (int)strlen(buf)-1;
538            while( idx >= 0 && isspace(buf[idx]) )
539                buf[idx--] = '\0';
540            // skip spaces in the beginning
541            char* ptr = buf;
542            while( *ptr != '\0' && isspace(*ptr) )
543                ptr++;
544            // skip commented off lines
545            if(*ptr == '#')
546                continue;
547            rowvals.clear();
548            rowtypes.clear();
549
550            char* token = strtok(buf, delimiters);
551            if (!token)
552                break;
553
554            for(;;)
555            {
556                float val=0.f; int tp = 0;
557                decodeElem( token, val, tp, missch, tempNameMap, catCounter );
558                if( tp == VAR_MISSED )
559                    haveMissed = true;
560                rowvals.push_back(val);
561                rowtypes.push_back((uchar)tp);
562                token = strtok(NULL, delimiters);
563                if (!token)
564                    break;
565            }
566
567            if( nvars == 0 )
568            {
569                if( rowvals.empty() )
570                    CV_Error(CV_StsBadArg, "invalid CSV format; no data found");
571                nvars = (int)rowvals.size();
572                if( !varTypeSpec.empty() && varTypeSpec.size() > 0 )
573                {
574                    setVarTypes(varTypeSpec, nvars, vtypes);
575                    varTypesSet = true;
576                }
577                else
578                    vtypes = rowtypes;
579
580                ridx0 = ridx0 >= 0 ? ridx0 : ridx0 == -1 ? nvars - 1 : -1;
581                ridx1 = ridx1 >= 0 ? ridx1 : ridx0 >= 0 ? ridx0+1 : -1;
582                CV_Assert(ridx1 > ridx0);
583                noutputvars = ridx0 >= 0 ? ridx1 - ridx0 : 0;
584                ninputvars = nvars - noutputvars;
585            }
586            else
587                CV_Assert( nvars == (int)rowvals.size() );
588
589            // check var types
590            for( i = 0; i < nvars; i++ )
591            {
592                CV_Assert( (!varTypesSet && vtypes[i] == rowtypes[i]) ||
593                           (varTypesSet && (vtypes[i] == rowtypes[i] || rowtypes[i] == VAR_ORDERED)) );
594            }
595
596            if( ridx0 >= 0 )
597            {
598                for( i = ridx1; i < nvars; i++ )
599                    std::swap(rowvals[i], rowvals[i-noutputvars]);
600                for( i = ninputvars; i < nvars; i++ )
601                    allresponses.push_back(rowvals[i]);
602                rowvals.pop_back();
603            }
604            Mat rmat(1, ninputvars, CV_32F, &rowvals[0]);
605            tempSamples.push_back(rmat);
606        }
607
608        closeFile();
609
610        int nsamples = tempSamples.rows;
611        if( nsamples == 0 )
612            return false;
613
614        if( haveMissed )
615            compare(tempSamples, MISSED_VAL, tempMissing, CMP_EQ);
616
617        if( ridx0 >= 0 )
618        {
619            for( i = ridx1; i < nvars; i++ )
620                std::swap(vtypes[i], vtypes[i-noutputvars]);
621            if( noutputvars > 1 )
622            {
623                for( i = ninputvars; i < nvars; i++ )
624                    if( vtypes[i] == VAR_CATEGORICAL )
625                        CV_Error(CV_StsBadArg,
626                                 "If responses are vector values, not scalars, they must be marked as ordered responses");
627            }
628        }
629
630        if( !varTypesSet && noutputvars == 1 && vtypes[ninputvars] == VAR_ORDERED )
631        {
632            for( i = 0; i < nsamples; i++ )
633                if( allresponses[i] != cvRound(allresponses[i]) )
634                    break;
635            if( i == nsamples )
636                vtypes[ninputvars] = VAR_CATEGORICAL;
637        }
638
639        Mat(nsamples, noutputvars, CV_32F, &allresponses[0]).copyTo(tempResponses);
640        setData(tempSamples, ROW_SAMPLE, tempResponses, noArray(), noArray(),
641                noArray(), Mat(vtypes).clone(), tempMissing);
642        bool ok = !samples.empty();
643        if(ok)
644            std::swap(tempNameMap, nameMap);
645        return ok;
646    }
647
648    void decodeElem( const char* token, float& elem, int& type,
649                     char missch, MapType& namemap, int& counter ) const
650    {
651        char* stopstring = NULL;
652        elem = (float)strtod( token, &stopstring );
653        if( *stopstring == missch && strlen(stopstring) == 1 ) // missed value
654        {
655            elem = MISSED_VAL;
656            type = VAR_MISSED;
657        }
658        else if( *stopstring != '\0' )
659        {
660            MapType::iterator it = namemap.find(token);
661            if( it == namemap.end() )
662            {
663                elem = (float)counter;
664                namemap[token] = counter++;
665            }
666            else
667                elem = (float)it->second;
668            type = VAR_CATEGORICAL;
669        }
670        else
671            type = VAR_ORDERED;
672    }
673
674    void setVarTypes( const String& s, int nvars, std::vector<uchar>& vtypes ) const
675    {
676        const char* errmsg = "type spec is not correct; it should have format \"cat\", \"ord\" or "
677          "\"ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\", where n's and m's are 0-based variable indices";
678        const char* str = s.c_str();
679        int specCounter = 0;
680
681        vtypes.resize(nvars);
682
683        for( int k = 0; k < 2; k++ )
684        {
685            const char* ptr = strstr(str, k == 0 ? "ord" : "cat");
686            int tp = k == 0 ? VAR_ORDERED : VAR_CATEGORICAL;
687            if( ptr ) // parse ord/cat str
688            {
689                char* stopstring = NULL;
690
691                if( ptr[3] == '\0' )
692                {
693                    for( int i = 0; i < nvars; i++ )
694                        vtypes[i] = (uchar)tp;
695                    specCounter = nvars;
696                    break;
697                }
698
699                if ( ptr[3] != '[')
700                    CV_Error( CV_StsBadArg, errmsg );
701
702                ptr += 4; // pass "ord["
703                do
704                {
705                    int b1 = (int)strtod( ptr, &stopstring );
706                    if( *stopstring == 0 || (*stopstring != ',' && *stopstring != ']' && *stopstring != '-') )
707                        CV_Error( CV_StsBadArg, errmsg );
708                    ptr = stopstring + 1;
709                    if( (stopstring[0] == ',') || (stopstring[0] == ']'))
710                    {
711                        CV_Assert( 0 <= b1 && b1 < nvars );
712                        vtypes[b1] = (uchar)tp;
713                        specCounter++;
714                    }
715                    else
716                    {
717                        if( stopstring[0] == '-')
718                        {
719                            int b2 = (int)strtod( ptr, &stopstring);
720                            if ( (*stopstring == 0) || (*stopstring != ',' && *stopstring != ']') )
721                                CV_Error( CV_StsBadArg, errmsg );
722                            ptr = stopstring + 1;
723                            CV_Assert( 0 <= b1 && b1 <= b2 && b2 < nvars );
724                            for (int i = b1; i <= b2; i++)
725                                vtypes[i] = (uchar)tp;
726                            specCounter += b2 - b1 + 1;
727                        }
728                        else
729                            CV_Error( CV_StsBadArg, errmsg );
730
731                    }
732                }
733                while(*stopstring != ']');
734
735                if( stopstring[1] != '\0' && stopstring[1] != ',')
736                    CV_Error( CV_StsBadArg, errmsg );
737            }
738        }
739
740        if( specCounter != nvars )
741            CV_Error( CV_StsBadArg, "type of some variables is not specified" );
742    }
743
744    void setTrainTestSplitRatio(double ratio, bool shuffle)
745    {
746        CV_Assert( 0. <= ratio && ratio <= 1. );
747        setTrainTestSplit(cvRound(getNSamples()*ratio), shuffle);
748    }
749
750    void setTrainTestSplit(int count, bool shuffle)
751    {
752        int i, nsamples = getNSamples();
753        CV_Assert( 0 <= count && count < nsamples );
754
755        trainSampleIdx.release();
756        testSampleIdx.release();
757
758        if( count == 0 )
759            trainSampleIdx = sampleIdx;
760        else if( count == nsamples )
761            testSampleIdx = sampleIdx;
762        else
763        {
764            Mat mask(1, nsamples, CV_8U);
765            uchar* mptr = mask.ptr();
766            for( i = 0; i < nsamples; i++ )
767                mptr[i] = (uchar)(i < count);
768            trainSampleIdx.create(1, count, CV_32S);
769            testSampleIdx.create(1, nsamples - count, CV_32S);
770            int j0 = 0, j1 = 0;
771            const int* sptr = !sampleIdx.empty() ? sampleIdx.ptr<int>() : 0;
772            int* trainptr = trainSampleIdx.ptr<int>();
773            int* testptr = testSampleIdx.ptr<int>();
774            for( i = 0; i < nsamples; i++ )
775            {
776                int idx = sptr ? sptr[i] : i;
777                if( mptr[i] )
778                    trainptr[j0++] = idx;
779                else
780                    testptr[j1++] = idx;
781            }
782            if( shuffle )
783                shuffleTrainTest();
784        }
785    }
786
787    void shuffleTrainTest()
788    {
789        if( !trainSampleIdx.empty() && !testSampleIdx.empty() )
790        {
791            int i, nsamples = getNSamples(), ntrain = getNTrainSamples(), ntest = getNTestSamples();
792            int* trainIdx = trainSampleIdx.ptr<int>();
793            int* testIdx = testSampleIdx.ptr<int>();
794            RNG& rng = theRNG();
795
796            for( i = 0; i < nsamples; i++)
797            {
798                int a = rng.uniform(0, nsamples);
799                int b = rng.uniform(0, nsamples);
800                int* ptra = trainIdx;
801                int* ptrb = trainIdx;
802                if( a >= ntrain )
803                {
804                    ptra = testIdx;
805                    a -= ntrain;
806                    CV_Assert( a < ntest );
807                }
808                if( b >= ntrain )
809                {
810                    ptrb = testIdx;
811                    b -= ntrain;
812                    CV_Assert( b < ntest );
813                }
814                std::swap(ptra[a], ptrb[b]);
815            }
816        }
817    }
818
819    Mat getTrainSamples(int _layout,
820                        bool compressSamples,
821                        bool compressVars) const
822    {
823        if( samples.empty() )
824            return samples;
825
826        if( (!compressSamples || (trainSampleIdx.empty() && sampleIdx.empty())) &&
827            (!compressVars || varIdx.empty()) &&
828            layout == _layout )
829            return samples;
830
831        int drows = getNTrainSamples(), dcols = getNVars();
832        Mat sidx = getTrainSampleIdx(), vidx = getVarIdx();
833        const float* src0 = samples.ptr<float>();
834        const int* sptr = !sidx.empty() ? sidx.ptr<int>() : 0;
835        const int* vptr = !vidx.empty() ? vidx.ptr<int>() : 0;
836        size_t sstep0 = samples.step/samples.elemSize();
837        size_t sstep = layout == ROW_SAMPLE ? sstep0 : 1;
838        size_t vstep = layout == ROW_SAMPLE ? 1 : sstep0;
839
840        if( _layout == COL_SAMPLE )
841        {
842            std::swap(drows, dcols);
843            std::swap(sptr, vptr);
844            std::swap(sstep, vstep);
845        }
846
847        Mat dsamples(drows, dcols, CV_32F);
848
849        for( int i = 0; i < drows; i++ )
850        {
851            const float* src = src0 + (sptr ? sptr[i] : i)*sstep;
852            float* dst = dsamples.ptr<float>(i);
853
854            for( int j = 0; j < dcols; j++ )
855                dst[j] = src[(vptr ? vptr[j] : j)*vstep];
856        }
857
858        return dsamples;
859    }
860
861    void getValues( int vi, InputArray _sidx, float* values ) const
862    {
863        Mat sidx = _sidx.getMat();
864        int i, n = sidx.checkVector(1, CV_32S), nsamples = getNSamples();
865        CV_Assert( 0 <= vi && vi < getNAllVars() );
866        CV_Assert( n >= 0 );
867        const int* s = n > 0 ? sidx.ptr<int>() : 0;
868        if( n == 0 )
869            n = nsamples;
870
871        size_t step = samples.step/samples.elemSize();
872        size_t sstep = layout == ROW_SAMPLE ? step : 1;
873        size_t vstep = layout == ROW_SAMPLE ? 1 : step;
874
875        const float* src = samples.ptr<float>() + vi*vstep;
876        float subst = missingSubst.at<float>(vi);
877        for( i = 0; i < n; i++ )
878        {
879            int j = i;
880            if( s )
881            {
882                j = s[i];
883                CV_Assert( 0 <= j && j < nsamples );
884            }
885            values[i] = src[j*sstep];
886            if( values[i] == MISSED_VAL )
887                values[i] = subst;
888        }
889    }
890
891    void getNormCatValues( int vi, InputArray _sidx, int* values ) const
892    {
893        float* fvalues = (float*)values;
894        getValues(vi, _sidx, fvalues);
895        int i, n = (int)_sidx.total();
896        Vec2i ofs = catOfs.at<Vec2i>(vi);
897        int m = ofs[1] - ofs[0];
898
899        CV_Assert( m > 0 ); // if m==0, vi is an ordered variable
900        const int* cmap = &catMap.at<int>(ofs[0]);
901        bool fastMap = (m == cmap[m - 1] - cmap[0] + 1);
902
903        if( fastMap )
904        {
905            for( i = 0; i < n; i++ )
906            {
907                int val = cvRound(fvalues[i]);
908                int idx = val - cmap[0];
909                CV_Assert(cmap[idx] == val);
910                values[i] = idx;
911            }
912        }
913        else
914        {
915            for( i = 0; i < n; i++ )
916            {
917                int val = cvRound(fvalues[i]);
918                int a = 0, b = m, c = -1;
919
920                while( a < b )
921                {
922                    c = (a + b) >> 1;
923                    if( val < cmap[c] )
924                        b = c;
925                    else if( val > cmap[c] )
926                        a = c+1;
927                    else
928                        break;
929                }
930
931                CV_DbgAssert( c >= 0 && val == cmap[c] );
932                values[i] = c;
933            }
934        }
935    }
936
937    void getSample(InputArray _vidx, int sidx, float* buf) const
938    {
939        CV_Assert(buf != 0 && 0 <= sidx && sidx < getNSamples());
940        Mat vidx = _vidx.getMat();
941        int i, n = vidx.checkVector(1, CV_32S), nvars = getNAllVars();
942        CV_Assert( n >= 0 );
943        const int* vptr = n > 0 ? vidx.ptr<int>() : 0;
944        if( n == 0 )
945            n = nvars;
946
947        size_t step = samples.step/samples.elemSize();
948        size_t sstep = layout == ROW_SAMPLE ? step : 1;
949        size_t vstep = layout == ROW_SAMPLE ? 1 : step;
950
951        const float* src = samples.ptr<float>() + sidx*sstep;
952        for( i = 0; i < n; i++ )
953        {
954            int j = i;
955            if( vptr )
956            {
957                j = vptr[i];
958                CV_Assert( 0 <= j && j < nvars );
959            }
960            buf[i] = src[j*vstep];
961        }
962    }
963
964    FILE* file;
965    int layout;
966    Mat samples, missing, varType, varIdx, responses, missingSubst;
967    Mat sampleIdx, trainSampleIdx, testSampleIdx;
968    Mat sampleWeights, catMap, catOfs;
969    Mat normCatResponses, classLabels, classCounters;
970    MapType nameMap;
971};
972
973Ptr<TrainData> TrainData::loadFromCSV(const String& filename,
974                                      int headerLines,
975                                      int responseStartIdx,
976                                      int responseEndIdx,
977                                      const String& varTypeSpec,
978                                      char delimiter, char missch)
979{
980    Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
981    if(!td->loadCSV(filename, headerLines, responseStartIdx, responseEndIdx, varTypeSpec, delimiter, missch))
982        td.release();
983    return td;
984}
985
986Ptr<TrainData> TrainData::create(InputArray samples, int layout, InputArray responses,
987                                 InputArray varIdx, InputArray sampleIdx, InputArray sampleWeights,
988                                 InputArray varType)
989{
990    Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
991    td->setData(samples, layout, responses, varIdx, sampleIdx, sampleWeights, varType, noArray());
992    return td;
993}
994
995}}
996
997/* End of file. */
998