16acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/*M///////////////////////////////////////////////////////////////////////////////////////
26acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
36acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
46acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
56acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  By downloading, copying, installing or using the software you agree to this license.
66acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  If you do not agree to this license, do not download, install,
76acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  copy or use the software.
86acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
96acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//                        Intel License Agreement
116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//                For Open Source Computer Vision Library
126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Copyright( C) 2000, Intel Corporation, all rights reserved.
146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Third party copyrights are property of their respective owners.
156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Redistribution and use in source and binary forms, with or without modification,
176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// are permitted provided that the following conditions are met:
186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//   * Redistribution's of source code must retain the above copyright notice,
206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     this list of conditions and the following disclaimer.
216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//   * Redistribution's in binary form must reproduce the above copyright notice,
236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     this list of conditions and the following disclaimer in the documentation
246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     and/or other materials provided with the distribution.
256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//   * The name of Intel Corporation may not be used to endorse or promote products
276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     derived from this software without specific prior written permission.
286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// This software is provided by the copyright holders and contributors "as is" and
306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// any express or implied warranties, including, but not limited to, the implied
316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// warranties of merchantability and fitness for a particular purpose are disclaimed.
326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// In no event shall the Intel Corporation or contributors be liable for any direct,
336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// indirect, incidental, special, exemplary, or consequential damages
346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//(including, but not limited to, procurement of substitute goods or services;
356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// loss of use, data, or profits; or business interruption) however caused
366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// and on any theory of liability, whether in contract, strict liability,
376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// or tort(including negligence or otherwise) arising in any way out of
386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// the use of this software, even ifadvised of the possibility of such damage.
396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//M*/
416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#include "_ml.h"
436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/*
466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   CvEM:
476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn * params.nclusters    - number of clusters to cluster samples to.
486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn * means               - calculated by the EM algorithm set of gaussians' means.
496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn * log_weight_div_det - auxilary vector that k-th component is equal to
506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        (-2)*ln(weights_k/det(Sigma_k)^0.5),
516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        where <weights_k> is the weight,
526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        <Sigma_k> is the covariation matrice of k-th cluster.
536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn * inv_eigen_values   - set of 1*dims matrices, <inv_eigen_values>[k] contains
546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        inversed eigen values of covariation matrice of the k-th cluster.
556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        In the case of <cov_mat_type> == COV_MAT_DIAGONAL,
566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        inv_eigen_values[k] = Sigma_k^(-1).
576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn * covs_rotate_mats   - used only if cov_mat_type == COV_MAT_GENERIC, in all the
586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        other cases it is NULL. <covs_rotate_mats>[k] is the orthogonal
596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        matrice, obtained by the SVD-decomposition of Sigma_k.
606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   Both <inv_eigen_values> and <covs_rotate_mats> fields are used for representation of
616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   covariation matrices and simplifying EM calculations.
626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   For fixed k denote
636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   u = covs_rotate_mats[k],
646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   v = inv_eigen_values[k],
656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   w = v^(-1);
666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   if <cov_mat_type> == COV_MAT_GENERIC, then Sigma_k = u w u',
676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   else                                       Sigma_k = w.
686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   Symbol ' means transposition.
696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn */
706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvEM::CvEM()
736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    means = weights = probs = inv_eigen_values = log_weight_div_det = 0;
756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    covs = cov_rotate_mats = 0;
766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvEM::CvEM( const CvMat* samples, const CvMat* sample_idx,
796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CvEMParams params, CvMat* labels )
806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    means = weights = probs = inv_eigen_values = log_weight_div_det = 0;
826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    covs = cov_rotate_mats = 0;
836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // just invoke the train() method
856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    train(samples, sample_idx, params, labels);
866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvEM::~CvEM()
896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvEM::clear()
956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i;
976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &means );
996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &weights );
1006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &probs );
1016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &inv_eigen_values );
1026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &log_weight_div_det );
1036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( covs || cov_rotate_mats )
1056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
1066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < params.nclusters; i++ )
1076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
1086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( covs )
1096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvReleaseMat( &covs[i] );
1106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( cov_rotate_mats )
1116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvReleaseMat( &cov_rotate_mats[i] );
1126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
1136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvFree( &covs );
1146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvFree( &cov_rotate_mats );
1156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
1166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
1176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvEM::set_params( const CvEMParams& _params, const CvVectors& train_data )
1206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
1216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvEM::set_params" );
1226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
1246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int k;
1266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params = _params;
1286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    params.term_crit = cvCheckTermCriteria( params.term_crit, 1e-6, 10000 );
1296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.cov_mat_type != COV_MAT_SPHERICAL &&
1316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.cov_mat_type != COV_MAT_DIAGONAL &&
1326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.cov_mat_type != COV_MAT_GENERIC )
1336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Unknown covariation matrix type" );
1346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    switch( params.start_step )
1366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
1376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    case START_M_STEP:
1386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !params.probs )
1396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsNullPtr, "Probabilities must be specified when EM algorithm starts with M-step" );
1406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        break;
1416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    case START_E_STEP:
1426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !params.means )
1436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsNullPtr, "Mean's must be specified when EM algorithm starts with E-step" );
1446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        break;
1456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    case START_AUTO_STEP:
1466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        break;
1476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    default:
1486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Unknown start_step" );
1496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
1506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.nclusters < 1 )
1526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsOutOfRange, "The number of clusters (mixtures) should be > 0" );
1536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.probs )
1556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
1566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const CvMat* p = params.weights;
1576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !CV_IS_MAT(p) ||
1586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_MAT_TYPE(p->type) != CV_32FC1  &&
1596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_MAT_TYPE(p->type) != CV_64FC1 ||
1606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            p->rows != train_data.count ||
1616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            p->cols != params.nclusters )
1626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadArg, "The array of probabilities must be a valid "
1636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            "floating-point matrix (CvMat) of 'nsamples' x 'nclusters' size" );
1646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
1656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.means )
1676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
1686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const CvMat* m = params.means;
1696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !CV_IS_MAT(m) ||
1706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_MAT_TYPE(m->type) != CV_32FC1  &&
1716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_MAT_TYPE(m->type) != CV_64FC1 ||
1726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            m->rows != params.nclusters ||
1736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            m->cols != train_data.dims )
1746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadArg, "The array of mean's must be a valid "
1756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            "floating-point matrix (CvMat) of 'nsamples' x 'dims' size" );
1766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
1776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.weights )
1796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
1806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const CvMat* w = params.weights;
1816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !CV_IS_MAT(w) ||
1826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_MAT_TYPE(w->type) != CV_32FC1  &&
1836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_MAT_TYPE(w->type) != CV_64FC1 ||
1846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            w->rows != 1 && w->cols != 1 ||
1856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            w->rows + w->cols - 1 != params.nclusters )
1866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadArg, "The array of weights must be a valid "
1876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            "1d floating-point vector (CvMat) of 'nclusters' elements" );
1886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
1896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.covs )
1916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( k = 0; k < params.nclusters; k++ )
1926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
1936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const CvMat* cov = params.covs[k];
1946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( !CV_IS_MAT(cov) ||
1956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CV_MAT_TYPE(cov->type) != CV_32FC1  &&
1966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CV_MAT_TYPE(cov->type) != CV_64FC1 ||
1976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cov->rows != cov->cols || cov->cols != train_data.dims )
1986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CV_ERROR( CV_StsBadArg,
1996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                "Each of covariation matrices must be a valid square "
2006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                "floating-point matrix (CvMat) of 'dims' x 'dims'" );
2016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
2026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
2046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
2056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/
2086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennfloat
2096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvEM::predict( const CvMat* _sample, CvMat* _probs ) const
2106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
2116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float* sample_data   = 0;
2126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    void* buffer = 0;
2136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int allocated_buffer = 0;
2146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int cls = 0;
2156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvEM::predict" );
2176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
2186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, k, dims;
2206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int nclusters;
2216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int cov_mat_type = params.cov_mat_type;
2226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double opt = FLT_MAX;
2236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    size_t size;
2246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat diff, expo;
2256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    dims = means->cols;
2276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    nclusters = params.nclusters;
2286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( cvPreparePredictData( _sample, dims, 0, params.nclusters, _probs, &sample_data ));
2306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// allocate memory and initializing headers for calculating
2326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    size = sizeof(double) * (nclusters + dims);
2336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( size <= CV_MAX_LOCAL_SIZE )
2346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        buffer = cvStackAlloc( size );
2356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
2366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
2376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( buffer = cvAlloc( size ));
2386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        allocated_buffer = 1;
2396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
2406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    expo = cvMat( 1, nclusters, CV_64FC1, buffer );
2416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    diff = cvMat( 1, dims, CV_64FC1, (double*)buffer + nclusters );
2426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// calculate the probabilities
2446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( k = 0; k < nclusters; k++ )
2456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
2466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const double* mean_k = (const double*)(means->data.ptr + means->step*k);
2476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const double* w = (const double*)(inv_eigen_values->data.ptr + inv_eigen_values->step*k);
2486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double cur = log_weight_div_det->data.db[k];
2496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvMat* u = cov_rotate_mats[k];
2506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // cov = u w u'  -->  cov^(-1) = u w^(-1) u'
2516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( cov_mat_type == COV_MAT_SPHERICAL )
2526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
2536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double w0 = w[0];
2546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < dims; i++ )
2556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
2566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double val = sample_data[i] - mean_k[i];
2576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cur += val*val*w0;
2586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
2596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
2606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
2616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
2626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < dims; i++ )
2636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                diff.data.db[i] = sample_data[i] - mean_k[i];
2646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( cov_mat_type == COV_MAT_GENERIC )
2656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvGEMM( &diff, u, 1, 0, 0, &diff, CV_GEMM_B_T );
2666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < dims; i++ )
2676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
2686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double val = diff.data.db[i];
2696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cur += val*val*w[i];
2706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
2716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
2726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        expo.data.db[k] = cur;
2746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( cur < opt )
2756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
2766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cls = k;
2776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            opt = cur;
2786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
2796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        /* probability = (2*pi)^(-dims/2)*exp( -0.5 * cur ) */
2806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
2816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( _probs )
2836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
2846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( cvConvertScale( &expo, &expo, -0.5 ));
2856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( cvExp( &expo, &expo ));
2866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( _probs->cols == 1 )
2876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_CALL( cvReshape( &expo, &expo, 0, nclusters ));
2886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( cvConvertScale( &expo, _probs, 1./cvSum( &expo ).val[0] ));
2896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
2906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
2926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( sample_data != _sample->data.fl )
2946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvFree( &sample_data );
2956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( allocated_buffer )
2966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvFree( &buffer );
2976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return (float)cls;
2996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
3006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool CvEM::train( const CvMat* _samples, const CvMat* _sample_idx,
3046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                  CvEMParams _params, CvMat* labels )
3056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
3066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bool result = false;
3076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvVectors train_data;
3086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* sample_idx = 0;
3096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    train_data.data.fl = 0;
3116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    train_data.count = 0;
3126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME("cvEM");
3146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
3166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, nsamples, nclusters, dims;
3186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
3206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( cvPrepareTrainData( "cvEM",
3226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        _samples, CV_ROW_SAMPLE, 0, CV_VAR_CATEGORICAL,
3236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        0, _sample_idx, false, (const float***)&train_data.data.fl,
3246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        &train_data.count, &train_data.dims, &train_data.dims,
3256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        0, 0, 0, &sample_idx ));
3266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( set_params( _params, train_data ));
3286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    nsamples = train_data.count;
3296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    nclusters = params.nclusters;
3306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    dims = train_data.dims;
3316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( labels && (!CV_IS_MAT(labels) || CV_MAT_TYPE(labels->type) != CV_32SC1 ||
3336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        labels->cols != 1 && labels->rows != 1 || labels->cols + labels->rows - 1 != nsamples ))
3346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg,
3356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        "labels array (when passed) must be a valid 1d integer vector of <sample_count> elements" );
3366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( nsamples <= nclusters )
3386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsOutOfRange,
3396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        "The number of samples should be greater than the number of clusters" );
3406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( log_weight_div_det = cvCreateMat( 1, nclusters, CV_64FC1 ));
3426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( probs  = cvCreateMat( nsamples, nclusters, CV_64FC1 ));
3436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( means = cvCreateMat( nclusters, dims, CV_64FC1 ));
3446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( weights = cvCreateMat( 1, nclusters, CV_64FC1 ));
3456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( inv_eigen_values = cvCreateMat( nclusters,
3466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        params.cov_mat_type == COV_MAT_SPHERICAL ? 1 : dims, CV_64FC1 ));
3476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( covs = (CvMat**)cvAlloc( nclusters * sizeof(*covs) ));
3486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( cov_rotate_mats = (CvMat**)cvAlloc( nclusters * sizeof(cov_rotate_mats[0]) ));
3496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < nclusters; i++ )
3516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( covs[i] = cvCreateMat( dims, dims, CV_64FC1 ));
3536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( cov_rotate_mats[i]  = cvCreateMat( dims, dims, CV_64FC1 ));
3546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvZero( cov_rotate_mats[i] );
3556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
3566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    init_em( train_data );
3586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    log_likelihood = run_em( train_data );
3596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( log_likelihood <= -DBL_MAX/10000. )
3606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        EXIT;
3616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( labels )
3636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( nclusters == 1 )
3656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvZero( labels );
3666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
3676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
3686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CvMat sample = cvMat( 1, dims, CV_32F );
3696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CvMat prob = cvMat( 1, nclusters, CV_64F );
3706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int lstep = CV_IS_MAT_CONT(labels->type) ? 1 : labels->step/sizeof(int);
3716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < nsamples; i++ )
3736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
3746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int idx = sample_idx ? sample_idx->data.i[i] : i;
3756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                sample.data.ptr = _samples->data.ptr + _samples->step*idx;
3766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                prob.data.ptr = probs->data.ptr + probs->step*i;
3776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                labels->data.i[i*lstep] = cvRound(predict(&sample, &prob));
3796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
3806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
3816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
3826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    result = true;
3846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
3866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( sample_idx != _sample_idx )
3886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvReleaseMat( &sample_idx );
3896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( &train_data.data.ptr );
3916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return result;
3936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
3946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvEM::init_em( const CvVectors& train_data )
3976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
3986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat *w = 0, *u = 0, *tcov = 0;
3996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvEM::init_em" );
4016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
4036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double maxval = 0;
4056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, force_symm_plus = 0;
4066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int nclusters = params.nclusters, nsamples = train_data.count, dims = train_data.dims;
4076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.start_step == START_AUTO_STEP || nclusters == 1 || nclusters == nsamples )
4096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        init_auto( train_data );
4106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else if( params.start_step == START_M_STEP )
4116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
4126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < nsamples; i++ )
4136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
4146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CvMat prob;
4156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvGetRow( params.probs, &prob, i );
4166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvMaxS( &prob, 0., &prob );
4176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvMinMaxLoc( &prob, 0, &maxval );
4186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( maxval < FLT_EPSILON )
4196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvSet( &prob, cvScalar(1./nclusters) );
4206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else
4216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvNormalize( &prob, &prob, 1., 0, CV_L1 );
4226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
4236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        EXIT; // do not preprocess covariation matrices,
4246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn              // as in this case they are initialized at the first iteration of EM
4256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
4266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
4276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
4286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ASSERT( params.start_step == START_E_STEP && params.means );
4296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( params.weights && params.covs )
4306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
4316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvConvert( params.means, means );
4326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvReshape( weights, weights, 1, params.weights->rows );
4336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvConvert( params.weights, weights );
4346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvReshape( weights, weights, 1, 1 );
4356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvMaxS( weights, 0., weights );
4366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvMinMaxLoc( weights, 0, &maxval );
4376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( maxval < FLT_EPSILON )
4386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvSet( &weights, cvScalar(1./nclusters) );
4396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvNormalize( weights, weights, 1., 0, CV_L1 );
4406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < nclusters; i++ )
4416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CV_CALL( cvConvert( params.covs[i], covs[i] ));
4426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            force_symm_plus = 1;
4436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
4446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
4456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            init_auto( train_data );
4466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
4476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( tcov = cvCreateMat( dims, dims, CV_64FC1 ));
4496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( w = cvCreateMat( dims, dims, CV_64FC1 ));
4506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.cov_mat_type == COV_MAT_GENERIC )
4516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( u = cvCreateMat( dims, dims, CV_64FC1 ));
4526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < nclusters; i++ )
4546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
4556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( force_symm_plus )
4566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
4576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvTranspose( covs[i], tcov );
4586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvAddWeighted( covs[i], 0.5, tcov, 0.5, 0, tcov );
4596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
4606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
4616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvCopy( covs[i], tcov );
4626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvSVD( tcov, w, u, 0, CV_SVD_MODIFY_A + CV_SVD_U_T + CV_SVD_V_T );
4636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( params.cov_mat_type == COV_MAT_SPHERICAL )
4646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvSetIdentity( covs[i], cvScalar(cvTrace(w).val[0]/dims) );
4656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else if( params.cov_mat_type == COV_MAT_DIAGONAL )
4666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvCopy( w, covs[i] );
4676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
4686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
4696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // generic case: covs[i] = (u')'*max(w,0)*u'
4706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvGEMM( u, w, 1, 0, 0, tcov, CV_GEMM_A_T );
4716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvGEMM( tcov, u, 1, 0, 0, covs[i], 0 );
4726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
4736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
4746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
4766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &w );
4786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &u );
4796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &tcov );
4806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
4816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvEM::init_auto( const CvVectors& train_data )
4846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
4856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* hdr = 0;
4866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const void** vec = 0;
4876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* class_ranges = 0;
4886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* labels = 0;
4896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvEM::init_auto" );
4916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
4936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int nclusters = params.nclusters, nsamples = train_data.count, dims = train_data.dims;
4956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, j;
4966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( nclusters == nsamples )
4986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
4996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvMat src = cvMat( 1, dims, CV_32F );
5006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvMat dst = cvMat( 1, dims, CV_64F );
5016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < nsamples; i++ )
5026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
5036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            src.data.ptr = train_data.data.ptr[i];
5046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            dst.data.ptr = means->data.ptr + means->step*i;
5056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvConvert( &src, &dst );
5066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvZero( covs[i] );
5076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvSetIdentity( cov_rotate_mats[i] );
5086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
5096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvSetIdentity( probs );
5106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvSet( weights, cvScalar(1./nclusters) );
5116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
5126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
5136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
5146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int max_count = 0;
5156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( class_ranges = cvCreateMat( 1, nclusters+1, CV_32SC1 ));
5176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( nclusters > 1 )
5186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
5196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_CALL( labels = cvCreateMat( 1, nsamples, CV_32SC1 ));
5206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            kmeans( train_data, nclusters, labels, cvTermCriteria( CV_TERMCRIT_ITER,
5216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    params.means ? 1 : 10, 0.5 ), params.means );
5226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_CALL( cvSortSamplesByClasses( (const float**)train_data.data.fl,
5236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                            labels, class_ranges->data.i ));
5246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
5256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
5266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
5276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            class_ranges->data.i[0] = 0;
5286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            class_ranges->data.i[1] = nsamples;
5296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
5306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < nclusters; i++ )
5326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
5336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int left = class_ranges->data.i[i], right = class_ranges->data.i[i+1];
5346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            max_count = MAX( max_count, right - left );
5356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
5366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( hdr = (CvMat*)cvAlloc( max_count*sizeof(hdr[0]) ));
5376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( vec = (const void**)cvAlloc( max_count*sizeof(vec[0]) ));
5386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        hdr[0] = cvMat( 1, dims, CV_32F );
5396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < max_count; i++ )
5406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
5416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            vec[i] = hdr + i;
5426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            hdr[i] = hdr[0];
5436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
5446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < nclusters; i++ )
5466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
5476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int left = class_ranges->data.i[i], right = class_ranges->data.i[i+1];
5486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            int cluster_size = right - left;
5496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CvMat avg;
5506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( cluster_size <= 0 )
5526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                continue;
5536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = left; j < right; j++ )
5556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                hdr[j - left].data.fl = train_data.data.fl[j];
5566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_CALL( cvGetRow( means, &avg, i ));
5586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_CALL( cvCalcCovarMatrix( vec, cluster_size, covs[i],
5596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                &avg, CV_COVAR_NORMAL | CV_COVAR_SCALE ));
5606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            weights->data.db[i] = (double)cluster_size/(double)nsamples;
5616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
5626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
5636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
5656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &class_ranges );
5676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &labels );
5686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( &hdr );
5696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvFree( &vec );
5706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
5716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvEM::kmeans( const CvVectors& train_data, int nclusters, CvMat* labels,
5746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                   CvTermCriteria termcrit, const CvMat* centers0 )
5756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
5766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* centers = 0;
5776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* old_centers = 0;
5786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* counters = 0;
5796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvEM::kmeans" );
5816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
5836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvRNG rng = cvRNG(-1);
5856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, j, k, nsamples, dims;
5866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int iter = 0;
5876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double max_dist = DBL_MAX;
5886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    termcrit = cvCheckTermCriteria( termcrit, 1e-6, 100 );
5906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    termcrit.epsilon *= termcrit.epsilon;
5916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    nsamples = train_data.count;
5926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    dims = train_data.dims;
5936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    nclusters = MIN( nclusters, nsamples );
5946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
5956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( centers = cvCreateMat( nclusters, dims, CV_64FC1 ));
5966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( old_centers = cvCreateMat( nclusters, dims, CV_64FC1 ));
5976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( counters = cvCreateMat( 1, nclusters, CV_32SC1 ));
5986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvZero( old_centers );
5996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( centers0 )
6016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
6026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( cvConvert( centers0, centers ));
6036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
6046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
6056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
6066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < nsamples; i++ )
6076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            labels->data.i[i] = i*nclusters/nsamples;
6086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvRandShuffle( labels, &rng );
6096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
6106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( ;; )
6126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
6136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvMat* temp;
6146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( iter > 0 || centers0 )
6166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
6176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < nsamples; i++ )
6186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
6196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                const float* s = train_data.data.fl[i];
6206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int k_best = 0;
6216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double min_dist = DBL_MAX;
6226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( k = 0; k < nclusters; k++ )
6246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
6256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    const double* c = (double*)(centers->data.ptr + k*centers->step);
6266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    double dist = 0;
6276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    for( j = 0; j <= dims - 4; j += 4 )
6296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    {
6306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        double t0 = c[j] - s[j];
6316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        double t1 = c[j+1] - s[j+1];
6326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        dist += t0*t0 + t1*t1;
6336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        t0 = c[j+2] - s[j+2];
6346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        t1 = c[j+3] - s[j+3];
6356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        dist += t0*t0 + t1*t1;
6366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    }
6376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    for( ; j < dims; j++ )
6396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    {
6406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        double t = c[j] - s[j];
6416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        dist += t*t;
6426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    }
6436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( min_dist > dist )
6456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    {
6466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        min_dist = dist;
6476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        k_best = k;
6486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    }
6496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
6506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                labels->data.i[i] = k_best;
6526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
6536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
6546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( ++iter > termcrit.max_iter )
6566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            break;
6576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_SWAP( centers, old_centers, temp );
6596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvZero( centers );
6606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvZero( counters );
6616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // update centers
6636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( i = 0; i < nsamples; i++ )
6646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
6656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            const float* s = train_data.data.fl[i];
6666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            k = labels->data.i[i];
6676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double* c = (double*)(centers->data.ptr + k*centers->step);
6686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j <= dims - 4; j += 4 )
6706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
6716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double t0 = c[j] + s[j];
6726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double t1 = c[j+1] + s[j+1];
6736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                c[j] = t0;
6756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                c[j+1] = t1;
6766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                t0 = c[j+2] + s[j+2];
6786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                t1 = c[j+3] + s[j+3];
6796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                c[j+2] = t0;
6816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                c[j+3] = t1;
6826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
6836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( ; j < dims; j++ )
6846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                c[j] += s[j];
6856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            counters->data.i[k]++;
6866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
6876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( iter > 1 )
6896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            max_dist = 0;
6906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
6916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( k = 0; k < nclusters; k++ )
6926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
6936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double* c = (double*)(centers->data.ptr + k*centers->step);
6946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( counters->data.i[k] != 0 )
6956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
6966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double scale = 1./counters->data.i[k];
6976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( j = 0; j < dims; j++ )
6986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    c[j] *= scale;
6996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
7006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else
7016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
7026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                const float* s;
7036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( j = 0; j < 10; j++ )
7046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
7056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    i = cvRandInt( &rng ) % nsamples;
7066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( counters->data.i[labels->data.i[i]] > 1 )
7076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        break;
7086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
7096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                s = train_data.data.fl[i];
7106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( j = 0; j < dims; j++ )
7116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    c[j] = s[j];
7126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
7136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( iter > 1 )
7156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
7166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double dist = 0;
7176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                const double* c_o = (double*)(old_centers->data.ptr + k*old_centers->step);
7186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( j = 0; j < dims; j++ )
7196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
7206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    double t = c[j] - c_o[j];
7216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    dist += t*t;
7226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
7236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( max_dist < dist )
7246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    max_dist = dist;
7256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
7266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
7276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( max_dist < termcrit.epsilon )
7296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            break;
7306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
7316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvZero( counters );
7336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < nsamples; i++ )
7346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        counters->data.i[labels->data.i[i]]++;
7356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // ensure that we do not have empty clusters
7376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( k = 0; k < nclusters; k++ )
7386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( counters->data.i[k] == 0 )
7396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for(;;)
7406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
7416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                i = cvRandInt(&rng) % nsamples;
7426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                j = labels->data.i[i];
7436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( counters->data.i[j] > 1 )
7446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
7456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    labels->data.i[i] = k;
7466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    counters->data.i[j]--;
7476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    counters->data.i[k]++;
7486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    break;
7496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
7506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
7516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
7536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &centers );
7556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &old_centers );
7566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &counters );
7576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
7586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************/
7616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/* log_weight_div_det[k] = -2*log(weights_k) + log(det(Sigma_k)))
7626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   covs[k] = cov_rotate_mats[k] * cov_eigen_values[k] * (cov_rotate_mats[k])'
7646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   cov_rotate_mats[k] are orthogonal matrices of eigenvectors and
7656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   cov_eigen_values[k] are diagonal matrices (represented by 1D vectors) of eigen values.
7666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   The <alpha_ik> is the probability of the vector x_i to belong to the k-th cluster:
7686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   <alpha_ik> ~ weights_k * exp{ -0.5[ln(det(Sigma_k)) + (x_i - mu_k)' Sigma_k^(-1) (x_i - mu_k)] }
7696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   We calculate these probabilities here by the equivalent formulae:
7706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   Denote
7716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   S_ik = -0.5(log(det(Sigma_k)) + (x_i - mu_k)' Sigma_k^(-1) (x_i - mu_k)) + log(weights_k),
7726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   M_i = max_k S_ik = S_qi, so that the q-th class is the one where maximum reaches. Then
7736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn   alpha_ik = exp{ S_ik - M_i } / ( 1 + sum_j!=q exp{ S_ji - M_i })
7746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn*/
7756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renndouble CvEM::run_em( const CvVectors& train_data )
7766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
7776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* centered_sample = 0;
7786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* covs_item = 0;
7796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* log_det = 0;
7806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* log_weights = 0;
7816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* cov_eigen_values = 0;
7826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* samples = 0;
7836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* sum_probs = 0;
7846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    log_likelihood = -DBL_MAX;
7856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvEM::run_em" );
7876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
7886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int nsamples = train_data.count, dims = train_data.dims, nclusters = params.nclusters;
7906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double min_variation = FLT_EPSILON;
7916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double min_det_value = MAX( DBL_MIN, pow( min_variation, dims ));
7926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double likelihood_bias = -CV_LOG2PI * (double)nsamples * (double)dims / 2., _log_likelihood = -DBL_MAX;
7936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int start_step = params.start_step;
7946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
7956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, j, k, n;
7966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int is_general = 0, is_diagonal = 0, is_spherical = 0;
7976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double prev_log_likelihood = -DBL_MAX / 1000., det, d;
7986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat whdr, iwhdr, diag, *w, *iw;
7996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* w_data;
8006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double* sp_data;
8016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( nclusters == 1 )
8036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
8046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double log_weight;
8056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( cvSet( probs, cvScalar(1.)) );
8066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( params.cov_mat_type == COV_MAT_SPHERICAL )
8086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
8096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            d = cvTrace(*covs).val[0]/dims;
8106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            d = MAX( d, FLT_EPSILON );
8116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            inv_eigen_values->data.db[0] = 1./d;
8126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            log_weight = pow( d, dims*0.5 );
8136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
8146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        else
8156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
8166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            w_data = inv_eigen_values->data.db;
8176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( params.cov_mat_type == COV_MAT_GENERIC )
8196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvSVD( *covs, inv_eigen_values, *cov_rotate_mats, 0, CV_SVD_U_T );
8206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else
8216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvTranspose( cvGetDiag(*covs, &diag), inv_eigen_values );
8226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvMaxS( inv_eigen_values, FLT_EPSILON, inv_eigen_values );
8246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0, det = 1.; j < dims; j++ )
8256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                det *= w_data[j];
8266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            log_weight = sqrt(det);
8276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvDiv( 0, inv_eigen_values, inv_eigen_values );
8286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
8296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        log_weight_div_det->data.db[0] = -2*log(weights->data.db[0]/log_weight);
8316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        log_likelihood = DBL_MAX/1000.;
8326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        EXIT;
8336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
8346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( params.cov_mat_type == COV_MAT_GENERIC )
8366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        is_general  = 1;
8376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else if( params.cov_mat_type == COV_MAT_DIAGONAL )
8386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        is_diagonal = 1;
8396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else if( params.cov_mat_type == COV_MAT_SPHERICAL )
8406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        is_spherical  = 1;
8416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    /* In the case of <cov_mat_type> == COV_MAT_DIAGONAL, the k-th row of cov_eigen_values
8426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    contains the diagonal elements (variations). In the case of
8436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    <cov_mat_type> == COV_MAT_SPHERICAL - the 0-ths elements of the vectors cov_eigen_values[k]
8446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    are to be equal to the mean of the variations over all the dimensions. */
8456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( log_det = cvCreateMat( 1, nclusters, CV_64FC1 ));
8476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( log_weights = cvCreateMat( 1, nclusters, CV_64FC1 ));
8486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( covs_item = cvCreateMat( dims, dims, CV_64FC1 ));
8496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( centered_sample = cvCreateMat( 1, dims, CV_64FC1 ));
8506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( cov_eigen_values = cvCreateMat( inv_eigen_values->rows, inv_eigen_values->cols, CV_64FC1 ));
8516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( samples = cvCreateMat( nsamples, dims, CV_64FC1 ));
8526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( sum_probs = cvCreateMat( 1, nclusters, CV_64FC1 ));
8536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    sp_data = sum_probs->data.db;
8546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // copy the training data into double-precision matrix
8566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < nsamples; i++ )
8576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
8586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const float* src = train_data.data.fl[i];
8596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        double* dst = (double*)(samples->data.ptr + samples->step*i);
8606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( j = 0; j < dims; j++ )
8626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            dst[j] = src[j];
8636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
8646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
8656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( start_step != START_M_STEP )
8666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
8676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( k = 0; k < nclusters; k++ )
8686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
8696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( is_general || is_diagonal )
8706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
8716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                w = cvGetRow( cov_eigen_values, &whdr, k );
8726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( is_general )
8736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    cvSVD( covs[k], w, cov_rotate_mats[k], 0, CV_SVD_U_T );
8746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                else
8756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    cvTranspose( cvGetDiag( covs[k], &diag ), w );
8766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                w_data = w->data.db;
8776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( j = 0, det = 1.; j < dims; j++ )
8786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    det *= w_data[j];
8796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( det < min_det_value )
8806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
8816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( start_step == START_AUTO_STEP )
8826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        det = min_det_value;
8836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    else
8846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        EXIT;
8856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
8866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                log_det->data.db[k] = det;
8876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
8886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else
8896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
8906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                d = cvTrace(covs[k]).val[0]/(double)dims;
8916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( d < min_variation )
8926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
8936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( start_step == START_AUTO_STEP )
8946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        d = min_variation;
8956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    else
8966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        EXIT;
8976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
8986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cov_eigen_values->data.db[k] = d;
8996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                log_det->data.db[k] = d;
9006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
9016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
9026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvLog( log_det, log_det );
9046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( is_spherical )
9056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvScale( log_det, log_det, dims );
9066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
9076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( n = 0; n < params.term_crit.max_iter; n++ )
9096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
9106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( n > 0 || start_step != START_M_STEP )
9116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
9126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // e-step: compute probs_ik from means_k, covs_k and weights_k.
9136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_CALL(cvLog( weights, log_weights ));
9146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // S_ik = -0.5[log(det(Sigma_k)) + (x_i - mu_k)' Sigma_k^(-1) (x_i - mu_k)] + log(weights_k)
9166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( k = 0; k < nclusters; k++ )
9176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
9186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                CvMat* u = cov_rotate_mats[k];
9196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                const double* mean = (double*)(means->data.ptr + means->step*k);
9206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                w = cvGetRow( cov_eigen_values, &whdr, k );
9216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                iw = cvGetRow( inv_eigen_values, &iwhdr, k );
9226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvDiv( 0, w, iw );
9236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                w_data = (double*)(inv_eigen_values->data.ptr + inv_eigen_values->step*k);
9256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( i = 0; i < nsamples; i++ )
9276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
9286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    double *csample = centered_sample->data.db, p = log_det->data.db[k];
9296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    const double* sample = (double*)(samples->data.ptr + samples->step*i);
9306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    double* pp = (double*)(probs->data.ptr + probs->step*i);
9316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    for( j = 0; j < dims; j++ )
9326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        csample[j] = sample[j] - mean[j];
9336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( is_general )
9346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        cvGEMM( centered_sample, u, 1, 0, 0, centered_sample, CV_GEMM_B_T );
9356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    for( j = 0; j < dims; j++ )
9366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        p += csample[j]*csample[j]*w_data[is_spherical ? 0 : j];
9376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    pp[k] = -0.5*p + log_weights->data.db[k];
9386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    // S_ik <- S_ik - max_j S_ij
9406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( k == nclusters - 1 )
9416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    {
9426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        double max_val = 0;
9436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        for( j = 0; j < nclusters; j++ )
9446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            max_val = MAX( max_val, pp[j] );
9456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        for( j = 0; j < nclusters; j++ )
9466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            pp[j] -= max_val;
9476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    }
9486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
9496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
9506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_CALL(cvExp( probs, probs )); // exp( S_ik )
9526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvZero( sum_probs );
9536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // alpha_ik = exp( S_ik ) / sum_j exp( S_ij ),
9556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // log_likelihood = sum_i log (sum_j exp(S_ij))
9566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0, _log_likelihood = likelihood_bias; i < nsamples; i++ )
9576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
9586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double* pp = (double*)(probs->data.ptr + probs->step*i), sum = 0;
9596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( j = 0; j < nclusters; j++ )
9606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    sum += pp[j];
9616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                sum = 1./MAX( sum, DBL_EPSILON );
9626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( j = 0; j < nclusters; j++ )
9636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
9646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    double p = pp[j] *= sum;
9656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    sp_data[j] += p;
9666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
9676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                _log_likelihood -= log( sum );
9686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
9696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // check termination criteria
9716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( fabs( (_log_likelihood - prev_log_likelihood) / prev_log_likelihood ) < params.term_crit.epsilon )
9726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                break;
9736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            prev_log_likelihood = _log_likelihood;
9746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
9756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        // m-step: update means_k, covs_k and weights_k from probs_ik
9776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvGEMM( probs, samples, 1, 0, 0, means, CV_GEMM_A_T );
9786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( k = 0; k < nclusters; k++ )
9806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
9816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            double sum = sp_data[k], inv_sum = 1./sum;
9826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CvMat* cov = covs[k], _mean, _sample;
9836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            w = cvGetRow( cov_eigen_values, &whdr, k );
9856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            w_data = w->data.db;
9866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvGetRow( means, &_mean, k );
9876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvGetRow( samples, &_sample, k );
9886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // update weights_k
9906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            weights->data.db[k] = sum;
9916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // update means_k
9936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvScale( &_mean, &_mean, inv_sum );
9946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            // compute covs_k
9966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvZero( cov );
9976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvZero( w );
9986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
9996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < nsamples; i++ )
10006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
10016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double p = probs->data.db[i*nclusters + k]*inv_sum;
10026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                _sample.data.db = (double*)(samples->data.ptr + samples->step*i);
10036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( is_general )
10056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
10066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    cvMulTransposed( &_sample, covs_item, 1, &_mean );
10076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    cvScaleAdd( covs_item, cvRealScalar(p), cov, cov );
10086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
10096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                else
10106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    for( j = 0; j < dims; j++ )
10116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    {
10126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        double val = _sample.data.db[j] - _mean.data.db[j];
10136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        w_data[is_spherical ? 0 : j] += p*val*val;
10146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    }
10156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
10166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( is_spherical )
10186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
10196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                d = w_data[0]/(double)dims;
10206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                d = MAX( d, min_variation );
10216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                w->data.db[0] = d;
10226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                log_det->data.db[k] = d;
10236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
10246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else
10256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
10266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( is_general )
10276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    cvSVD( cov, w, cov_rotate_mats[k], 0, CV_SVD_U_T );
10286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                cvMaxS( w, min_variation, w );
10296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( j = 0, det = 1.; j < dims; j++ )
10306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    det *= w_data[j];
10316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                log_det->data.db[k] = det;
10326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
10336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
10346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvConvertScale( weights, weights, 1./(double)nsamples, 0 );
10366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvMaxS( weights, DBL_MIN, weights );
10376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvLog( log_det, log_det );
10396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( is_spherical )
10406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvScale( log_det, log_det, dims );
10416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    } // end of iteration process
10426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    //log_weight_div_det[k] = -2*log(weights_k/det(Sigma_k))^0.5) = -2*log(weights_k) + log(det(Sigma_k)))
10446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( log_weight_div_det )
10456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
10466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvScale( log_weights, log_weight_div_det, -2 );
10476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvAdd( log_weight_div_det, log_det, log_weight_div_det );
10486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
10496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    /* Now finalize all the covariation matrices:
10516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    1) if <cov_mat_type> == COV_MAT_DIAGONAL we used array of <w> as diagonals.
10526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn       Now w[k] should be copied back to the diagonals of covs[k];
10536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    2) if <cov_mat_type> == COV_MAT_SPHERICAL we used the 0-th element of w[k]
10546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn       as an average variation in each cluster. The value of the 0-th element of w[k]
10556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn       should be copied to the all of the diagonal elements of covs[k]. */
10566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( is_spherical )
10576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
10586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( k = 0; k < nclusters; k++ )
10596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvSetIdentity( covs[k], cvScalar(cov_eigen_values->data.db[k]));
10606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
10616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else if( is_diagonal )
10626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
10636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( k = 0; k < nclusters; k++ )
10646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            cvTranspose( cvGetRow( cov_eigen_values, &whdr, k ),
10656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                         cvGetDiag( covs[k], &diag ));
10666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
10676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvDiv( 0, cov_eigen_values, inv_eigen_values );
10686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    log_likelihood = _log_likelihood;
10706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
10726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &log_det );
10746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &log_weights );
10756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &covs_item );
10766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &centered_sample );
10776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &cov_eigen_values );
10786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &samples );
10796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    cvReleaseMat( &sum_probs );
10806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return log_likelihood;
10826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
10836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvEM::get_nclusters() const
10866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
10876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return params.nclusters;
10886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
10896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennconst CvMat* CvEM::get_means() const
10916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
10926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return means;
10936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
10946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
10956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennconst CvMat** CvEM::get_covs() const
10966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
10976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return (const CvMat**)covs;
10986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
10996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennconst CvMat* CvEM::get_weights() const
11016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
11026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return weights;
11036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
11046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennconst CvMat* CvEM::get_probs() const
11066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
11076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return probs;
11086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
11096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
11106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/* End of file. */
1111