1#include "opencv2/core.hpp" 2#include "opencv2/imgproc.hpp" 3#include "opencv2/ml.hpp" 4#include "opencv2/highgui.hpp" 5#ifdef HAVE_OPENCV_OCL 6#define _OCL_KNN_ 1 // select whether using ocl::KNN method or not, default is using 7#define _OCL_SVM_ 1 // select whether using ocl::svm method or not, default is using 8#include "opencv2/ocl/ocl.hpp" 9#endif 10 11#include <stdio.h> 12 13using namespace std; 14using namespace cv; 15using namespace cv::ml; 16 17const Scalar WHITE_COLOR = Scalar(255,255,255); 18const string winName = "points"; 19const int testStep = 5; 20 21Mat img, imgDst; 22RNG rng; 23 24vector<Point> trainedPoints; 25vector<int> trainedPointsMarkers; 26const int MAX_CLASSES = 2; 27vector<Vec3b> classColors(MAX_CLASSES); 28int currentClass = 0; 29vector<int> classCounters(MAX_CLASSES); 30 31#define _NBC_ 1 // normal Bayessian classifier 32#define _KNN_ 1 // k nearest neighbors classifier 33#define _SVM_ 1 // support vectors machine 34#define _DT_ 1 // decision tree 35#define _BT_ 1 // ADA Boost 36#define _GBT_ 0 // gradient boosted trees 37#define _RF_ 1 // random forest 38#define _ANN_ 1 // artificial neural networks 39#define _EM_ 1 // expectation-maximization 40 41static void on_mouse( int event, int x, int y, int /*flags*/, void* ) 42{ 43 if( img.empty() ) 44 return; 45 46 int updateFlag = 0; 47 48 if( event == EVENT_LBUTTONUP ) 49 { 50 trainedPoints.push_back( Point(x,y) ); 51 trainedPointsMarkers.push_back( currentClass ); 52 classCounters[currentClass]++; 53 updateFlag = true; 54 } 55 56 //draw 57 if( updateFlag ) 58 { 59 img = Scalar::all(0); 60 61 // draw points 62 for( size_t i = 0; i < trainedPoints.size(); i++ ) 63 { 64 Vec3b c = classColors[trainedPointsMarkers[i]]; 65 circle( img, trainedPoints[i], 5, Scalar(c), -1 ); 66 } 67 68 imshow( winName, img ); 69 } 70} 71 72static Mat prepare_train_samples(const vector<Point>& pts) 73{ 74 Mat samples; 75 Mat(pts).reshape(1, (int)pts.size()).convertTo(samples, CV_32F); 76 return samples; 77} 78 79static Ptr<TrainData> prepare_train_data() 80{ 81 Mat samples = prepare_train_samples(trainedPoints); 82 return TrainData::create(samples, ROW_SAMPLE, Mat(trainedPointsMarkers)); 83} 84 85static void predict_and_paint(const Ptr<StatModel>& model, Mat& dst) 86{ 87 Mat testSample( 1, 2, CV_32FC1 ); 88 for( int y = 0; y < img.rows; y += testStep ) 89 { 90 for( int x = 0; x < img.cols; x += testStep ) 91 { 92 testSample.at<float>(0) = (float)x; 93 testSample.at<float>(1) = (float)y; 94 95 int response = (int)model->predict( testSample ); 96 dst.at<Vec3b>(y, x) = classColors[response]; 97 } 98 } 99} 100 101#if _NBC_ 102static void find_decision_boundary_NBC() 103{ 104 // learn classifier 105 Ptr<NormalBayesClassifier> normalBayesClassifier = StatModel::train<NormalBayesClassifier>(prepare_train_data()); 106 107 predict_and_paint(normalBayesClassifier, imgDst); 108} 109#endif 110 111 112#if _KNN_ 113static void find_decision_boundary_KNN( int K ) 114{ 115 116 Ptr<KNearest> knn = KNearest::create(); 117 knn->setDefaultK(K); 118 knn->setIsClassifier(true); 119 knn->train(prepare_train_data()); 120 predict_and_paint(knn, imgDst); 121} 122#endif 123 124#if _SVM_ 125static void find_decision_boundary_SVM( double C ) 126{ 127 Ptr<SVM> svm = SVM::create(); 128 svm->setType(SVM::C_SVC); 129 svm->setKernel(SVM::POLY); //SVM::LINEAR; 130 svm->setDegree(0.5); 131 svm->setGamma(1); 132 svm->setCoef0(1); 133 svm->setNu(0.5); 134 svm->setP(0); 135 svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS, 1000, 0.01)); 136 svm->setC(C); 137 svm->train(prepare_train_data()); 138 predict_and_paint(svm, imgDst); 139 140 Mat sv = svm->getSupportVectors(); 141 for( int i = 0; i < sv.rows; i++ ) 142 { 143 const float* supportVector = sv.ptr<float>(i); 144 circle( imgDst, Point(saturate_cast<int>(supportVector[0]),saturate_cast<int>(supportVector[1])), 5, Scalar(255,255,255), -1 ); 145 } 146} 147#endif 148 149#if _DT_ 150static void find_decision_boundary_DT() 151{ 152 Ptr<DTrees> dtree = DTrees::create(); 153 dtree->setMaxDepth(8); 154 dtree->setMinSampleCount(2); 155 dtree->setUseSurrogates(false); 156 dtree->setCVFolds(0); // the number of cross-validation folds 157 dtree->setUse1SERule(false); 158 dtree->setTruncatePrunedTree(false); 159 dtree->train(prepare_train_data()); 160 predict_and_paint(dtree, imgDst); 161} 162#endif 163 164#if _BT_ 165static void find_decision_boundary_BT() 166{ 167 Ptr<Boost> boost = Boost::create(); 168 boost->setBoostType(Boost::DISCRETE); 169 boost->setWeakCount(100); 170 boost->setWeightTrimRate(0.95); 171 boost->setMaxDepth(2); 172 boost->setUseSurrogates(false); 173 boost->setPriors(Mat()); 174 boost->train(prepare_train_data()); 175 predict_and_paint(boost, imgDst); 176} 177 178#endif 179 180#if _GBT_ 181static void find_decision_boundary_GBT() 182{ 183 GBTrees::Params params( GBTrees::DEVIANCE_LOSS, // loss_function_type 184 100, // weak_count 185 0.1f, // shrinkage 186 1.0f, // subsample_portion 187 2, // max_depth 188 false // use_surrogates ) 189 ); 190 191 Ptr<GBTrees> gbtrees = StatModel::train<GBTrees>(prepare_train_data(), params); 192 predict_and_paint(gbtrees, imgDst); 193} 194#endif 195 196#if _RF_ 197static void find_decision_boundary_RF() 198{ 199 Ptr<RTrees> rtrees = RTrees::create(); 200 rtrees->setMaxDepth(4); 201 rtrees->setMinSampleCount(2); 202 rtrees->setRegressionAccuracy(0.f); 203 rtrees->setUseSurrogates(false); 204 rtrees->setMaxCategories(16); 205 rtrees->setPriors(Mat()); 206 rtrees->setCalculateVarImportance(false); 207 rtrees->setActiveVarCount(1); 208 rtrees->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 5, 0)); 209 rtrees->train(prepare_train_data()); 210 predict_and_paint(rtrees, imgDst); 211} 212 213#endif 214 215#if _ANN_ 216static void find_decision_boundary_ANN( const Mat& layer_sizes ) 217{ 218 Mat trainClasses = Mat::zeros( (int)trainedPoints.size(), (int)classColors.size(), CV_32FC1 ); 219 for( int i = 0; i < trainClasses.rows; i++ ) 220 { 221 trainClasses.at<float>(i, trainedPointsMarkers[i]) = 1.f; 222 } 223 224 Mat samples = prepare_train_samples(trainedPoints); 225 Ptr<TrainData> tdata = TrainData::create(samples, ROW_SAMPLE, trainClasses); 226 227 Ptr<ANN_MLP> ann = ANN_MLP::create(); 228 ann->setLayerSizes(layer_sizes); 229 ann->setActivationFunction(ANN_MLP::SIGMOID_SYM, 1, 1); 230 ann->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS, 300, FLT_EPSILON)); 231 ann->setTrainMethod(ANN_MLP::BACKPROP, 0.001); 232 ann->train(tdata); 233 predict_and_paint(ann, imgDst); 234} 235#endif 236 237#if _EM_ 238static void find_decision_boundary_EM() 239{ 240 img.copyTo( imgDst ); 241 242 Mat samples = prepare_train_samples(trainedPoints); 243 244 int i, j, nmodels = (int)classColors.size(); 245 vector<Ptr<EM> > em_models(nmodels); 246 Mat modelSamples; 247 248 for( i = 0; i < nmodels; i++ ) 249 { 250 const int componentCount = 3; 251 252 modelSamples.release(); 253 for( j = 0; j < samples.rows; j++ ) 254 { 255 if( trainedPointsMarkers[j] == i ) 256 modelSamples.push_back(samples.row(j)); 257 } 258 259 // learn models 260 if( !modelSamples.empty() ) 261 { 262 Ptr<EM> em = EM::create(); 263 em->setClustersNumber(componentCount); 264 em->setCovarianceMatrixType(EM::COV_MAT_DIAGONAL); 265 em->trainEM(modelSamples, noArray(), noArray(), noArray()); 266 em_models[i] = em; 267 } 268 } 269 270 // classify coordinate plane points using the bayes classifier, i.e. 271 // y(x) = arg max_i=1_modelsCount likelihoods_i(x) 272 Mat testSample(1, 2, CV_32FC1 ); 273 Mat logLikelihoods(1, nmodels, CV_64FC1, Scalar(-DBL_MAX)); 274 275 for( int y = 0; y < img.rows; y += testStep ) 276 { 277 for( int x = 0; x < img.cols; x += testStep ) 278 { 279 testSample.at<float>(0) = (float)x; 280 testSample.at<float>(1) = (float)y; 281 282 for( i = 0; i < nmodels; i++ ) 283 { 284 if( !em_models[i].empty() ) 285 logLikelihoods.at<double>(i) = em_models[i]->predict2(testSample, noArray())[0]; 286 } 287 Point maxLoc; 288 minMaxLoc(logLikelihoods, 0, 0, 0, &maxLoc); 289 imgDst.at<Vec3b>(y, x) = classColors[maxLoc.x]; 290 } 291 } 292} 293#endif 294 295int main() 296{ 297 cout << "Use:" << endl 298 << " key '0' .. '1' - switch to class #n" << endl 299 << " left mouse button - to add new point;" << endl 300 << " key 'r' - to run the ML model;" << endl 301 << " key 'i' - to init (clear) the data." << endl << endl; 302 303 cv::namedWindow( "points", 1 ); 304 img.create( 480, 640, CV_8UC3 ); 305 imgDst.create( 480, 640, CV_8UC3 ); 306 307 imshow( "points", img ); 308 setMouseCallback( "points", on_mouse ); 309 310 classColors[0] = Vec3b(0, 255, 0); 311 classColors[1] = Vec3b(0, 0, 255); 312 313 for(;;) 314 { 315 uchar key = (uchar)waitKey(); 316 317 if( key == 27 ) break; 318 319 if( key == 'i' ) // init 320 { 321 img = Scalar::all(0); 322 323 trainedPoints.clear(); 324 trainedPointsMarkers.clear(); 325 classCounters.assign(MAX_CLASSES, 0); 326 327 imshow( winName, img ); 328 } 329 330 if( key == '0' || key == '1' ) 331 { 332 currentClass = key - '0'; 333 } 334 335 if( key == 'r' ) // run 336 { 337 double minVal = 0; 338 minMaxLoc(classCounters, &minVal, 0, 0, 0); 339 if( minVal == 0 ) 340 { 341 printf("each class should have at least 1 point\n"); 342 continue; 343 } 344 img.copyTo( imgDst ); 345#if _NBC_ 346 find_decision_boundary_NBC(); 347 imshow( "NormalBayesClassifier", imgDst ); 348#endif 349#if _KNN_ 350 find_decision_boundary_KNN( 3 ); 351 imshow( "kNN", imgDst ); 352 353 find_decision_boundary_KNN( 15 ); 354 imshow( "kNN2", imgDst ); 355#endif 356 357#if _SVM_ 358 //(1)-(2)separable and not sets 359 360 find_decision_boundary_SVM( 1 ); 361 imshow( "classificationSVM1", imgDst ); 362 363 find_decision_boundary_SVM( 10 ); 364 imshow( "classificationSVM2", imgDst ); 365#endif 366 367#if _DT_ 368 find_decision_boundary_DT(); 369 imshow( "DT", imgDst ); 370#endif 371 372#if _BT_ 373 find_decision_boundary_BT(); 374 imshow( "BT", imgDst); 375#endif 376 377#if _GBT_ 378 find_decision_boundary_GBT(); 379 imshow( "GBT", imgDst); 380#endif 381 382#if _RF_ 383 find_decision_boundary_RF(); 384 imshow( "RF", imgDst); 385#endif 386 387#if _ANN_ 388 Mat layer_sizes1( 1, 3, CV_32SC1 ); 389 layer_sizes1.at<int>(0) = 2; 390 layer_sizes1.at<int>(1) = 5; 391 layer_sizes1.at<int>(2) = (int)classColors.size(); 392 find_decision_boundary_ANN( layer_sizes1 ); 393 imshow( "ANN", imgDst ); 394#endif 395 396#if _EM_ 397 find_decision_boundary_EM(); 398 imshow( "EM", imgDst ); 399#endif 400 } 401 } 402 403 return 0; 404} 405