16acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/*M/////////////////////////////////////////////////////////////////////////////////////// 26acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 36acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 46acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 56acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// By downloading, copying, installing or using the software you agree to this license. 66acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// If you do not agree to this license, do not download, install, 76acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// copy or use the software. 86acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 96acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Intel License Agreement 116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// For Open Source Computer Vision Library 126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Copyright( C) 2000, Intel Corporation, all rights reserved. 146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Third party copyrights are property of their respective owners. 156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Redistribution and use in source and binary forms, with or without modification, 176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// are permitted provided that the following conditions are met: 186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// * Redistribution's of source code must retain the above copyright notice, 206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// this list of conditions and the following disclaimer. 216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// * Redistribution's in binary form must reproduce the above copyright notice, 236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// this list of conditions and the following disclaimer in the documentation 246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// and/or other materials provided with the distribution. 256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// * The name of Intel Corporation may not be used to endorse or promote products 276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// derived from this software without specific prior written permission. 286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// This software is provided by the copyright holders and contributors "as is" and 306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// any express or implied warranties, including, but not limited to, the implied 316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// warranties of merchantability and fitness for a particular purpose are disclaimed. 326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// In no event shall the Intel Corporation or contributors be liable for any direct, 336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// indirect, incidental, special, exemplary, or consequential damages 346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//(including, but not limited to, procurement of substitute goods or services; 356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// loss of use, data, or profits; or business interruption) however caused 366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// and on any theory of liability, whether in contract, strict liability, 376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// or tort(including negligence or otherwise) arising in any way out of 386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// the use of this software, even ifadvised of the possibility of such damage. 396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// 406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//M*/ 416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#include "_ml.h" 436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/* 466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvEM: 476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn * params.nclusters - number of clusters to cluster samples to. 486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn * means - calculated by the EM algorithm set of gaussians' means. 496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn * log_weight_div_det - auxilary vector that k-th component is equal to 506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn (-2)*ln(weights_k/det(Sigma_k)^0.5), 516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn where <weights_k> is the weight, 526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn <Sigma_k> is the covariation matrice of k-th cluster. 536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn * inv_eigen_values - set of 1*dims matrices, <inv_eigen_values>[k] contains 546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn inversed eigen values of covariation matrice of the k-th cluster. 556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn In the case of <cov_mat_type> == COV_MAT_DIAGONAL, 566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn inv_eigen_values[k] = Sigma_k^(-1). 576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn * covs_rotate_mats - used only if cov_mat_type == COV_MAT_GENERIC, in all the 586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn other cases it is NULL. <covs_rotate_mats>[k] is the orthogonal 596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn matrice, obtained by the SVD-decomposition of Sigma_k. 606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn Both <inv_eigen_values> and <covs_rotate_mats> fields are used for representation of 616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn covariation matrices and simplifying EM calculations. 626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn For fixed k denote 636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn u = covs_rotate_mats[k], 646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn v = inv_eigen_values[k], 656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w = v^(-1); 666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if <cov_mat_type> == COV_MAT_GENERIC, then Sigma_k = u w u', 676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else Sigma_k = w. 686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn Symbol ' means transposition. 696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn */ 706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvEM::CvEM() 736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn means = weights = probs = inv_eigen_values = log_weight_div_det = 0; 756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn covs = cov_rotate_mats = 0; 766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvEM::CvEM( const CvMat* samples, const CvMat* sample_idx, 796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvEMParams params, CvMat* labels ) 806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn means = weights = probs = inv_eigen_values = log_weight_div_det = 0; 826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn covs = cov_rotate_mats = 0; 836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // just invoke the train() method 856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn train(samples, sample_idx, params, labels); 866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvEM::~CvEM() 896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn clear(); 916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvEM::clear() 956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i; 976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &means ); 996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &weights ); 1006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &probs ); 1016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &inv_eigen_values ); 1026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &log_weight_div_det ); 1036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( covs || cov_rotate_mats ) 1056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 1066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < params.nclusters; i++ ) 1076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 1086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( covs ) 1096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &covs[i] ); 1106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cov_rotate_mats ) 1116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &cov_rotate_mats[i] ); 1126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 1136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &covs ); 1146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &cov_rotate_mats ); 1156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 1166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 1176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvEM::set_params( const CvEMParams& _params, const CvVectors& train_data ) 1206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 1216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvEM::set_params" ); 1226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 1246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int k; 1266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params = _params; 1286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.term_crit = cvCheckTermCriteria( params.term_crit, 1e-6, 10000 ); 1296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.cov_mat_type != COV_MAT_SPHERICAL && 1316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.cov_mat_type != COV_MAT_DIAGONAL && 1326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.cov_mat_type != COV_MAT_GENERIC ) 1336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Unknown covariation matrix type" ); 1346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn switch( params.start_step ) 1366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 1376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn case START_M_STEP: 1386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !params.probs ) 1396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsNullPtr, "Probabilities must be specified when EM algorithm starts with M-step" ); 1406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 1416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn case START_E_STEP: 1426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !params.means ) 1436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsNullPtr, "Mean's must be specified when EM algorithm starts with E-step" ); 1446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 1456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn case START_AUTO_STEP: 1466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 1476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn default: 1486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "Unknown start_step" ); 1496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 1506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.nclusters < 1 ) 1526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsOutOfRange, "The number of clusters (mixtures) should be > 0" ); 1536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.probs ) 1556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 1566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* p = params.weights; 1576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_IS_MAT(p) || 1586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_MAT_TYPE(p->type) != CV_32FC1 && 1596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_MAT_TYPE(p->type) != CV_64FC1 || 1606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn p->rows != train_data.count || 1616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn p->cols != params.nclusters ) 1626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "The array of probabilities must be a valid " 1636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "floating-point matrix (CvMat) of 'nsamples' x 'nclusters' size" ); 1646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 1656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.means ) 1676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 1686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* m = params.means; 1696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_IS_MAT(m) || 1706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_MAT_TYPE(m->type) != CV_32FC1 && 1716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_MAT_TYPE(m->type) != CV_64FC1 || 1726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn m->rows != params.nclusters || 1736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn m->cols != train_data.dims ) 1746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "The array of mean's must be a valid " 1756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "floating-point matrix (CvMat) of 'nsamples' x 'dims' size" ); 1766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 1776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.weights ) 1796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 1806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* w = params.weights; 1816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_IS_MAT(w) || 1826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_MAT_TYPE(w->type) != CV_32FC1 && 1836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_MAT_TYPE(w->type) != CV_64FC1 || 1846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w->rows != 1 && w->cols != 1 || 1856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w->rows + w->cols - 1 != params.nclusters ) 1866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, "The array of weights must be a valid " 1876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "1d floating-point vector (CvMat) of 'nclusters' elements" ); 1886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 1896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.covs ) 1916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < params.nclusters; k++ ) 1926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 1936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const CvMat* cov = params.covs[k]; 1946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( !CV_IS_MAT(cov) || 1956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_MAT_TYPE(cov->type) != CV_32FC1 && 1966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_MAT_TYPE(cov->type) != CV_64FC1 || 1976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cov->rows != cov->cols || cov->cols != train_data.dims ) 1986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, 1996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "Each of covariation matrices must be a valid square " 2006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "floating-point matrix (CvMat) of 'dims' x 'dims'" ); 2016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 2046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 2056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/ 2086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennfloat 2096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvEM::predict( const CvMat* _sample, CvMat* _probs ) const 2106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 2116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn float* sample_data = 0; 2126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn void* buffer = 0; 2136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int allocated_buffer = 0; 2146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int cls = 0; 2156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvEM::predict" ); 2176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 2186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, k, dims; 2206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int nclusters; 2216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int cov_mat_type = params.cov_mat_type; 2226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double opt = FLT_MAX; 2236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn size_t size; 2246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat diff, expo; 2256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dims = means->cols; 2276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn nclusters = params.nclusters; 2286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvPreparePredictData( _sample, dims, 0, params.nclusters, _probs, &sample_data )); 2306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// allocate memory and initializing headers for calculating 2326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn size = sizeof(double) * (nclusters + dims); 2336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( size <= CV_MAX_LOCAL_SIZE ) 2346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn buffer = cvStackAlloc( size ); 2356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 2366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( buffer = cvAlloc( size )); 2386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn allocated_buffer = 1; 2396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn expo = cvMat( 1, nclusters, CV_64FC1, buffer ); 2416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn diff = cvMat( 1, dims, CV_64FC1, (double*)buffer + nclusters ); 2426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// calculate the probabilities 2446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < nclusters; k++ ) 2456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* mean_k = (const double*)(means->data.ptr + means->step*k); 2476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* w = (const double*)(inv_eigen_values->data.ptr + inv_eigen_values->step*k); 2486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double cur = log_weight_div_det->data.db[k]; 2496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* u = cov_rotate_mats[k]; 2506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // cov = u w u' --> cov^(-1) = u w^(-1) u' 2516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cov_mat_type == COV_MAT_SPHERICAL ) 2526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double w0 = w[0]; 2546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < dims; i++ ) 2556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double val = sample_data[i] - mean_k[i]; 2576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cur += val*val*w0; 2586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 2616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < dims; i++ ) 2636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn diff.data.db[i] = sample_data[i] - mean_k[i]; 2646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cov_mat_type == COV_MAT_GENERIC ) 2656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGEMM( &diff, u, 1, 0, 0, &diff, CV_GEMM_B_T ); 2666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < dims; i++ ) 2676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double val = diff.data.db[i]; 2696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cur += val*val*w[i]; 2706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn expo.data.db[k] = cur; 2746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cur < opt ) 2756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cls = k; 2776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn opt = cur; 2786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn /* probability = (2*pi)^(-dims/2)*exp( -0.5 * cur ) */ 2806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( _probs ) 2836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 2846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvConvertScale( &expo, &expo, -0.5 )); 2856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvExp( &expo, &expo )); 2866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( _probs->cols == 1 ) 2876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvReshape( &expo, &expo, 0, nclusters )); 2886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvConvertScale( &expo, _probs, 1./cvSum( &expo ).val[0] )); 2896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 2906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 2926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( sample_data != _sample->data.fl ) 2946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &sample_data ); 2956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( allocated_buffer ) 2966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &buffer ); 2976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return (float)cls; 2996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 3006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool CvEM::train( const CvMat* _samples, const CvMat* _sample_idx, 3046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvEMParams _params, CvMat* labels ) 3056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 3066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn bool result = false; 3076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvVectors train_data; 3086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* sample_idx = 0; 3096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn train_data.data.fl = 0; 3116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn train_data.count = 0; 3126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME("cvEM"); 3146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 3166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, nsamples, nclusters, dims; 3186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn clear(); 3206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvPrepareTrainData( "cvEM", 3226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _samples, CV_ROW_SAMPLE, 0, CV_VAR_CATEGORICAL, 3236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 0, _sample_idx, false, (const float***)&train_data.data.fl, 3246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn &train_data.count, &train_data.dims, &train_data.dims, 3256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 0, 0, 0, &sample_idx )); 3266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( set_params( _params, train_data )); 3286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn nsamples = train_data.count; 3296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn nclusters = params.nclusters; 3306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dims = train_data.dims; 3316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( labels && (!CV_IS_MAT(labels) || CV_MAT_TYPE(labels->type) != CV_32SC1 || 3336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn labels->cols != 1 && labels->rows != 1 || labels->cols + labels->rows - 1 != nsamples )) 3346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsBadArg, 3356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "labels array (when passed) must be a valid 1d integer vector of <sample_count> elements" ); 3366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( nsamples <= nclusters ) 3386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ERROR( CV_StsOutOfRange, 3396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn "The number of samples should be greater than the number of clusters" ); 3406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( log_weight_div_det = cvCreateMat( 1, nclusters, CV_64FC1 )); 3426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( probs = cvCreateMat( nsamples, nclusters, CV_64FC1 )); 3436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( means = cvCreateMat( nclusters, dims, CV_64FC1 )); 3446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( weights = cvCreateMat( 1, nclusters, CV_64FC1 )); 3456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( inv_eigen_values = cvCreateMat( nclusters, 3466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.cov_mat_type == COV_MAT_SPHERICAL ? 1 : dims, CV_64FC1 )); 3476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( covs = (CvMat**)cvAlloc( nclusters * sizeof(*covs) )); 3486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cov_rotate_mats = (CvMat**)cvAlloc( nclusters * sizeof(cov_rotate_mats[0]) )); 3496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < nclusters; i++ ) 3516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( covs[i] = cvCreateMat( dims, dims, CV_64FC1 )); 3536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cov_rotate_mats[i] = cvCreateMat( dims, dims, CV_64FC1 )); 3546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvZero( cov_rotate_mats[i] ); 3556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn init_em( train_data ); 3586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn log_likelihood = run_em( train_data ); 3596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( log_likelihood <= -DBL_MAX/10000. ) 3606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn EXIT; 3616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( labels ) 3636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( nclusters == 1 ) 3656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvZero( labels ); 3666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 3676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat sample = cvMat( 1, dims, CV_32F ); 3696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat prob = cvMat( 1, nclusters, CV_64F ); 3706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int lstep = CV_IS_MAT_CONT(labels->type) ? 1 : labels->step/sizeof(int); 3716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < nsamples; i++ ) 3736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 3746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int idx = sample_idx ? sample_idx->data.i[i] : i; 3756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sample.data.ptr = _samples->data.ptr + _samples->step*idx; 3766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn prob.data.ptr = probs->data.ptr + probs->step*i; 3776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn labels->data.i[i*lstep] = cvRound(predict(&sample, &prob)); 3796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 3826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn result = true; 3846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 3866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( sample_idx != _sample_idx ) 3886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &sample_idx ); 3896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &train_data.data.ptr ); 3916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return result; 3936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 3946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 3966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvEM::init_em( const CvVectors& train_data ) 3976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 3986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat *w = 0, *u = 0, *tcov = 0; 3996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvEM::init_em" ); 4016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 4036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double maxval = 0; 4056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, force_symm_plus = 0; 4066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int nclusters = params.nclusters, nsamples = train_data.count, dims = train_data.dims; 4076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.start_step == START_AUTO_STEP || nclusters == 1 || nclusters == nsamples ) 4096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn init_auto( train_data ); 4106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( params.start_step == START_M_STEP ) 4116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < nsamples; i++ ) 4136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat prob; 4156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGetRow( params.probs, &prob, i ); 4166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvMaxS( &prob, 0., &prob ); 4176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvMinMaxLoc( &prob, 0, &maxval ); 4186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( maxval < FLT_EPSILON ) 4196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSet( &prob, cvScalar(1./nclusters) ); 4206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 4216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvNormalize( &prob, &prob, 1., 0, CV_L1 ); 4226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn EXIT; // do not preprocess covariation matrices, 4246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // as in this case they are initialized at the first iteration of EM 4256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 4276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_ASSERT( params.start_step == START_E_STEP && params.means ); 4296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.weights && params.covs ) 4306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvConvert( params.means, means ); 4326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReshape( weights, weights, 1, params.weights->rows ); 4336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvConvert( params.weights, weights ); 4346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReshape( weights, weights, 1, 1 ); 4356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvMaxS( weights, 0., weights ); 4366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvMinMaxLoc( weights, 0, &maxval ); 4376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( maxval < FLT_EPSILON ) 4386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSet( &weights, cvScalar(1./nclusters) ); 4396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvNormalize( weights, weights, 1., 0, CV_L1 ); 4406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < nclusters; i++ ) 4416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvConvert( params.covs[i], covs[i] )); 4426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn force_symm_plus = 1; 4436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 4456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn init_auto( train_data ); 4466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( tcov = cvCreateMat( dims, dims, CV_64FC1 )); 4496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( w = cvCreateMat( dims, dims, CV_64FC1 )); 4506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.cov_mat_type == COV_MAT_GENERIC ) 4516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( u = cvCreateMat( dims, dims, CV_64FC1 )); 4526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < nclusters; i++ ) 4546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( force_symm_plus ) 4566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvTranspose( covs[i], tcov ); 4586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvAddWeighted( covs[i], 0.5, tcov, 0.5, 0, tcov ); 4596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 4616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvCopy( covs[i], tcov ); 4626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSVD( tcov, w, u, 0, CV_SVD_MODIFY_A + CV_SVD_U_T + CV_SVD_V_T ); 4636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.cov_mat_type == COV_MAT_SPHERICAL ) 4646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSetIdentity( covs[i], cvScalar(cvTrace(w).val[0]/dims) ); 4656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( params.cov_mat_type == COV_MAT_DIAGONAL ) 4666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvCopy( w, covs[i] ); 4676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 4686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // generic case: covs[i] = (u')'*max(w,0)*u' 4706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGEMM( u, w, 1, 0, 0, tcov, CV_GEMM_A_T ); 4716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGEMM( tcov, u, 1, 0, 0, covs[i], 0 ); 4726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 4746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 4766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &w ); 4786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &u ); 4796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &tcov ); 4806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 4816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvEM::init_auto( const CvVectors& train_data ) 4846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 4856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* hdr = 0; 4866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const void** vec = 0; 4876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* class_ranges = 0; 4886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* labels = 0; 4896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvEM::init_auto" ); 4916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 4936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int nclusters = params.nclusters, nsamples = train_data.count, dims = train_data.dims; 4956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, j; 4966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 4976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( nclusters == nsamples ) 4986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 4996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat src = cvMat( 1, dims, CV_32F ); 5006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat dst = cvMat( 1, dims, CV_64F ); 5016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < nsamples; i++ ) 5026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn src.data.ptr = train_data.data.ptr[i]; 5046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dst.data.ptr = means->data.ptr + means->step*i; 5056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvConvert( &src, &dst ); 5066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvZero( covs[i] ); 5076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSetIdentity( cov_rotate_mats[i] ); 5086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSetIdentity( probs ); 5106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSet( weights, cvScalar(1./nclusters) ); 5116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 5136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int max_count = 0; 5156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( class_ranges = cvCreateMat( 1, nclusters+1, CV_32SC1 )); 5176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( nclusters > 1 ) 5186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( labels = cvCreateMat( 1, nsamples, CV_32SC1 )); 5206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn kmeans( train_data, nclusters, labels, cvTermCriteria( CV_TERMCRIT_ITER, 5216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn params.means ? 1 : 10, 0.5 ), params.means ); 5226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvSortSamplesByClasses( (const float**)train_data.data.fl, 5236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn labels, class_ranges->data.i )); 5246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 5266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn class_ranges->data.i[0] = 0; 5286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn class_ranges->data.i[1] = nsamples; 5296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < nclusters; i++ ) 5326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int left = class_ranges->data.i[i], right = class_ranges->data.i[i+1]; 5346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_count = MAX( max_count, right - left ); 5356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( hdr = (CvMat*)cvAlloc( max_count*sizeof(hdr[0]) )); 5376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( vec = (const void**)cvAlloc( max_count*sizeof(vec[0]) )); 5386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn hdr[0] = cvMat( 1, dims, CV_32F ); 5396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < max_count; i++ ) 5406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn vec[i] = hdr + i; 5426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn hdr[i] = hdr[0]; 5436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < nclusters; i++ ) 5466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 5476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int left = class_ranges->data.i[i], right = class_ranges->data.i[i+1]; 5486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int cluster_size = right - left; 5496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat avg; 5506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( cluster_size <= 0 ) 5526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn continue; 5536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = left; j < right; j++ ) 5556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn hdr[j - left].data.fl = train_data.data.fl[j]; 5566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvGetRow( means, &avg, i )); 5586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvCalcCovarMatrix( vec, cluster_size, covs[i], 5596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn &avg, CV_COVAR_NORMAL | CV_COVAR_SCALE )); 5606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weights->data.db[i] = (double)cluster_size/(double)nsamples; 5616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 5636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 5656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &class_ranges ); 5676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &labels ); 5686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &hdr ); 5696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvFree( &vec ); 5706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 5716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvEM::kmeans( const CvVectors& train_data, int nclusters, CvMat* labels, 5746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvTermCriteria termcrit, const CvMat* centers0 ) 5756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 5766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* centers = 0; 5776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* old_centers = 0; 5786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* counters = 0; 5796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvEM::kmeans" ); 5816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 5836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvRNG rng = cvRNG(-1); 5856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, j, k, nsamples, dims; 5866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int iter = 0; 5876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double max_dist = DBL_MAX; 5886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn termcrit = cvCheckTermCriteria( termcrit, 1e-6, 100 ); 5906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn termcrit.epsilon *= termcrit.epsilon; 5916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn nsamples = train_data.count; 5926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dims = train_data.dims; 5936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn nclusters = MIN( nclusters, nsamples ); 5946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 5956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( centers = cvCreateMat( nclusters, dims, CV_64FC1 )); 5966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( old_centers = cvCreateMat( nclusters, dims, CV_64FC1 )); 5976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( counters = cvCreateMat( 1, nclusters, CV_32SC1 )); 5986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvZero( old_centers ); 5996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( centers0 ) 6016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvConvert( centers0, centers )); 6036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 6056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < nsamples; i++ ) 6076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn labels->data.i[i] = i*nclusters/nsamples; 6086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvRandShuffle( labels, &rng ); 6096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ;; ) 6126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* temp; 6146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( iter > 0 || centers0 ) 6166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < nsamples; i++ ) 6186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float* s = train_data.data.fl[i]; 6206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int k_best = 0; 6216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double min_dist = DBL_MAX; 6226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < nclusters; k++ ) 6246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* c = (double*)(centers->data.ptr + k*centers->step); 6266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double dist = 0; 6276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j <= dims - 4; j += 4 ) 6296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double t0 = c[j] - s[j]; 6316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double t1 = c[j+1] - s[j+1]; 6326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dist += t0*t0 + t1*t1; 6336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn t0 = c[j+2] - s[j+2]; 6346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn t1 = c[j+3] - s[j+3]; 6356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dist += t0*t0 + t1*t1; 6366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ; j < dims; j++ ) 6396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double t = c[j] - s[j]; 6416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dist += t*t; 6426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( min_dist > dist ) 6456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn min_dist = dist; 6476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn k_best = k; 6486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn labels->data.i[i] = k_best; 6526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( ++iter > termcrit.max_iter ) 6566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 6576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_SWAP( centers, old_centers, temp ); 6596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvZero( centers ); 6606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvZero( counters ); 6616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // update centers 6636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < nsamples; i++ ) 6646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float* s = train_data.data.fl[i]; 6666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn k = labels->data.i[i]; 6676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* c = (double*)(centers->data.ptr + k*centers->step); 6686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j <= dims - 4; j += 4 ) 6706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double t0 = c[j] + s[j]; 6726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double t1 = c[j+1] + s[j+1]; 6736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn c[j] = t0; 6756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn c[j+1] = t1; 6766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn t0 = c[j+2] + s[j+2]; 6786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn t1 = c[j+3] + s[j+3]; 6796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn c[j+2] = t0; 6816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn c[j+3] = t1; 6826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( ; j < dims; j++ ) 6846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn c[j] += s[j]; 6856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn counters->data.i[k]++; 6866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 6876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( iter > 1 ) 6896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_dist = 0; 6906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 6916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < nclusters; k++ ) 6926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* c = (double*)(centers->data.ptr + k*centers->step); 6946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( counters->data.i[k] != 0 ) 6956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 6966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double scale = 1./counters->data.i[k]; 6976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < dims; j++ ) 6986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn c[j] *= scale; 6996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 7016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float* s; 7036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < 10; j++ ) 7046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn i = cvRandInt( &rng ) % nsamples; 7066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( counters->data.i[labels->data.i[i]] > 1 ) 7076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 7086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn s = train_data.data.fl[i]; 7106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < dims; j++ ) 7116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn c[j] = s[j]; 7126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( iter > 1 ) 7156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double dist = 0; 7176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* c_o = (double*)(old_centers->data.ptr + k*old_centers->step); 7186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < dims; j++ ) 7196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double t = c[j] - c_o[j]; 7216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dist += t*t; 7226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( max_dist < dist ) 7246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_dist = dist; 7256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( max_dist < termcrit.epsilon ) 7296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 7306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvZero( counters ); 7336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < nsamples; i++ ) 7346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn counters->data.i[labels->data.i[i]]++; 7356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // ensure that we do not have empty clusters 7376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < nclusters; k++ ) 7386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( counters->data.i[k] == 0 ) 7396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for(;;) 7406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn i = cvRandInt(&rng) % nsamples; 7426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn j = labels->data.i[i]; 7436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( counters->data.i[j] > 1 ) 7446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 7456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn labels->data.i[i] = k; 7466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn counters->data.i[j]--; 7476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn counters->data.i[k]++; 7486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 7496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 7516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 7536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( ¢ers ); 7556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &old_centers ); 7566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &counters ); 7576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 7586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/ 7616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/* log_weight_div_det[k] = -2*log(weights_k) + log(det(Sigma_k))) 7626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn covs[k] = cov_rotate_mats[k] * cov_eigen_values[k] * (cov_rotate_mats[k])' 7646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cov_rotate_mats[k] are orthogonal matrices of eigenvectors and 7656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cov_eigen_values[k] are diagonal matrices (represented by 1D vectors) of eigen values. 7666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn The <alpha_ik> is the probability of the vector x_i to belong to the k-th cluster: 7686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn <alpha_ik> ~ weights_k * exp{ -0.5[ln(det(Sigma_k)) + (x_i - mu_k)' Sigma_k^(-1) (x_i - mu_k)] } 7696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn We calculate these probabilities here by the equivalent formulae: 7706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn Denote 7716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn S_ik = -0.5(log(det(Sigma_k)) + (x_i - mu_k)' Sigma_k^(-1) (x_i - mu_k)) + log(weights_k), 7726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn M_i = max_k S_ik = S_qi, so that the q-th class is the one where maximum reaches. Then 7736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn alpha_ik = exp{ S_ik - M_i } / ( 1 + sum_j!=q exp{ S_ji - M_i }) 7746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn*/ 7756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renndouble CvEM::run_em( const CvVectors& train_data ) 7766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 7776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* centered_sample = 0; 7786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* covs_item = 0; 7796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* log_det = 0; 7806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* log_weights = 0; 7816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* cov_eigen_values = 0; 7826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* samples = 0; 7836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* sum_probs = 0; 7846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn log_likelihood = -DBL_MAX; 7856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_FUNCNAME( "CvEM::run_em" ); 7876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __BEGIN__; 7886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int nsamples = train_data.count, dims = train_data.dims, nclusters = params.nclusters; 7906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double min_variation = FLT_EPSILON; 7916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double min_det_value = MAX( DBL_MIN, pow( min_variation, dims )); 7926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double likelihood_bias = -CV_LOG2PI * (double)nsamples * (double)dims / 2., _log_likelihood = -DBL_MAX; 7936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int start_step = params.start_step; 7946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 7956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int i, j, k, n; 7966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn int is_general = 0, is_diagonal = 0, is_spherical = 0; 7976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double prev_log_likelihood = -DBL_MAX / 1000., det, d; 7986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat whdr, iwhdr, diag, *w, *iw; 7996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* w_data; 8006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* sp_data; 8016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( nclusters == 1 ) 8036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double log_weight; 8056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cvSet( probs, cvScalar(1.)) ); 8066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.cov_mat_type == COV_MAT_SPHERICAL ) 8086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn d = cvTrace(*covs).val[0]/dims; 8106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn d = MAX( d, FLT_EPSILON ); 8116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn inv_eigen_values->data.db[0] = 1./d; 8126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn log_weight = pow( d, dims*0.5 ); 8136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 8156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w_data = inv_eigen_values->data.db; 8176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.cov_mat_type == COV_MAT_GENERIC ) 8196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSVD( *covs, inv_eigen_values, *cov_rotate_mats, 0, CV_SVD_U_T ); 8206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 8216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvTranspose( cvGetDiag(*covs, &diag), inv_eigen_values ); 8226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvMaxS( inv_eigen_values, FLT_EPSILON, inv_eigen_values ); 8246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0, det = 1.; j < dims; j++ ) 8256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn det *= w_data[j]; 8266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn log_weight = sqrt(det); 8276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvDiv( 0, inv_eigen_values, inv_eigen_values ); 8286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn log_weight_div_det->data.db[0] = -2*log(weights->data.db[0]/log_weight); 8316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn log_likelihood = DBL_MAX/1000.; 8326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn EXIT; 8336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( params.cov_mat_type == COV_MAT_GENERIC ) 8366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn is_general = 1; 8376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( params.cov_mat_type == COV_MAT_DIAGONAL ) 8386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn is_diagonal = 1; 8396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( params.cov_mat_type == COV_MAT_SPHERICAL ) 8406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn is_spherical = 1; 8416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn /* In the case of <cov_mat_type> == COV_MAT_DIAGONAL, the k-th row of cov_eigen_values 8426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn contains the diagonal elements (variations). In the case of 8436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn <cov_mat_type> == COV_MAT_SPHERICAL - the 0-ths elements of the vectors cov_eigen_values[k] 8446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn are to be equal to the mean of the variations over all the dimensions. */ 8456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( log_det = cvCreateMat( 1, nclusters, CV_64FC1 )); 8476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( log_weights = cvCreateMat( 1, nclusters, CV_64FC1 )); 8486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( covs_item = cvCreateMat( dims, dims, CV_64FC1 )); 8496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( centered_sample = cvCreateMat( 1, dims, CV_64FC1 )); 8506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( cov_eigen_values = cvCreateMat( inv_eigen_values->rows, inv_eigen_values->cols, CV_64FC1 )); 8516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( samples = cvCreateMat( nsamples, dims, CV_64FC1 )); 8526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL( sum_probs = cvCreateMat( 1, nclusters, CV_64FC1 )); 8536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sp_data = sum_probs->data.db; 8546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // copy the training data into double-precision matrix 8566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < nsamples; i++ ) 8576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const float* src = train_data.data.fl[i]; 8596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* dst = (double*)(samples->data.ptr + samples->step*i); 8606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < dims; j++ ) 8626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn dst[j] = src[j]; 8636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 8656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( start_step != START_M_STEP ) 8666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < nclusters; k++ ) 8686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( is_general || is_diagonal ) 8706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w = cvGetRow( cov_eigen_values, &whdr, k ); 8726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( is_general ) 8736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSVD( covs[k], w, cov_rotate_mats[k], 0, CV_SVD_U_T ); 8746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 8756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvTranspose( cvGetDiag( covs[k], &diag ), w ); 8766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w_data = w->data.db; 8776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0, det = 1.; j < dims; j++ ) 8786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn det *= w_data[j]; 8796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( det < min_det_value ) 8806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( start_step == START_AUTO_STEP ) 8826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn det = min_det_value; 8836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 8846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn EXIT; 8856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn log_det->data.db[k] = det; 8876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 8896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn d = cvTrace(covs[k]).val[0]/(double)dims; 8916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( d < min_variation ) 8926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 8936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( start_step == START_AUTO_STEP ) 8946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn d = min_variation; 8956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 8966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn EXIT; 8976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 8986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cov_eigen_values->data.db[k] = d; 8996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn log_det->data.db[k] = d; 9006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvLog( log_det, log_det ); 9046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( is_spherical ) 9056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvScale( log_det, log_det, dims ); 9066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( n = 0; n < params.term_crit.max_iter; n++ ) 9096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( n > 0 || start_step != START_M_STEP ) 9116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // e-step: compute probs_ik from means_k, covs_k and weights_k. 9136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvLog( weights, log_weights )); 9146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // S_ik = -0.5[log(det(Sigma_k)) + (x_i - mu_k)' Sigma_k^(-1) (x_i - mu_k)] + log(weights_k) 9166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < nclusters; k++ ) 9176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* u = cov_rotate_mats[k]; 9196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* mean = (double*)(means->data.ptr + means->step*k); 9206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w = cvGetRow( cov_eigen_values, &whdr, k ); 9216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn iw = cvGetRow( inv_eigen_values, &iwhdr, k ); 9226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvDiv( 0, w, iw ); 9236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w_data = (double*)(inv_eigen_values->data.ptr + inv_eigen_values->step*k); 9256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < nsamples; i++ ) 9276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double *csample = centered_sample->data.db, p = log_det->data.db[k]; 9296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn const double* sample = (double*)(samples->data.ptr + samples->step*i); 9306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* pp = (double*)(probs->data.ptr + probs->step*i); 9316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < dims; j++ ) 9326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn csample[j] = sample[j] - mean[j]; 9336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( is_general ) 9346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGEMM( centered_sample, u, 1, 0, 0, centered_sample, CV_GEMM_B_T ); 9356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < dims; j++ ) 9366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn p += csample[j]*csample[j]*w_data[is_spherical ? 0 : j]; 9376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn pp[k] = -0.5*p + log_weights->data.db[k]; 9386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // S_ik <- S_ik - max_j S_ij 9406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( k == nclusters - 1 ) 9416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double max_val = 0; 9436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < nclusters; j++ ) 9446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn max_val = MAX( max_val, pp[j] ); 9456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < nclusters; j++ ) 9466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn pp[j] -= max_val; 9476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CV_CALL(cvExp( probs, probs )); // exp( S_ik ) 9526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvZero( sum_probs ); 9536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // alpha_ik = exp( S_ik ) / sum_j exp( S_ij ), 9556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // log_likelihood = sum_i log (sum_j exp(S_ij)) 9566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0, _log_likelihood = likelihood_bias; i < nsamples; i++ ) 9576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double* pp = (double*)(probs->data.ptr + probs->step*i), sum = 0; 9596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < nclusters; j++ ) 9606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum += pp[j]; 9616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sum = 1./MAX( sum, DBL_EPSILON ); 9626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < nclusters; j++ ) 9636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double p = pp[j] *= sum; 9656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn sp_data[j] += p; 9666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _log_likelihood -= log( sum ); 9686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // check termination criteria 9716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( fabs( (_log_likelihood - prev_log_likelihood) / prev_log_likelihood ) < params.term_crit.epsilon ) 9726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn break; 9736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn prev_log_likelihood = _log_likelihood; 9746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 9756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // m-step: update means_k, covs_k and weights_k from probs_ik 9776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGEMM( probs, samples, 1, 0, 0, means, CV_GEMM_A_T ); 9786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < nclusters; k++ ) 9806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 9816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double sum = sp_data[k], inv_sum = 1./sum; 9826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn CvMat* cov = covs[k], _mean, _sample; 9836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w = cvGetRow( cov_eigen_values, &whdr, k ); 9856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w_data = w->data.db; 9866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGetRow( means, &_mean, k ); 9876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGetRow( samples, &_sample, k ); 9886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // update weights_k 9906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn weights->data.db[k] = sum; 9916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // update means_k 9936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvScale( &_mean, &_mean, inv_sum ); 9946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn // compute covs_k 9966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvZero( cov ); 9976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvZero( w ); 9986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 9996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( i = 0; i < nsamples; i++ ) 10006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double p = probs->data.db[i*nclusters + k]*inv_sum; 10026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn _sample.data.db = (double*)(samples->data.ptr + samples->step*i); 10036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( is_general ) 10056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvMulTransposed( &_sample, covs_item, 1, &_mean ); 10076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvScaleAdd( covs_item, cvRealScalar(p), cov, cov ); 10086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 10106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0; j < dims; j++ ) 10116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn double val = _sample.data.db[j] - _mean.data.db[j]; 10136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w_data[is_spherical ? 0 : j] += p*val*val; 10146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( is_spherical ) 10186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn d = w_data[0]/(double)dims; 10206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn d = MAX( d, min_variation ); 10216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn w->data.db[0] = d; 10226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn log_det->data.db[k] = d; 10236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else 10256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( is_general ) 10276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSVD( cov, w, cov_rotate_mats[k], 0, CV_SVD_U_T ); 10286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvMaxS( w, min_variation, w ); 10296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( j = 0, det = 1.; j < dims; j++ ) 10306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn det *= w_data[j]; 10316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn log_det->data.db[k] = det; 10326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvConvertScale( weights, weights, 1./(double)nsamples, 0 ); 10366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvMaxS( weights, DBL_MIN, weights ); 10376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvLog( log_det, log_det ); 10396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( is_spherical ) 10406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvScale( log_det, log_det, dims ); 10416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } // end of iteration process 10426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn //log_weight_div_det[k] = -2*log(weights_k/det(Sigma_k))^0.5) = -2*log(weights_k) + log(det(Sigma_k))) 10446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( log_weight_div_det ) 10456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvScale( log_weights, log_weight_div_det, -2 ); 10476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvAdd( log_weight_div_det, log_det, log_weight_div_det ); 10486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn /* Now finalize all the covariation matrices: 10516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 1) if <cov_mat_type> == COV_MAT_DIAGONAL we used array of <w> as diagonals. 10526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn Now w[k] should be copied back to the diagonals of covs[k]; 10536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 2) if <cov_mat_type> == COV_MAT_SPHERICAL we used the 0-th element of w[k] 10546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn as an average variation in each cluster. The value of the 0-th element of w[k] 10556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn should be copied to the all of the diagonal elements of covs[k]. */ 10566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn if( is_spherical ) 10576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < nclusters; k++ ) 10596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvSetIdentity( covs[k], cvScalar(cov_eigen_values->data.db[k])); 10606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn else if( is_diagonal ) 10626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn { 10636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn for( k = 0; k < nclusters; k++ ) 10646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvTranspose( cvGetRow( cov_eigen_values, &whdr, k ), 10656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvGetDiag( covs[k], &diag )); 10666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn } 10676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvDiv( 0, cov_eigen_values, inv_eigen_values ); 10686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn log_likelihood = _log_likelihood; 10706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn __END__; 10726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &log_det ); 10746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &log_weights ); 10756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &covs_item ); 10766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( ¢ered_sample ); 10776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &cov_eigen_values ); 10786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &samples ); 10796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn cvReleaseMat( &sum_probs ); 10806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return log_likelihood; 10826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 10836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvEM::get_nclusters() const 10866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 10876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return params.nclusters; 10886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 10896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennconst CvMat* CvEM::get_means() const 10916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 10926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return means; 10936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 10946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 10956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennconst CvMat** CvEM::get_covs() const 10966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 10976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return (const CvMat**)covs; 10986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 10996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennconst CvMat* CvEM::get_weights() const 11016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 11026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return weights; 11036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 11046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennconst CvMat* CvEM::get_probs() const 11066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{ 11076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn return probs; 11086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn} 11096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn 11106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/* End of file. */ 1111