1/*M/////////////////////////////////////////////////////////////////////////////////////// 2// 3// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 4// 5// By downloading, copying, installing or using the software you agree to this license. 6// If you do not agree to this license, do not download, install, 7// copy or use the software. 8// 9// 10// Intel License Agreement 11// 12// Copyright (C) 2000, Intel Corporation, all rights reserved. 13// Third party copyrights are property of their respective owners. 14// 15// Redistribution and use in source and binary forms, with or without modification, 16// are permitted provided that the following conditions are met: 17// 18// * Redistribution's of source code must retain the above copyright notice, 19// this list of conditions and the following disclaimer. 20// 21// * Redistribution's in binary form must reproduce the above copyright notice, 22// this list of conditions and the following disclaimer in the documentation 23// and/or other materials provided with the distribution. 24// 25// * The name of Intel Corporation may not be used to endorse or promote products 26// derived from this software without specific prior written permission. 27// 28// This software is provided by the copyright holders and contributors "as is" and 29// any express or implied warranties, including, but not limited to, the implied 30// warranties of merchantability and fitness for a particular purpose are disclaimed. 31// In no event shall the Intel Corporation or contributors be liable for any direct, 32// indirect, incidental, special, exemplary, or consequential damages 33// (including, but not limited to, procurement of substitute goods or services; 34// loss of use, data, or profits; or business interruption) however caused 35// and on any theory of liability, whether in contract, strict liability, 36// or tort (including negligence or otherwise) arising in any way out of 37// the use of this software, even if advised of the possibility of such damage. 38// 39//M*/ 40 41#include "_ml.h" 42 43/****************************************************************************************\ 44* K-Nearest Neighbors Classifier * 45\****************************************************************************************/ 46 47// k Nearest Neighbors 48CvKNearest::CvKNearest() 49{ 50 samples = 0; 51 clear(); 52} 53 54 55CvKNearest::~CvKNearest() 56{ 57 clear(); 58} 59 60 61CvKNearest::CvKNearest( const CvMat* _train_data, const CvMat* _responses, 62 const CvMat* _sample_idx, bool _is_regression, int _max_k ) 63{ 64 samples = 0; 65 train( _train_data, _responses, _sample_idx, _is_regression, _max_k, false ); 66} 67 68 69void CvKNearest::clear() 70{ 71 while( samples ) 72 { 73 CvVectors* next_samples = samples->next; 74 cvFree( &samples->data.fl ); 75 cvFree( &samples ); 76 samples = next_samples; 77 } 78 var_count = 0; 79 total = 0; 80 max_k = 0; 81} 82 83 84int CvKNearest::get_max_k() const { return max_k; } 85 86int CvKNearest::get_var_count() const { return var_count; } 87 88bool CvKNearest::is_regression() const { return regression; } 89 90int CvKNearest::get_sample_count() const { return total; } 91 92bool CvKNearest::train( const CvMat* _train_data, const CvMat* _responses, 93 const CvMat* _sample_idx, bool _is_regression, 94 int _max_k, bool _update_base ) 95{ 96 bool ok = false; 97 CvMat* responses = 0; 98 99 CV_FUNCNAME( "CvKNearest::train" ); 100 101 __BEGIN__; 102 103 CvVectors* _samples; 104 float** _data; 105 int _count, _dims, _dims_all, _rsize; 106 107 if( !_update_base ) 108 clear(); 109 110 // Prepare training data and related parameters. 111 // Treat categorical responses as ordered - to prevent class label compression and 112 // to enable entering new classes in the updates 113 CV_CALL( cvPrepareTrainData( "CvKNearest::train", _train_data, CV_ROW_SAMPLE, 114 _responses, CV_VAR_ORDERED, 0, _sample_idx, true, (const float***)&_data, 115 &_count, &_dims, &_dims_all, &responses, 0, 0 )); 116 117 if( _update_base && _dims != var_count ) 118 CV_ERROR( CV_StsBadArg, "The newly added data have different dimensionality" ); 119 120 if( !_update_base ) 121 { 122 if( _max_k < 1 ) 123 CV_ERROR( CV_StsOutOfRange, "max_k must be a positive number" ); 124 125 regression = _is_regression; 126 var_count = _dims; 127 max_k = _max_k; 128 } 129 130 _rsize = _count*sizeof(float); 131 CV_CALL( _samples = (CvVectors*)cvAlloc( sizeof(*_samples) + _rsize )); 132 _samples->next = samples; 133 _samples->type = CV_32F; 134 _samples->data.fl = _data; 135 _samples->count = _count; 136 total += _count; 137 138 samples = _samples; 139 memcpy( _samples + 1, responses->data.fl, _rsize ); 140 141 ok = true; 142 143 __END__; 144 145 return ok; 146} 147 148 149 150void CvKNearest::find_neighbors_direct( const CvMat* _samples, int k, int start, int end, 151 float* neighbor_responses, const float** neighbors, float* dist ) const 152{ 153 int i, j, count = end - start, k1 = 0, k2 = 0, d = var_count; 154 CvVectors* s = samples; 155 156 for( ; s != 0; s = s->next ) 157 { 158 int n = s->count; 159 for( j = 0; j < n; j++ ) 160 { 161 for( i = 0; i < count; i++ ) 162 { 163 double sum = 0; 164 Cv32suf si; 165 const float* v = s->data.fl[j]; 166 const float* u = (float*)(_samples->data.ptr + _samples->step*(start + i)); 167 Cv32suf* dd = (Cv32suf*)(dist + i*k); 168 float* nr; 169 const float** nn; 170 int t, ii, ii1; 171 172 for( t = 0; t <= d - 4; t += 4 ) 173 { 174 double t0 = u[t] - v[t], t1 = u[t+1] - v[t+1]; 175 double t2 = u[t+2] - v[t+2], t3 = u[t+3] - v[t+3]; 176 sum += t0*t0 + t1*t1 + t2*t2 + t3*t3; 177 } 178 179 for( ; t < d; t++ ) 180 { 181 double t0 = u[t] - v[t]; 182 sum += t0*t0; 183 } 184 185 si.f = (float)sum; 186 for( ii = k1-1; ii >= 0; ii-- ) 187 if( si.i > dd[ii].i ) 188 break; 189 if( ii >= k-1 ) 190 continue; 191 192 nr = neighbor_responses + i*k; 193 nn = neighbors ? neighbors + (start + i)*k : 0; 194 for( ii1 = k2 - 1; ii1 > ii; ii1-- ) 195 { 196 dd[ii1+1].i = dd[ii1].i; 197 nr[ii1+1] = nr[ii1]; 198 if( nn ) nn[ii1+1] = nn[ii1]; 199 } 200 dd[ii+1].i = si.i; 201 nr[ii+1] = ((float*)(s + 1))[j]; 202 if( nn ) 203 nn[ii+1] = v; 204 } 205 k1 = MIN( k1+1, k ); 206 k2 = MIN( k1, k-1 ); 207 } 208 } 209} 210 211 212float CvKNearest::write_results( int k, int k1, int start, int end, 213 const float* neighbor_responses, const float* dist, 214 CvMat* _results, CvMat* _neighbor_responses, 215 CvMat* _dist, Cv32suf* sort_buf ) const 216{ 217 float result = 0.f; 218 int i, j, j1, count = end - start; 219 double inv_scale = 1./k1; 220 int rstep = _results && !CV_IS_MAT_CONT(_results->type) ? _results->step/sizeof(result) : 1; 221 222 for( i = 0; i < count; i++ ) 223 { 224 const Cv32suf* nr = (const Cv32suf*)(neighbor_responses + i*k); 225 float* dst; 226 float r; 227 if( _results || start+i == 0 ) 228 { 229 if( regression ) 230 { 231 double s = 0; 232 for( j = 0; j < k1; j++ ) 233 s += nr[j].f; 234 r = (float)(s*inv_scale); 235 } 236 else 237 { 238 int prev_start = 0, best_count = 0, cur_count; 239 Cv32suf best_val; 240 241 for( j = 0; j < k1; j++ ) 242 sort_buf[j].i = nr[j].i; 243 244 for( j = k1-1; j > 0; j-- ) 245 { 246 bool swap_fl = false; 247 for( j1 = 0; j1 < j; j1++ ) 248 if( sort_buf[j1].i > sort_buf[j1+1].i ) 249 { 250 int t; 251 CV_SWAP( sort_buf[j1].i, sort_buf[j1+1].i, t ); 252 swap_fl = true; 253 } 254 if( !swap_fl ) 255 break; 256 } 257 258 best_val.i = 0; 259 for( j = 1; j <= k1; j++ ) 260 if( j == k1 || sort_buf[j].i != sort_buf[j-1].i ) 261 { 262 cur_count = j - prev_start; 263 if( best_count < cur_count ) 264 { 265 best_count = cur_count; 266 best_val.i = sort_buf[j-1].i; 267 } 268 prev_start = j; 269 } 270 r = best_val.f; 271 } 272 273 if( start+i == 0 ) 274 result = r; 275 276 if( _results ) 277 _results->data.fl[(start + i)*rstep] = r; 278 } 279 280 if( _neighbor_responses ) 281 { 282 dst = (float*)(_neighbor_responses->data.ptr + 283 (start + i)*_neighbor_responses->step); 284 for( j = 0; j < k1; j++ ) 285 dst[j] = nr[j].f; 286 for( ; j < k; j++ ) 287 dst[j] = 0.f; 288 } 289 290 if( _dist ) 291 { 292 dst = (float*)(_dist->data.ptr + (start + i)*_dist->step); 293 for( j = 0; j < k1; j++ ) 294 dst[j] = dist[j + i*k]; 295 for( ; j < k; j++ ) 296 dst[j] = 0.f; 297 } 298 } 299 300 return result; 301} 302 303 304 305float CvKNearest::find_nearest( const CvMat* _samples, int k, CvMat* _results, 306 const float** _neighbors, CvMat* _neighbor_responses, CvMat* _dist ) const 307{ 308 float result = 0.f; 309 bool local_alloc = false; 310 float* buf = 0; 311 const int max_blk_count = 128, max_buf_sz = 1 << 12; 312 313 CV_FUNCNAME( "CvKNearest::find_nearest" ); 314 315 __BEGIN__; 316 317 int i, count, count_scale, blk_count0, blk_count = 0, buf_sz, k1; 318 319 if( !samples ) 320 CV_ERROR( CV_StsError, "The search tree must be constructed first using train method" ); 321 322 if( !CV_IS_MAT(_samples) || 323 CV_MAT_TYPE(_samples->type) != CV_32FC1 || 324 _samples->cols != var_count ) 325 CV_ERROR( CV_StsBadArg, "Input samples must be floating-point matrix (<num_samples>x<var_count>)" ); 326 327 if( _results && (!CV_IS_MAT(_results) || 328 _results->cols != 1 && _results->rows != 1 || 329 _results->cols + _results->rows - 1 != _samples->rows) ) 330 CV_ERROR( CV_StsBadArg, 331 "The results must be 1d vector containing as much elements as the number of samples" ); 332 333 if( _results && CV_MAT_TYPE(_results->type) != CV_32FC1 && 334 (CV_MAT_TYPE(_results->type) != CV_32SC1 || regression)) 335 CV_ERROR( CV_StsUnsupportedFormat, 336 "The results must be floating-point or integer (in case of classification) vector" ); 337 338 if( k < 1 || k > max_k ) 339 CV_ERROR( CV_StsOutOfRange, "k must be within 1..max_k range" ); 340 341 if( _neighbor_responses ) 342 { 343 if( !CV_IS_MAT(_neighbor_responses) || CV_MAT_TYPE(_neighbor_responses->type) != CV_32FC1 || 344 _neighbor_responses->rows != _samples->rows || _neighbor_responses->cols != k ) 345 CV_ERROR( CV_StsBadArg, 346 "The neighbor responses (if present) must be floating-point matrix of <num_samples> x <k> size" ); 347 } 348 349 if( _dist ) 350 { 351 if( !CV_IS_MAT(_dist) || CV_MAT_TYPE(_dist->type) != CV_32FC1 || 352 _dist->rows != _samples->rows || _dist->cols != k ) 353 CV_ERROR( CV_StsBadArg, 354 "The distances from the neighbors (if present) must be floating-point matrix of <num_samples> x <k> size" ); 355 } 356 357 count = _samples->rows; 358 count_scale = k*2*sizeof(float); 359 blk_count0 = MIN( count, max_blk_count ); 360 buf_sz = MIN( blk_count0 * count_scale, max_buf_sz ); 361 blk_count0 = MAX( buf_sz/count_scale, 1 ); 362 blk_count0 += blk_count0 % 2; 363 blk_count0 = MIN( blk_count0, count ); 364 buf_sz = blk_count0 * count_scale + k*sizeof(float); 365 k1 = get_sample_count(); 366 k1 = MIN( k1, k ); 367 368 if( buf_sz <= CV_MAX_LOCAL_SIZE ) 369 { 370 buf = (float*)cvStackAlloc( buf_sz ); 371 local_alloc = true; 372 } 373 else 374 CV_CALL( buf = (float*)cvAlloc( buf_sz )); 375 376 for( i = 0; i < count; i += blk_count ) 377 { 378 blk_count = MIN( count - i, blk_count0 ); 379 float* neighbor_responses = buf; 380 float* dist = buf + blk_count*k; 381 Cv32suf* sort_buf = (Cv32suf*)(dist + blk_count*k); 382 383 find_neighbors_direct( _samples, k, i, i + blk_count, 384 neighbor_responses, _neighbors, dist ); 385 386 float r = write_results( k, k1, i, i + blk_count, neighbor_responses, dist, 387 _results, _neighbor_responses, _dist, sort_buf ); 388 if( i == 0 ) 389 result = r; 390 } 391 392 __END__; 393 394 if( !local_alloc ) 395 cvFree( &buf ); 396 397 return result; 398} 399 400/* End of file */ 401 402