1/*M/////////////////////////////////////////////////////////////////////////////////////// 2// 3// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 4// 5// By downloading, copying, installing or using the software you agree to this license. 6// If you do not agree to this license, do not download, install, 7// copy or use the software. 8// 9// 10// Intel License Agreement 11// For Open Source Computer Vision Library 12// 13// Copyright (C) 2000, Intel Corporation, all rights reserved. 14// Third party copyrights are property of their respective owners. 15// 16// Redistribution and use in source and binary forms, with or without modification, 17// are permitted provided that the following conditions are met: 18// 19// * Redistribution's of source code must retain the above copyright notice, 20// this list of conditions and the following disclaimer. 21// 22// * Redistribution's in binary form must reproduce the above copyright notice, 23// this list of conditions and the following disclaimer in the documentation 24// and/or other materials provided with the distribution. 25// 26// * The name of Intel Corporation may not be used to endorse or promote products 27// derived from this software without specific prior written permission. 28// 29// This software is provided by the copyright holders and contributors "as is" and 30// any express or implied warranties, including, but not limited to, the implied 31// warranties of merchantability and fitness for a particular purpose are disclaimed. 32// In no event shall the Intel Corporation or contributors be liable for any direct, 33// indirect, incidental, special, exemplary, or consequential damages 34// (including, but not limited to, procurement of substitute goods or services; 35// loss of use, data, or profits; or business interruption) however caused 36// and on any theory of liability, whether in contract, strict liability, 37// or tort (including negligence or otherwise) arising in any way out of 38// the use of this software, even if advised of the possibility of such damage. 39// 40//M*/ 41 42#include "test_precomp.hpp" 43 44using namespace std; 45using namespace cv; 46using cv::ml::TrainData; 47using cv::ml::EM; 48using cv::ml::KNearest; 49 50static 51void defaultDistribs( Mat& means, vector<Mat>& covs, int type=CV_32FC1 ) 52{ 53 float mp0[] = {0.0f, 0.0f}, cp0[] = {0.67f, 0.0f, 0.0f, 0.67f}; 54 float mp1[] = {5.0f, 0.0f}, cp1[] = {1.0f, 0.0f, 0.0f, 1.0f}; 55 float mp2[] = {1.0f, 5.0f}, cp2[] = {1.0f, 0.0f, 0.0f, 1.0f}; 56 means.create(3, 2, type); 57 Mat m0( 1, 2, CV_32FC1, mp0 ), c0( 2, 2, CV_32FC1, cp0 ); 58 Mat m1( 1, 2, CV_32FC1, mp1 ), c1( 2, 2, CV_32FC1, cp1 ); 59 Mat m2( 1, 2, CV_32FC1, mp2 ), c2( 2, 2, CV_32FC1, cp2 ); 60 means.resize(3), covs.resize(3); 61 62 Mat mr0 = means.row(0); 63 m0.convertTo(mr0, type); 64 c0.convertTo(covs[0], type); 65 66 Mat mr1 = means.row(1); 67 m1.convertTo(mr1, type); 68 c1.convertTo(covs[1], type); 69 70 Mat mr2 = means.row(2); 71 m2.convertTo(mr2, type); 72 c2.convertTo(covs[2], type); 73} 74 75// generate points sets by normal distributions 76static 77void generateData( Mat& data, Mat& labels, const vector<int>& sizes, const Mat& _means, const vector<Mat>& covs, int dataType, int labelType ) 78{ 79 vector<int>::const_iterator sit = sizes.begin(); 80 int total = 0; 81 for( ; sit != sizes.end(); ++sit ) 82 total += *sit; 83 CV_Assert( _means.rows == (int)sizes.size() && covs.size() == sizes.size() ); 84 CV_Assert( !data.empty() && data.rows == total ); 85 CV_Assert( data.type() == dataType ); 86 87 labels.create( data.rows, 1, labelType ); 88 89 randn( data, Scalar::all(-1.0), Scalar::all(1.0) ); 90 vector<Mat> means(sizes.size()); 91 for(int i = 0; i < _means.rows; i++) 92 means[i] = _means.row(i); 93 vector<Mat>::const_iterator mit = means.begin(), cit = covs.begin(); 94 int bi, ei = 0; 95 sit = sizes.begin(); 96 for( int p = 0, l = 0; sit != sizes.end(); ++sit, ++mit, ++cit, l++ ) 97 { 98 bi = ei; 99 ei = bi + *sit; 100 assert( mit->rows == 1 && mit->cols == data.cols ); 101 assert( cit->rows == data.cols && cit->cols == data.cols ); 102 for( int i = bi; i < ei; i++, p++ ) 103 { 104 Mat r = data.row(i); 105 r = r * (*cit) + *mit; 106 if( labelType == CV_32FC1 ) 107 labels.at<float>(p, 0) = (float)l; 108 else if( labelType == CV_32SC1 ) 109 labels.at<int>(p, 0) = l; 110 else 111 { 112 CV_DbgAssert(0); 113 } 114 } 115 } 116} 117 118static 119int maxIdx( const vector<int>& count ) 120{ 121 int idx = -1; 122 int maxVal = -1; 123 vector<int>::const_iterator it = count.begin(); 124 for( int i = 0; it != count.end(); ++it, i++ ) 125 { 126 if( *it > maxVal) 127 { 128 maxVal = *it; 129 idx = i; 130 } 131 } 132 assert( idx >= 0); 133 return idx; 134} 135 136static 137bool getLabelsMap( const Mat& labels, const vector<int>& sizes, vector<int>& labelsMap, bool checkClusterUniq=true ) 138{ 139 size_t total = 0, nclusters = sizes.size(); 140 for(size_t i = 0; i < sizes.size(); i++) 141 total += sizes[i]; 142 143 assert( !labels.empty() ); 144 assert( labels.total() == total && (labels.cols == 1 || labels.rows == 1)); 145 assert( labels.type() == CV_32SC1 || labels.type() == CV_32FC1 ); 146 147 bool isFlt = labels.type() == CV_32FC1; 148 149 labelsMap.resize(nclusters); 150 151 vector<bool> buzy(nclusters, false); 152 int startIndex = 0; 153 for( size_t clusterIndex = 0; clusterIndex < sizes.size(); clusterIndex++ ) 154 { 155 vector<int> count( nclusters, 0 ); 156 for( int i = startIndex; i < startIndex + sizes[clusterIndex]; i++) 157 { 158 int lbl = isFlt ? (int)labels.at<float>(i) : labels.at<int>(i); 159 CV_Assert(lbl < (int)nclusters); 160 count[lbl]++; 161 CV_Assert(count[lbl] < (int)total); 162 } 163 startIndex += sizes[clusterIndex]; 164 165 int cls = maxIdx( count ); 166 CV_Assert( !checkClusterUniq || !buzy[cls] ); 167 168 labelsMap[clusterIndex] = cls; 169 170 buzy[cls] = true; 171 } 172 173 if(checkClusterUniq) 174 { 175 for(size_t i = 0; i < buzy.size(); i++) 176 if(!buzy[i]) 177 return false; 178 } 179 180 return true; 181} 182 183static 184bool calcErr( const Mat& labels, const Mat& origLabels, const vector<int>& sizes, float& err, bool labelsEquivalent = true, bool checkClusterUniq=true ) 185{ 186 err = 0; 187 CV_Assert( !labels.empty() && !origLabels.empty() ); 188 CV_Assert( labels.rows == 1 || labels.cols == 1 ); 189 CV_Assert( origLabels.rows == 1 || origLabels.cols == 1 ); 190 CV_Assert( labels.total() == origLabels.total() ); 191 CV_Assert( labels.type() == CV_32SC1 || labels.type() == CV_32FC1 ); 192 CV_Assert( origLabels.type() == labels.type() ); 193 194 vector<int> labelsMap; 195 bool isFlt = labels.type() == CV_32FC1; 196 if( !labelsEquivalent ) 197 { 198 if( !getLabelsMap( labels, sizes, labelsMap, checkClusterUniq ) ) 199 return false; 200 201 for( int i = 0; i < labels.rows; i++ ) 202 if( isFlt ) 203 err += labels.at<float>(i) != labelsMap[(int)origLabels.at<float>(i)] ? 1.f : 0.f; 204 else 205 err += labels.at<int>(i) != labelsMap[origLabels.at<int>(i)] ? 1.f : 0.f; 206 } 207 else 208 { 209 for( int i = 0; i < labels.rows; i++ ) 210 if( isFlt ) 211 err += labels.at<float>(i) != origLabels.at<float>(i) ? 1.f : 0.f; 212 else 213 err += labels.at<int>(i) != origLabels.at<int>(i) ? 1.f : 0.f; 214 } 215 err /= (float)labels.rows; 216 return true; 217} 218 219//-------------------------------------------------------------------------------------------- 220class CV_KMeansTest : public cvtest::BaseTest { 221public: 222 CV_KMeansTest() {} 223protected: 224 virtual void run( int start_from ); 225}; 226 227void CV_KMeansTest::run( int /*start_from*/ ) 228{ 229 const int iters = 100; 230 int sizesArr[] = { 5000, 7000, 8000 }; 231 int pointsCount = sizesArr[0]+ sizesArr[1] + sizesArr[2]; 232 233 Mat data( pointsCount, 2, CV_32FC1 ), labels; 234 vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) ); 235 Mat means; 236 vector<Mat> covs; 237 defaultDistribs( means, covs ); 238 generateData( data, labels, sizes, means, covs, CV_32FC1, CV_32SC1 ); 239 240 int code = cvtest::TS::OK; 241 float err; 242 Mat bestLabels; 243 // 1. flag==KMEANS_PP_CENTERS 244 kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_PP_CENTERS, noArray() ); 245 if( !calcErr( bestLabels, labels, sizes, err , false ) ) 246 { 247 ts->printf( cvtest::TS::LOG, "Bad output labels if flag==KMEANS_PP_CENTERS.\n" ); 248 code = cvtest::TS::FAIL_INVALID_OUTPUT; 249 } 250 else if( err > 0.01f ) 251 { 252 ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_PP_CENTERS.\n", err ); 253 code = cvtest::TS::FAIL_BAD_ACCURACY; 254 } 255 256 // 2. flag==KMEANS_RANDOM_CENTERS 257 kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_RANDOM_CENTERS, noArray() ); 258 if( !calcErr( bestLabels, labels, sizes, err, false ) ) 259 { 260 ts->printf( cvtest::TS::LOG, "Bad output labels if flag==KMEANS_RANDOM_CENTERS.\n" ); 261 code = cvtest::TS::FAIL_INVALID_OUTPUT; 262 } 263 else if( err > 0.01f ) 264 { 265 ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_RANDOM_CENTERS.\n", err ); 266 code = cvtest::TS::FAIL_BAD_ACCURACY; 267 } 268 269 // 3. flag==KMEANS_USE_INITIAL_LABELS 270 labels.copyTo( bestLabels ); 271 RNG rng; 272 for( int i = 0; i < 0.5f * pointsCount; i++ ) 273 bestLabels.at<int>( rng.next() % pointsCount, 0 ) = rng.next() % 3; 274 kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_USE_INITIAL_LABELS, noArray() ); 275 if( !calcErr( bestLabels, labels, sizes, err, false ) ) 276 { 277 ts->printf( cvtest::TS::LOG, "Bad output labels if flag==KMEANS_USE_INITIAL_LABELS.\n" ); 278 code = cvtest::TS::FAIL_INVALID_OUTPUT; 279 } 280 else if( err > 0.01f ) 281 { 282 ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_USE_INITIAL_LABELS.\n", err ); 283 code = cvtest::TS::FAIL_BAD_ACCURACY; 284 } 285 286 ts->set_failed_test_info( code ); 287} 288 289//-------------------------------------------------------------------------------------------- 290class CV_KNearestTest : public cvtest::BaseTest { 291public: 292 CV_KNearestTest() {} 293protected: 294 virtual void run( int start_from ); 295}; 296 297void CV_KNearestTest::run( int /*start_from*/ ) 298{ 299 int sizesArr[] = { 500, 700, 800 }; 300 int pointsCount = sizesArr[0]+ sizesArr[1] + sizesArr[2]; 301 302 // train data 303 Mat trainData( pointsCount, 2, CV_32FC1 ), trainLabels; 304 vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) ); 305 Mat means; 306 vector<Mat> covs; 307 defaultDistribs( means, covs ); 308 generateData( trainData, trainLabels, sizes, means, covs, CV_32FC1, CV_32FC1 ); 309 310 // test data 311 Mat testData( pointsCount, 2, CV_32FC1 ), testLabels, bestLabels; 312 generateData( testData, testLabels, sizes, means, covs, CV_32FC1, CV_32FC1 ); 313 314 int code = cvtest::TS::OK; 315 316 // KNearest default implementation 317 Ptr<KNearest> knearest = KNearest::create(); 318 knearest->train(trainData, ml::ROW_SAMPLE, trainLabels); 319 knearest->findNearest(testData, 4, bestLabels); 320 float err; 321 if( !calcErr( bestLabels, testLabels, sizes, err, true ) ) 322 { 323 ts->printf( cvtest::TS::LOG, "Bad output labels.\n" ); 324 code = cvtest::TS::FAIL_INVALID_OUTPUT; 325 } 326 else if( err > 0.01f ) 327 { 328 ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on test data.\n", err ); 329 code = cvtest::TS::FAIL_BAD_ACCURACY; 330 } 331 332 // KNearest KDTree implementation 333 Ptr<KNearest> knearestKdt = KNearest::create(); 334 knearestKdt->setAlgorithmType(KNearest::KDTREE); 335 knearestKdt->train(trainData, ml::ROW_SAMPLE, trainLabels); 336 knearestKdt->findNearest(testData, 4, bestLabels); 337 if( !calcErr( bestLabels, testLabels, sizes, err, true ) ) 338 { 339 ts->printf( cvtest::TS::LOG, "Bad output labels.\n" ); 340 code = cvtest::TS::FAIL_INVALID_OUTPUT; 341 } 342 else if( err > 0.01f ) 343 { 344 ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on test data.\n", err ); 345 code = cvtest::TS::FAIL_BAD_ACCURACY; 346 } 347 348 ts->set_failed_test_info( code ); 349} 350 351class EM_Params 352{ 353public: 354 EM_Params(int _nclusters=10, int _covMatType=EM::COV_MAT_DIAGONAL, int _startStep=EM::START_AUTO_STEP, 355 const cv::TermCriteria& _termCrit=cv::TermCriteria(cv::TermCriteria::COUNT+cv::TermCriteria::EPS, 100, FLT_EPSILON), 356 const cv::Mat* _probs=0, const cv::Mat* _weights=0, 357 const cv::Mat* _means=0, const std::vector<cv::Mat>* _covs=0) 358 : nclusters(_nclusters), covMatType(_covMatType), startStep(_startStep), 359 probs(_probs), weights(_weights), means(_means), covs(_covs), termCrit(_termCrit) 360 {} 361 362 int nclusters; 363 int covMatType; 364 int startStep; 365 366 // all 4 following matrices should have type CV_32FC1 367 const cv::Mat* probs; 368 const cv::Mat* weights; 369 const cv::Mat* means; 370 const std::vector<cv::Mat>* covs; 371 372 cv::TermCriteria termCrit; 373}; 374 375//-------------------------------------------------------------------------------------------- 376class CV_EMTest : public cvtest::BaseTest 377{ 378public: 379 CV_EMTest() {} 380protected: 381 virtual void run( int start_from ); 382 int runCase( int caseIndex, const EM_Params& params, 383 const cv::Mat& trainData, const cv::Mat& trainLabels, 384 const cv::Mat& testData, const cv::Mat& testLabels, 385 const vector<int>& sizes); 386}; 387 388int CV_EMTest::runCase( int caseIndex, const EM_Params& params, 389 const cv::Mat& trainData, const cv::Mat& trainLabels, 390 const cv::Mat& testData, const cv::Mat& testLabels, 391 const vector<int>& sizes ) 392{ 393 int code = cvtest::TS::OK; 394 395 cv::Mat labels; 396 float err; 397 398 Ptr<EM> em = EM::create(); 399 em->setClustersNumber(params.nclusters); 400 em->setCovarianceMatrixType(params.covMatType); 401 em->setTermCriteria(params.termCrit); 402 if( params.startStep == EM::START_AUTO_STEP ) 403 em->trainEM( trainData, noArray(), labels, noArray() ); 404 else if( params.startStep == EM::START_E_STEP ) 405 em->trainE( trainData, *params.means, *params.covs, 406 *params.weights, noArray(), labels, noArray() ); 407 else if( params.startStep == EM::START_M_STEP ) 408 em->trainM( trainData, *params.probs, 409 noArray(), labels, noArray() ); 410 411 // check train error 412 if( !calcErr( labels, trainLabels, sizes, err , false, false ) ) 413 { 414 ts->printf( cvtest::TS::LOG, "Case index %i : Bad output labels.\n", caseIndex ); 415 code = cvtest::TS::FAIL_INVALID_OUTPUT; 416 } 417 else if( err > 0.008f ) 418 { 419 ts->printf( cvtest::TS::LOG, "Case index %i : Bad accuracy (%f) on train data.\n", caseIndex, err ); 420 code = cvtest::TS::FAIL_BAD_ACCURACY; 421 } 422 423 // check test error 424 labels.create( testData.rows, 1, CV_32SC1 ); 425 for( int i = 0; i < testData.rows; i++ ) 426 { 427 Mat sample = testData.row(i); 428 Mat probs; 429 labels.at<int>(i) = static_cast<int>(em->predict2( sample, probs )[1]); 430 } 431 if( !calcErr( labels, testLabels, sizes, err, false, false ) ) 432 { 433 ts->printf( cvtest::TS::LOG, "Case index %i : Bad output labels.\n", caseIndex ); 434 code = cvtest::TS::FAIL_INVALID_OUTPUT; 435 } 436 else if( err > 0.008f ) 437 { 438 ts->printf( cvtest::TS::LOG, "Case index %i : Bad accuracy (%f) on test data.\n", caseIndex, err ); 439 code = cvtest::TS::FAIL_BAD_ACCURACY; 440 } 441 442 return code; 443} 444 445void CV_EMTest::run( int /*start_from*/ ) 446{ 447 int sizesArr[] = { 500, 700, 800 }; 448 int pointsCount = sizesArr[0]+ sizesArr[1] + sizesArr[2]; 449 450 // Points distribution 451 Mat means; 452 vector<Mat> covs; 453 defaultDistribs( means, covs, CV_64FC1 ); 454 455 // train data 456 Mat trainData( pointsCount, 2, CV_64FC1 ), trainLabels; 457 vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) ); 458 generateData( trainData, trainLabels, sizes, means, covs, CV_64FC1, CV_32SC1 ); 459 460 // test data 461 Mat testData( pointsCount, 2, CV_64FC1 ), testLabels; 462 generateData( testData, testLabels, sizes, means, covs, CV_64FC1, CV_32SC1 ); 463 464 EM_Params params; 465 params.nclusters = 3; 466 Mat probs(trainData.rows, params.nclusters, CV_64FC1, cv::Scalar(1)); 467 params.probs = &probs; 468 Mat weights(1, params.nclusters, CV_64FC1, cv::Scalar(1)); 469 params.weights = &weights; 470 params.means = &means; 471 params.covs = &covs; 472 473 int code = cvtest::TS::OK; 474 int caseIndex = 0; 475 { 476 params.startStep = EM::START_AUTO_STEP; 477 params.covMatType = EM::COV_MAT_GENERIC; 478 int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes); 479 code = currCode == cvtest::TS::OK ? code : currCode; 480 } 481 { 482 params.startStep = EM::START_AUTO_STEP; 483 params.covMatType = EM::COV_MAT_DIAGONAL; 484 int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes); 485 code = currCode == cvtest::TS::OK ? code : currCode; 486 } 487 { 488 params.startStep = EM::START_AUTO_STEP; 489 params.covMatType = EM::COV_MAT_SPHERICAL; 490 int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes); 491 code = currCode == cvtest::TS::OK ? code : currCode; 492 } 493 { 494 params.startStep = EM::START_M_STEP; 495 params.covMatType = EM::COV_MAT_GENERIC; 496 int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes); 497 code = currCode == cvtest::TS::OK ? code : currCode; 498 } 499 { 500 params.startStep = EM::START_M_STEP; 501 params.covMatType = EM::COV_MAT_DIAGONAL; 502 int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes); 503 code = currCode == cvtest::TS::OK ? code : currCode; 504 } 505 { 506 params.startStep = EM::START_M_STEP; 507 params.covMatType = EM::COV_MAT_SPHERICAL; 508 int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes); 509 code = currCode == cvtest::TS::OK ? code : currCode; 510 } 511 { 512 params.startStep = EM::START_E_STEP; 513 params.covMatType = EM::COV_MAT_GENERIC; 514 int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes); 515 code = currCode == cvtest::TS::OK ? code : currCode; 516 } 517 { 518 params.startStep = EM::START_E_STEP; 519 params.covMatType = EM::COV_MAT_DIAGONAL; 520 int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes); 521 code = currCode == cvtest::TS::OK ? code : currCode; 522 } 523 { 524 params.startStep = EM::START_E_STEP; 525 params.covMatType = EM::COV_MAT_SPHERICAL; 526 int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes); 527 code = currCode == cvtest::TS::OK ? code : currCode; 528 } 529 530 ts->set_failed_test_info( code ); 531} 532 533class CV_EMTest_SaveLoad : public cvtest::BaseTest { 534public: 535 CV_EMTest_SaveLoad() {} 536protected: 537 virtual void run( int /*start_from*/ ) 538 { 539 int code = cvtest::TS::OK; 540 const int nclusters = 2; 541 542 Mat samples = Mat(3,1,CV_64FC1); 543 samples.at<double>(0,0) = 1; 544 samples.at<double>(1,0) = 2; 545 samples.at<double>(2,0) = 3; 546 547 Mat labels; 548 549 Ptr<EM> em = EM::create(); 550 em->setClustersNumber(nclusters); 551 em->trainEM(samples, noArray(), labels, noArray()); 552 553 Mat firstResult(samples.rows, 1, CV_32SC1); 554 for( int i = 0; i < samples.rows; i++) 555 firstResult.at<int>(i) = static_cast<int>(em->predict2(samples.row(i), noArray())[1]); 556 557 // Write out 558 string filename = cv::tempfile(".xml"); 559 { 560 FileStorage fs = FileStorage(filename, FileStorage::WRITE); 561 try 562 { 563 fs << "em" << "{"; 564 em->write(fs); 565 fs << "}"; 566 } 567 catch(...) 568 { 569 ts->printf( cvtest::TS::LOG, "Crash in write method.\n" ); 570 ts->set_failed_test_info( cvtest::TS::FAIL_EXCEPTION ); 571 } 572 } 573 574 em.release(); 575 576 // Read in 577 try 578 { 579 em = Algorithm::load<EM>(filename); 580 } 581 catch(...) 582 { 583 ts->printf( cvtest::TS::LOG, "Crash in read method.\n" ); 584 ts->set_failed_test_info( cvtest::TS::FAIL_EXCEPTION ); 585 } 586 587 remove( filename.c_str() ); 588 589 int errCaseCount = 0; 590 for( int i = 0; i < samples.rows; i++) 591 errCaseCount = std::abs(em->predict2(samples.row(i), noArray())[1] - firstResult.at<int>(i)) < FLT_EPSILON ? 0 : 1; 592 593 if( errCaseCount > 0 ) 594 { 595 ts->printf( cvtest::TS::LOG, "Different prediction results before writeing and after reading (errCaseCount=%d).\n", errCaseCount ); 596 code = cvtest::TS::FAIL_BAD_ACCURACY; 597 } 598 599 ts->set_failed_test_info( code ); 600 } 601}; 602 603class CV_EMTest_Classification : public cvtest::BaseTest 604{ 605public: 606 CV_EMTest_Classification() {} 607protected: 608 virtual void run(int) 609 { 610 // This test classifies spam by the following way: 611 // 1. estimates distributions of "spam" / "not spam" 612 // 2. predict classID using Bayes classifier for estimated distributions. 613 614 string dataFilename = string(ts->get_data_path()) + "spambase.data"; 615 Ptr<TrainData> data = TrainData::loadFromCSV(dataFilename, 0); 616 617 if( data.empty() ) 618 { 619 ts->printf(cvtest::TS::LOG, "File with spambase dataset cann't be read.\n"); 620 ts->set_failed_test_info(cvtest::TS::FAIL_INVALID_TEST_DATA); 621 } 622 623 Mat samples = data->getSamples(); 624 CV_Assert(samples.cols == 57); 625 Mat responses = data->getResponses(); 626 627 vector<int> trainSamplesMask(samples.rows, 0); 628 int trainSamplesCount = (int)(0.5f * samples.rows); 629 for(int i = 0; i < trainSamplesCount; i++) 630 trainSamplesMask[i] = 1; 631 RNG rng(0); 632 for(size_t i = 0; i < trainSamplesMask.size(); i++) 633 { 634 int i1 = rng(static_cast<unsigned>(trainSamplesMask.size())); 635 int i2 = rng(static_cast<unsigned>(trainSamplesMask.size())); 636 std::swap(trainSamplesMask[i1], trainSamplesMask[i2]); 637 } 638 639 Mat samples0, samples1; 640 for(int i = 0; i < samples.rows; i++) 641 { 642 if(trainSamplesMask[i]) 643 { 644 Mat sample = samples.row(i); 645 int resp = (int)responses.at<float>(i); 646 if(resp == 0) 647 samples0.push_back(sample); 648 else 649 samples1.push_back(sample); 650 } 651 } 652 Ptr<EM> model0 = EM::create(); 653 model0->setClustersNumber(3); 654 model0->trainEM(samples0, noArray(), noArray(), noArray()); 655 656 Ptr<EM> model1 = EM::create(); 657 model1->setClustersNumber(3); 658 model1->trainEM(samples1, noArray(), noArray(), noArray()); 659 660 Mat trainConfusionMat(2, 2, CV_32SC1, Scalar(0)), 661 testConfusionMat(2, 2, CV_32SC1, Scalar(0)); 662 const double lambda = 1.; 663 for(int i = 0; i < samples.rows; i++) 664 { 665 Mat sample = samples.row(i); 666 double sampleLogLikelihoods0 = model0->predict2(sample, noArray())[0]; 667 double sampleLogLikelihoods1 = model1->predict2(sample, noArray())[0]; 668 669 int classID = sampleLogLikelihoods0 >= lambda * sampleLogLikelihoods1 ? 0 : 1; 670 671 if(trainSamplesMask[i]) 672 trainConfusionMat.at<int>((int)responses.at<float>(i), classID)++; 673 else 674 testConfusionMat.at<int>((int)responses.at<float>(i), classID)++; 675 } 676// std::cout << trainConfusionMat << std::endl; 677// std::cout << testConfusionMat << std::endl; 678 679 double trainError = (double)(trainConfusionMat.at<int>(1,0) + trainConfusionMat.at<int>(0,1)) / trainSamplesCount; 680 double testError = (double)(testConfusionMat.at<int>(1,0) + testConfusionMat.at<int>(0,1)) / (samples.rows - trainSamplesCount); 681 const double maxTrainError = 0.23; 682 const double maxTestError = 0.26; 683 684 int code = cvtest::TS::OK; 685 if(trainError > maxTrainError) 686 { 687 ts->printf(cvtest::TS::LOG, "Too large train classification error (calc = %f, valid=%f).\n", trainError, maxTrainError); 688 code = cvtest::TS::FAIL_INVALID_TEST_DATA; 689 } 690 if(testError > maxTestError) 691 { 692 ts->printf(cvtest::TS::LOG, "Too large test classification error (calc = %f, valid=%f).\n", testError, maxTestError); 693 code = cvtest::TS::FAIL_INVALID_TEST_DATA; 694 } 695 696 ts->set_failed_test_info(code); 697 } 698}; 699 700TEST(ML_KMeans, accuracy) { CV_KMeansTest test; test.safe_run(); } 701TEST(ML_KNearest, accuracy) { CV_KNearestTest test; test.safe_run(); } 702TEST(ML_EM, accuracy) { CV_EMTest test; test.safe_run(); } 703TEST(ML_EM, save_load) { CV_EMTest_SaveLoad test; test.safe_run(); } 704TEST(ML_EM, classification) { CV_EMTest_Classification test; test.safe_run(); } 705