1#include "opencv2/core.hpp"
2#include "opencv2/imgproc.hpp"
3#include "opencv2/ml.hpp"
4#include "opencv2/highgui.hpp"
5#ifdef HAVE_OPENCV_OCL
6#define _OCL_KNN_ 1 // select whether using ocl::KNN method or not, default is using
7#define _OCL_SVM_ 1 // select whether using ocl::svm method or not, default is using
8#include "opencv2/ocl/ocl.hpp"
9#endif
10
11#include <stdio.h>
12
13using namespace std;
14using namespace cv;
15using namespace cv::ml;
16
17const Scalar WHITE_COLOR = Scalar(255,255,255);
18const string winName = "points";
19const int testStep = 5;
20
21Mat img, imgDst;
22RNG rng;
23
24vector<Point>  trainedPoints;
25vector<int>    trainedPointsMarkers;
26const int MAX_CLASSES = 2;
27vector<Vec3b>  classColors(MAX_CLASSES);
28int currentClass = 0;
29vector<int> classCounters(MAX_CLASSES);
30
31#define _NBC_ 1 // normal Bayessian classifier
32#define _KNN_ 1 // k nearest neighbors classifier
33#define _SVM_ 1 // support vectors machine
34#define _DT_  1 // decision tree
35#define _BT_  1 // ADA Boost
36#define _GBT_ 0 // gradient boosted trees
37#define _RF_  1 // random forest
38#define _ANN_ 1 // artificial neural networks
39#define _EM_  1 // expectation-maximization
40
41static void on_mouse( int event, int x, int y, int /*flags*/, void* )
42{
43    if( img.empty() )
44        return;
45
46    int updateFlag = 0;
47
48    if( event == EVENT_LBUTTONUP )
49    {
50        trainedPoints.push_back( Point(x,y) );
51        trainedPointsMarkers.push_back( currentClass );
52        classCounters[currentClass]++;
53        updateFlag = true;
54    }
55
56    //draw
57    if( updateFlag )
58    {
59        img = Scalar::all(0);
60
61        // draw points
62        for( size_t i = 0; i < trainedPoints.size(); i++ )
63        {
64            Vec3b c = classColors[trainedPointsMarkers[i]];
65            circle( img, trainedPoints[i], 5, Scalar(c), -1 );
66        }
67
68        imshow( winName, img );
69   }
70}
71
72static Mat prepare_train_samples(const vector<Point>& pts)
73{
74    Mat samples;
75    Mat(pts).reshape(1, (int)pts.size()).convertTo(samples, CV_32F);
76    return samples;
77}
78
79static Ptr<TrainData> prepare_train_data()
80{
81    Mat samples = prepare_train_samples(trainedPoints);
82    return TrainData::create(samples, ROW_SAMPLE, Mat(trainedPointsMarkers));
83}
84
85static void predict_and_paint(const Ptr<StatModel>& model, Mat& dst)
86{
87    Mat testSample( 1, 2, CV_32FC1 );
88    for( int y = 0; y < img.rows; y += testStep )
89    {
90        for( int x = 0; x < img.cols; x += testStep )
91        {
92            testSample.at<float>(0) = (float)x;
93            testSample.at<float>(1) = (float)y;
94
95            int response = (int)model->predict( testSample );
96            dst.at<Vec3b>(y, x) = classColors[response];
97        }
98    }
99}
100
101#if _NBC_
102static void find_decision_boundary_NBC()
103{
104    // learn classifier
105    Ptr<NormalBayesClassifier> normalBayesClassifier = StatModel::train<NormalBayesClassifier>(prepare_train_data());
106
107    predict_and_paint(normalBayesClassifier, imgDst);
108}
109#endif
110
111
112#if _KNN_
113static void find_decision_boundary_KNN( int K )
114{
115
116    Ptr<KNearest> knn = KNearest::create();
117    knn->setDefaultK(K);
118    knn->setIsClassifier(true);
119    knn->train(prepare_train_data());
120    predict_and_paint(knn, imgDst);
121}
122#endif
123
124#if _SVM_
125static void find_decision_boundary_SVM( double C )
126{
127    Ptr<SVM> svm = SVM::create();
128    svm->setType(SVM::C_SVC);
129    svm->setKernel(SVM::POLY); //SVM::LINEAR;
130    svm->setDegree(0.5);
131    svm->setGamma(1);
132    svm->setCoef0(1);
133    svm->setNu(0.5);
134    svm->setP(0);
135    svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS, 1000, 0.01));
136    svm->setC(C);
137    svm->train(prepare_train_data());
138    predict_and_paint(svm, imgDst);
139
140    Mat sv = svm->getSupportVectors();
141    for( int i = 0; i < sv.rows; i++ )
142    {
143        const float* supportVector = sv.ptr<float>(i);
144        circle( imgDst, Point(saturate_cast<int>(supportVector[0]),saturate_cast<int>(supportVector[1])), 5, Scalar(255,255,255), -1 );
145    }
146}
147#endif
148
149#if _DT_
150static void find_decision_boundary_DT()
151{
152    Ptr<DTrees> dtree = DTrees::create();
153    dtree->setMaxDepth(8);
154    dtree->setMinSampleCount(2);
155    dtree->setUseSurrogates(false);
156    dtree->setCVFolds(0); // the number of cross-validation folds
157    dtree->setUse1SERule(false);
158    dtree->setTruncatePrunedTree(false);
159    dtree->train(prepare_train_data());
160    predict_and_paint(dtree, imgDst);
161}
162#endif
163
164#if _BT_
165static void find_decision_boundary_BT()
166{
167    Ptr<Boost> boost = Boost::create();
168    boost->setBoostType(Boost::DISCRETE);
169    boost->setWeakCount(100);
170    boost->setWeightTrimRate(0.95);
171    boost->setMaxDepth(2);
172    boost->setUseSurrogates(false);
173    boost->setPriors(Mat());
174    boost->train(prepare_train_data());
175    predict_and_paint(boost, imgDst);
176}
177
178#endif
179
180#if _GBT_
181static void find_decision_boundary_GBT()
182{
183    GBTrees::Params params( GBTrees::DEVIANCE_LOSS, // loss_function_type
184                         100, // weak_count
185                         0.1f, // shrinkage
186                         1.0f, // subsample_portion
187                         2, // max_depth
188                         false // use_surrogates )
189                         );
190
191    Ptr<GBTrees> gbtrees = StatModel::train<GBTrees>(prepare_train_data(), params);
192    predict_and_paint(gbtrees, imgDst);
193}
194#endif
195
196#if _RF_
197static void find_decision_boundary_RF()
198{
199    Ptr<RTrees> rtrees = RTrees::create();
200    rtrees->setMaxDepth(4);
201    rtrees->setMinSampleCount(2);
202    rtrees->setRegressionAccuracy(0.f);
203    rtrees->setUseSurrogates(false);
204    rtrees->setMaxCategories(16);
205    rtrees->setPriors(Mat());
206    rtrees->setCalculateVarImportance(false);
207    rtrees->setActiveVarCount(1);
208    rtrees->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 5, 0));
209    rtrees->train(prepare_train_data());
210    predict_and_paint(rtrees, imgDst);
211}
212
213#endif
214
215#if _ANN_
216static void find_decision_boundary_ANN( const Mat&  layer_sizes )
217{
218    Mat trainClasses = Mat::zeros( (int)trainedPoints.size(), (int)classColors.size(), CV_32FC1 );
219    for( int i = 0; i < trainClasses.rows; i++ )
220    {
221        trainClasses.at<float>(i, trainedPointsMarkers[i]) = 1.f;
222    }
223
224    Mat samples = prepare_train_samples(trainedPoints);
225    Ptr<TrainData> tdata = TrainData::create(samples, ROW_SAMPLE, trainClasses);
226
227    Ptr<ANN_MLP> ann = ANN_MLP::create();
228    ann->setLayerSizes(layer_sizes);
229    ann->setActivationFunction(ANN_MLP::SIGMOID_SYM, 1, 1);
230    ann->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS, 300, FLT_EPSILON));
231    ann->setTrainMethod(ANN_MLP::BACKPROP, 0.001);
232    ann->train(tdata);
233    predict_and_paint(ann, imgDst);
234}
235#endif
236
237#if _EM_
238static void find_decision_boundary_EM()
239{
240    img.copyTo( imgDst );
241
242    Mat samples = prepare_train_samples(trainedPoints);
243
244    int i, j, nmodels = (int)classColors.size();
245    vector<Ptr<EM> > em_models(nmodels);
246    Mat modelSamples;
247
248    for( i = 0; i < nmodels; i++ )
249    {
250        const int componentCount = 3;
251
252        modelSamples.release();
253        for( j = 0; j < samples.rows; j++ )
254        {
255            if( trainedPointsMarkers[j] == i )
256                modelSamples.push_back(samples.row(j));
257        }
258
259        // learn models
260        if( !modelSamples.empty() )
261        {
262            Ptr<EM> em = EM::create();
263            em->setClustersNumber(componentCount);
264            em->setCovarianceMatrixType(EM::COV_MAT_DIAGONAL);
265            em->trainEM(modelSamples, noArray(), noArray(), noArray());
266            em_models[i] = em;
267        }
268    }
269
270    // classify coordinate plane points using the bayes classifier, i.e.
271    // y(x) = arg max_i=1_modelsCount likelihoods_i(x)
272    Mat testSample(1, 2, CV_32FC1 );
273    Mat logLikelihoods(1, nmodels, CV_64FC1, Scalar(-DBL_MAX));
274
275    for( int y = 0; y < img.rows; y += testStep )
276    {
277        for( int x = 0; x < img.cols; x += testStep )
278        {
279            testSample.at<float>(0) = (float)x;
280            testSample.at<float>(1) = (float)y;
281
282            for( i = 0; i < nmodels; i++ )
283            {
284                if( !em_models[i].empty() )
285                    logLikelihoods.at<double>(i) = em_models[i]->predict2(testSample, noArray())[0];
286            }
287            Point maxLoc;
288            minMaxLoc(logLikelihoods, 0, 0, 0, &maxLoc);
289            imgDst.at<Vec3b>(y, x) = classColors[maxLoc.x];
290        }
291    }
292}
293#endif
294
295int main()
296{
297    cout << "Use:" << endl
298         << "  key '0' .. '1' - switch to class #n" << endl
299         << "  left mouse button - to add new point;" << endl
300         << "  key 'r' - to run the ML model;" << endl
301         << "  key 'i' - to init (clear) the data." << endl << endl;
302
303    cv::namedWindow( "points", 1 );
304    img.create( 480, 640, CV_8UC3 );
305    imgDst.create( 480, 640, CV_8UC3 );
306
307    imshow( "points", img );
308    setMouseCallback( "points", on_mouse );
309
310    classColors[0] = Vec3b(0, 255, 0);
311    classColors[1] = Vec3b(0, 0, 255);
312
313    for(;;)
314    {
315        uchar key = (uchar)waitKey();
316
317        if( key == 27 ) break;
318
319        if( key == 'i' ) // init
320        {
321            img = Scalar::all(0);
322
323            trainedPoints.clear();
324            trainedPointsMarkers.clear();
325            classCounters.assign(MAX_CLASSES, 0);
326
327            imshow( winName, img );
328        }
329
330        if( key == '0' || key == '1' )
331        {
332            currentClass = key - '0';
333        }
334
335        if( key == 'r' ) // run
336        {
337            double minVal = 0;
338            minMaxLoc(classCounters, &minVal, 0, 0, 0);
339            if( minVal == 0 )
340            {
341                printf("each class should have at least 1 point\n");
342                continue;
343            }
344            img.copyTo( imgDst );
345#if _NBC_
346            find_decision_boundary_NBC();
347            imshow( "NormalBayesClassifier", imgDst );
348#endif
349#if _KNN_
350            find_decision_boundary_KNN( 3 );
351            imshow( "kNN", imgDst );
352
353            find_decision_boundary_KNN( 15 );
354            imshow( "kNN2", imgDst );
355#endif
356
357#if _SVM_
358            //(1)-(2)separable and not sets
359
360            find_decision_boundary_SVM( 1 );
361            imshow( "classificationSVM1", imgDst );
362
363            find_decision_boundary_SVM( 10 );
364            imshow( "classificationSVM2", imgDst );
365#endif
366
367#if _DT_
368            find_decision_boundary_DT();
369            imshow( "DT", imgDst );
370#endif
371
372#if _BT_
373            find_decision_boundary_BT();
374            imshow( "BT", imgDst);
375#endif
376
377#if _GBT_
378            find_decision_boundary_GBT();
379            imshow( "GBT", imgDst);
380#endif
381
382#if _RF_
383            find_decision_boundary_RF();
384            imshow( "RF", imgDst);
385#endif
386
387#if _ANN_
388            Mat layer_sizes1( 1, 3, CV_32SC1 );
389            layer_sizes1.at<int>(0) = 2;
390            layer_sizes1.at<int>(1) = 5;
391            layer_sizes1.at<int>(2) = (int)classColors.size();
392            find_decision_boundary_ANN( layer_sizes1 );
393            imshow( "ANN", imgDst );
394#endif
395
396#if _EM_
397            find_decision_boundary_EM();
398            imshow( "EM", imgDst );
399#endif
400        }
401    }
402
403    return 0;
404}
405