16acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/*M///////////////////////////////////////////////////////////////////////////////////////
26acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
36acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
46acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
56acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  By downloading, copying, installing or using the software you agree to this license.
66acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  If you do not agree to this license, do not download, install,
76acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//  copy or use the software.
86acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
96acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//                        Intel License Agreement
116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Copyright (C) 2000, Intel Corporation, all rights reserved.
136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Third party copyrights are property of their respective owners.
146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// Redistribution and use in source and binary forms, with or without modification,
166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// are permitted provided that the following conditions are met:
176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//   * Redistribution's of source code must retain the above copyright notice,
196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     this list of conditions and the following disclaimer.
206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//   * Redistribution's in binary form must reproduce the above copyright notice,
226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     this list of conditions and the following disclaimer in the documentation
236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     and/or other materials provided with the distribution.
246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//   * The name of Intel Corporation may not be used to endorse or promote products
266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//     derived from this software without specific prior written permission.
276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// This software is provided by the copyright holders and contributors "as is" and
296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// any express or implied warranties, including, but not limited to, the implied
306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// warranties of merchantability and fitness for a particular purpose are disclaimed.
316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// In no event shall the Intel Corporation or contributors be liable for any direct,
326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// indirect, incidental, special, exemplary, or consequential damages
336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// (including, but not limited to, procurement of substitute goods or services;
346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// loss of use, data, or profits; or business interruption) however caused
356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// and on any theory of liability, whether in contract, strict liability,
366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// or tort (including negligence or otherwise) arising in any way out of
376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// the use of this software, even if advised of the possibility of such damage.
386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//
396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn//M*/
406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn#include "_ml.h"
426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/****************************************************************************************\
446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn*                          K-Nearest Neighbors Classifier                                *
456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn\****************************************************************************************/
466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn// k Nearest Neighbors
486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvKNearest::CvKNearest()
496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    samples = 0;
516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvKNearest::~CvKNearest()
566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    clear();
586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius RennCvKNearest::CvKNearest( const CvMat* _train_data, const CvMat* _responses,
626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        const CvMat* _sample_idx, bool _is_regression, int _max_k )
636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    samples = 0;
656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    train( _train_data, _responses, _sample_idx, _is_regression, _max_k, false );
666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvKNearest::clear()
706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    while( samples )
726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CvVectors* next_samples = samples->next;
746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvFree( &samples->data.fl );
756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvFree( &samples );
766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        samples = next_samples;
776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    var_count = 0;
796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    total = 0;
806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    max_k = 0;
816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvKNearest::get_max_k() const { return max_k; }
856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvKNearest::get_var_count() const { return var_count; }
876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool CvKNearest::is_regression() const { return regression; }
896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennint CvKNearest::get_sample_count() const { return total; }
916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennbool CvKNearest::train( const CvMat* _train_data, const CvMat* _responses,
936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        const CvMat* _sample_idx, bool _is_regression,
946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        int _max_k, bool _update_base )
956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bool ok = false;
976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* responses = 0;
986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvKNearest::train" );
1006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
1026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvVectors* _samples;
1046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float** _data;
1056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int _count, _dims, _dims_all, _rsize;
1066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !_update_base )
1086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        clear();
1096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // Prepare training data and related parameters.
1116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // Treat categorical responses as ordered - to prevent class label compression and
1126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    // to enable entering new classes in the updates
1136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( cvPrepareTrainData( "CvKNearest::train", _train_data, CV_ROW_SAMPLE,
1146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        _responses, CV_VAR_ORDERED, 0, _sample_idx, true, (const float***)&_data,
1156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        &_count, &_dims, &_dims_all, &responses, 0, 0 ));
1166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( _update_base && _dims != var_count )
1186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "The newly added data have different dimensionality" );
1196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !_update_base )
1216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
1226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( _max_k < 1 )
1236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsOutOfRange, "max_k must be a positive number" );
1246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        regression = _is_regression;
1266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        var_count = _dims;
1276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        max_k = _max_k;
1286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
1296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    _rsize = _count*sizeof(float);
1316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_CALL( _samples = (CvVectors*)cvAlloc( sizeof(*_samples) + _rsize ));
1326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    _samples->next = samples;
1336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    _samples->type = CV_32F;
1346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    _samples->data.fl = _data;
1356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    _samples->count = _count;
1366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    total += _count;
1376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    samples = _samples;
1396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    memcpy( _samples + 1, responses->data.fl, _rsize );
1406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    ok = true;
1426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
1446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return ok;
1466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
1476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennvoid CvKNearest::find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
1516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    float* neighbor_responses, const float** neighbors, float* dist ) const
1526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
1536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, j, count = end - start, k1 = 0, k2 = 0, d = var_count;
1546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvVectors* s = samples;
1556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( ; s != 0; s = s->next )
1576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
1586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        int n = s->count;
1596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        for( j = 0; j < n; j++ )
1606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
1616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( i = 0; i < count; i++ )
1626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
1636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double sum = 0;
1646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                Cv32suf si;
1656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                const float* v = s->data.fl[j];
1666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                const float* u = (float*)(_samples->data.ptr + _samples->step*(start + i));
1676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                Cv32suf* dd = (Cv32suf*)(dist + i*k);
1686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                float* nr;
1696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                const float** nn;
1706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int t, ii, ii1;
1716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( t = 0; t <= d - 4; t += 4 )
1736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
1746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    double t0 = u[t] - v[t], t1 = u[t+1] - v[t+1];
1756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    double t2 = u[t+2] - v[t+2], t3 = u[t+3] - v[t+3];
1766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    sum += t0*t0 + t1*t1 + t2*t2 + t3*t3;
1776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
1786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( ; t < d; t++ )
1806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
1816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    double t0 = u[t] - v[t];
1826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    sum += t0*t0;
1836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
1846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                si.f = (float)sum;
1866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( ii = k1-1; ii >= 0; ii-- )
1876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( si.i > dd[ii].i )
1886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        break;
1896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( ii >= k-1 )
1906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    continue;
1916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
1926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                nr = neighbor_responses + i*k;
1936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                nn = neighbors ? neighbors + (start + i)*k : 0;
1946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( ii1 = k2 - 1; ii1 > ii; ii1-- )
1956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
1966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    dd[ii1+1].i = dd[ii1].i;
1976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    nr[ii1+1] = nr[ii1];
1986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( nn ) nn[ii1+1] = nn[ii1];
1996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
2006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dd[ii+1].i = si.i;
2016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                nr[ii+1] = ((float*)(s + 1))[j];
2026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                if( nn )
2036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    nn[ii+1] = v;
2046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
2056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            k1 = MIN( k1+1, k );
2066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            k2 = MIN( k1, k-1 );
2076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
2086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
2096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
2106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennfloat CvKNearest::write_results( int k, int k1, int start, int end,
2136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const float* neighbor_responses, const float* dist,
2146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* _results, CvMat* _neighbor_responses,
2156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CvMat* _dist, Cv32suf* sort_buf ) const
2166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
2176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float result = 0.f;
2186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, j, j1, count = end - start;
2196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    double inv_scale = 1./k1;
2206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int rstep = _results && !CV_IS_MAT_CONT(_results->type) ? _results->step/sizeof(result) : 1;
2216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < count; i++ )
2236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
2246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        const Cv32suf* nr = (const Cv32suf*)(neighbor_responses + i*k);
2256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        float* dst;
2266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        float r;
2276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( _results || start+i == 0 )
2286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
2296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( regression )
2306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
2316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                double s = 0;
2326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( j = 0; j < k1; j++ )
2336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    s += nr[j].f;
2346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                r = (float)(s*inv_scale);
2356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
2366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            else
2376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            {
2386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                int prev_start = 0, best_count = 0, cur_count;
2396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                Cv32suf best_val;
2406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( j = 0; j < k1; j++ )
2426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    sort_buf[j].i = nr[j].i;
2436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( j = k1-1; j > 0; j-- )
2456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                {
2466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    bool swap_fl = false;
2476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    for( j1 = 0; j1 < j; j1++ )
2486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        if( sort_buf[j1].i > sort_buf[j1+1].i )
2496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        {
2506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            int t;
2516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            CV_SWAP( sort_buf[j1].i, sort_buf[j1+1].i, t );
2526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            swap_fl = true;
2536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        }
2546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( !swap_fl )
2556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        break;
2566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                }
2576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                best_val.i = 0;
2596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                for( j = 1; j <= k1; j++ )
2606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    if( j == k1 || sort_buf[j].i != sort_buf[j-1].i )
2616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    {
2626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        cur_count = j - prev_start;
2636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        if( best_count < cur_count )
2646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        {
2656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            best_count = cur_count;
2666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                            best_val.i = sort_buf[j-1].i;
2676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        }
2686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                        prev_start = j;
2696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    }
2706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                r = best_val.f;
2716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            }
2726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( start+i == 0 )
2746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                result = r;
2756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            if( _results )
2776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                _results->data.fl[(start + i)*rstep] = r;
2786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
2796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( _neighbor_responses )
2816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
2826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            dst = (float*)(_neighbor_responses->data.ptr +
2836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                (start + i)*_neighbor_responses->step);
2846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j < k1; j++ )
2856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dst[j] = nr[j].f;
2866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( ; j < k; j++ )
2876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dst[j] = 0.f;
2886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
2896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
2906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( _dist )
2916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        {
2926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            dst = (float*)(_dist->data.ptr + (start + i)*_dist->step);
2936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( j = 0; j < k1; j++ )
2946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dst[j] = dist[j + i*k];
2956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            for( ; j < k; j++ )
2966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                dst[j] = 0.f;
2976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        }
2986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
2996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return result;
3016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
3026acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3036acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3046acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3056acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Rennfloat CvKNearest::find_nearest( const CvMat* _samples, int k, CvMat* _results,
3066acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const float** _neighbors, CvMat* _neighbor_responses, CvMat* _dist ) const
3076acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn{
3086acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float result = 0.f;
3096acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    bool local_alloc = false;
3106acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    float* buf = 0;
3116acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    const int max_blk_count = 128, max_buf_sz = 1 << 12;
3126acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3136acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    CV_FUNCNAME( "CvKNearest::find_nearest" );
3146acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3156acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __BEGIN__;
3166acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3176acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    int i, count, count_scale, blk_count0, blk_count = 0, buf_sz, k1;
3186acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3196acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !samples )
3206acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsError, "The search tree must be constructed first using train method" );
3216acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3226acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !CV_IS_MAT(_samples) ||
3236acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_MAT_TYPE(_samples->type) != CV_32FC1 ||
3246acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        _samples->cols != var_count )
3256acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg, "Input samples must be floating-point matrix (<num_samples>x<var_count>)" );
3266acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3276acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( _results && (!CV_IS_MAT(_results) ||
3286acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        _results->cols != 1 && _results->rows != 1 ||
3296acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        _results->cols + _results->rows - 1 != _samples->rows) )
3306acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsBadArg,
3316acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        "The results must be 1d vector containing as much elements as the number of samples" );
3326acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3336acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( _results && CV_MAT_TYPE(_results->type) != CV_32FC1 &&
3346acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        (CV_MAT_TYPE(_results->type) != CV_32SC1 || regression))
3356acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsUnsupportedFormat,
3366acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        "The results must be floating-point or integer (in case of classification) vector" );
3376acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3386acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( k < 1 || k > max_k )
3396acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_ERROR( CV_StsOutOfRange, "k must be within 1..max_k range" );
3406acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3416acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( _neighbor_responses )
3426acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3436acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !CV_IS_MAT(_neighbor_responses) || CV_MAT_TYPE(_neighbor_responses->type) != CV_32FC1 ||
3446acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            _neighbor_responses->rows != _samples->rows || _neighbor_responses->cols != k )
3456acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadArg,
3466acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            "The neighbor responses (if present) must be floating-point matrix of <num_samples> x <k> size" );
3476acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
3486acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3496acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( _dist )
3506acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3516acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( !CV_IS_MAT(_dist) || CV_MAT_TYPE(_dist->type) != CV_32FC1 ||
3526acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            _dist->rows != _samples->rows || _dist->cols != k )
3536acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            CV_ERROR( CV_StsBadArg,
3546acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            "The distances from the neighbors (if present) must be floating-point matrix of <num_samples> x <k> size" );
3556acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
3566acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3576acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    count = _samples->rows;
3586acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    count_scale = k*2*sizeof(float);
3596acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    blk_count0 = MIN( count, max_blk_count );
3606acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    buf_sz = MIN( blk_count0 * count_scale, max_buf_sz );
3616acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    blk_count0 = MAX( buf_sz/count_scale, 1 );
3626acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    blk_count0 += blk_count0 % 2;
3636acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    blk_count0 = MIN( blk_count0, count );
3646acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    buf_sz = blk_count0 * count_scale + k*sizeof(float);
3656acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    k1 = get_sample_count();
3666acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    k1 = MIN( k1, k );
3676acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3686acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( buf_sz <= CV_MAX_LOCAL_SIZE )
3696acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3706acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        buf = (float*)cvStackAlloc( buf_sz );
3716acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        local_alloc = true;
3726acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
3736acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    else
3746acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        CV_CALL( buf = (float*)cvAlloc( buf_sz ));
3756acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3766acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    for( i = 0; i < count; i += blk_count )
3776acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    {
3786acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        blk_count = MIN( count - i, blk_count0 );
3796acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        float* neighbor_responses = buf;
3806acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        float* dist = buf + blk_count*k;
3816acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        Cv32suf* sort_buf = (Cv32suf*)(dist + blk_count*k);
3826acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3836acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        find_neighbors_direct( _samples, k, i, i + blk_count,
3846acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                    neighbor_responses, _neighbors, dist );
3856acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3866acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        float r = write_results( k, k1, i, i + blk_count, neighbor_responses, dist,
3876acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn                                 _results, _neighbor_responses, _dist, sort_buf );
3886acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        if( i == 0 )
3896acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn            result = r;
3906acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    }
3916acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3926acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    __END__;
3936acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3946acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    if( !local_alloc )
3956acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn        cvFree( &buf );
3966acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
3976acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn    return result;
3986acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn}
3996acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
4006acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn/* End of file */
4016acb9a7ea3d7564944e12cbc73a857b88c1301eeMarius Renn
402