1/*M///////////////////////////////////////////////////////////////////////////////////////
2//
3//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4//
5//  By downloading, copying, installing or using the software you agree to this license.
6//  If you do not agree to this license, do not download, install,
7//  copy or use the software.
8//
9//
10//            Intel License Agreement
11//
12// Copyright (C) 2000, Intel Corporation, all rights reserved.
13// Third party copyrights are property of their respective owners.
14//
15// Redistribution and use in source and binary forms, with or without modification,
16// are permitted provided that the following conditions are met:
17//
18//   * Redistribution's of source code must retain the above copyright notice,
19//     this list of conditions and the following disclaimer.
20//
21//   * Redistribution's in binary form must reproduce the above copyright notice,
22//     this list of conditions and the following disclaimer in the documentation
23//     and/or other materials provided with the distribution.
24//
25//   * The name of Intel Corporation may not be used to endorse or promote products
26//     derived from this software without specific prior written permission.
27//
28// This software is provided by the copyright holders and contributors "as is" and
29// any express or implied warranties, including, but not limited to, the implied
30// warranties of merchantability and fitness for a particular purpose are disclaimed.
31// In no event shall the Intel Corporation or contributors be liable for any direct,
32// indirect, incidental, special, exemplary, or consequential damages
33// (including, but not limited to, procurement of substitute goods or services;
34// loss of use, data, or profits; or business interruption) however caused
35// and on any theory of liability, whether in contract, strict liability,
36// or tort (including negligence or otherwise) arising in any way out of
37// the use of this software, even if advised of the possibility of such damage.
38//
39//M*/
40
41#include "_ml.h"
42
43CvNormalBayesClassifier::CvNormalBayesClassifier()
44{
45    var_count = var_all = 0;
46    var_idx = 0;
47    cls_labels = 0;
48    count = 0;
49    sum = 0;
50    productsum = 0;
51    avg = 0;
52    inv_eigen_values = 0;
53    cov_rotate_mats = 0;
54    c = 0;
55    default_model_name = "my_nb";
56}
57
58
59void CvNormalBayesClassifier::clear()
60{
61    if( cls_labels )
62    {
63        for( int cls = 0; cls < cls_labels->cols; cls++ )
64        {
65            cvReleaseMat( &count[cls] );
66            cvReleaseMat( &sum[cls] );
67            cvReleaseMat( &productsum[cls] );
68            cvReleaseMat( &avg[cls] );
69            cvReleaseMat( &inv_eigen_values[cls] );
70            cvReleaseMat( &cov_rotate_mats[cls] );
71        }
72    }
73
74    cvReleaseMat( &cls_labels );
75    cvReleaseMat( &var_idx );
76    cvReleaseMat( &c );
77    cvFree( &count );
78}
79
80
81CvNormalBayesClassifier::~CvNormalBayesClassifier()
82{
83    clear();
84}
85
86
87CvNormalBayesClassifier::CvNormalBayesClassifier(
88    const CvMat* _train_data, const CvMat* _responses,
89    const CvMat* _var_idx, const CvMat* _sample_idx )
90{
91    var_count = var_all = 0;
92    var_idx = 0;
93    cls_labels = 0;
94    count = 0;
95    sum = 0;
96    productsum = 0;
97    avg = 0;
98    inv_eigen_values = 0;
99    cov_rotate_mats = 0;
100    c = 0;
101    default_model_name = "my_nb";
102
103    train( _train_data, _responses, _var_idx, _sample_idx );
104}
105
106
107bool CvNormalBayesClassifier::train( const CvMat* _train_data, const CvMat* _responses,
108                            const CvMat* _var_idx, const CvMat* _sample_idx, bool update )
109{
110    const float min_variation = FLT_EPSILON;
111    bool result = false;
112    CvMat* responses   = 0;
113    const float** train_data = 0;
114    CvMat* __cls_labels = 0;
115    CvMat* __var_idx = 0;
116    CvMat* cov = 0;
117
118    CV_FUNCNAME( "CvNormalBayesClassifier::train" );
119
120    __BEGIN__;
121
122    int cls, nsamples = 0, _var_count = 0, _var_all = 0, nclasses = 0;
123    int s, c1, c2;
124    const int* responses_data;
125
126    CV_CALL( cvPrepareTrainData( 0,
127        _train_data, CV_ROW_SAMPLE, _responses, CV_VAR_CATEGORICAL,
128        _var_idx, _sample_idx, false, &train_data,
129        &nsamples, &_var_count, &_var_all, &responses,
130        &__cls_labels, &__var_idx ));
131
132    if( !update )
133    {
134        const size_t mat_size = sizeof(CvMat*);
135        size_t data_size;
136
137        clear();
138
139        var_idx = __var_idx;
140        cls_labels = __cls_labels;
141        __var_idx = __cls_labels = 0;
142        var_count = _var_count;
143        var_all = _var_all;
144
145        nclasses = cls_labels->cols;
146        data_size = nclasses*6*mat_size;
147
148        CV_CALL( count = (CvMat**)cvAlloc( data_size ));
149        memset( count, 0, data_size );
150
151        sum             = count      + nclasses;
152        productsum      = sum        + nclasses;
153        avg             = productsum + nclasses;
154        inv_eigen_values= avg        + nclasses;
155        cov_rotate_mats = inv_eigen_values         + nclasses;
156
157        CV_CALL( c = cvCreateMat( 1, nclasses, CV_64FC1 ));
158
159        for( cls = 0; cls < nclasses; cls++ )
160        {
161            CV_CALL(count[cls]            = cvCreateMat( 1, var_count, CV_32SC1 ));
162            CV_CALL(sum[cls]              = cvCreateMat( 1, var_count, CV_64FC1 ));
163            CV_CALL(productsum[cls]       = cvCreateMat( var_count, var_count, CV_64FC1 ));
164            CV_CALL(avg[cls]              = cvCreateMat( 1, var_count, CV_64FC1 ));
165            CV_CALL(inv_eigen_values[cls] = cvCreateMat( 1, var_count, CV_64FC1 ));
166            CV_CALL(cov_rotate_mats[cls]  = cvCreateMat( var_count, var_count, CV_64FC1 ));
167            CV_CALL(cvZero( count[cls] ));
168            CV_CALL(cvZero( sum[cls] ));
169            CV_CALL(cvZero( productsum[cls] ));
170            CV_CALL(cvZero( avg[cls] ));
171            CV_CALL(cvZero( inv_eigen_values[cls] ));
172            CV_CALL(cvZero( cov_rotate_mats[cls] ));
173        }
174    }
175    else
176    {
177        // check that the new training data has the same dimensionality etc.
178        if( _var_count != var_count || _var_all != var_all || !(!_var_idx && !var_idx ||
179            _var_idx && var_idx && cvNorm(_var_idx,var_idx,CV_C) < DBL_EPSILON) )
180            CV_ERROR( CV_StsBadArg,
181            "The new training data is inconsistent with the original training data" );
182
183        if( cls_labels->cols != __cls_labels->cols ||
184            cvNorm(cls_labels, __cls_labels, CV_C) > DBL_EPSILON )
185            CV_ERROR( CV_StsNotImplemented,
186            "In the current implementation the new training data must have absolutely "
187            "the same set of class labels as used in the original training data" );
188
189        nclasses = cls_labels->cols;
190    }
191
192    responses_data = responses->data.i;
193    CV_CALL( cov = cvCreateMat( _var_count, _var_count, CV_64FC1 ));
194
195    /* process train data (count, sum , productsum) */
196    for( s = 0; s < nsamples; s++ )
197    {
198        cls = responses_data[s];
199        int* count_data = count[cls]->data.i;
200        double* sum_data = sum[cls]->data.db;
201        double* prod_data = productsum[cls]->data.db;
202        const float* train_vec = train_data[s];
203
204        for( c1 = 0; c1 < _var_count; c1++, prod_data += _var_count )
205        {
206            double val1 = train_vec[c1];
207            sum_data[c1] += val1;
208            count_data[c1]++;
209            for( c2 = c1; c2 < _var_count; c2++ )
210                prod_data[c2] += train_vec[c2]*val1;
211        }
212    }
213
214    /* calculate avg, covariance matrix, c */
215    for( cls = 0; cls < nclasses; cls++ )
216    {
217        double det = 1;
218        int i, j;
219        CvMat* w = inv_eigen_values[cls];
220        int* count_data = count[cls]->data.i;
221        double* avg_data = avg[cls]->data.db;
222        double* sum1 = sum[cls]->data.db;
223
224        cvCompleteSymm( productsum[cls], 0 );
225
226        for( j = 0; j < _var_count; j++ )
227        {
228            int n = count_data[j];
229            avg_data[j] = n ? sum1[j] / n : 0.;
230        }
231
232        count_data = count[cls]->data.i;
233        avg_data = avg[cls]->data.db;
234        sum1 = sum[cls]->data.db;
235
236        for( i = 0; i < _var_count; i++ )
237        {
238            double* avg2_data = avg[cls]->data.db;
239            double* sum2 = sum[cls]->data.db;
240            double* prod_data = productsum[cls]->data.db + i*_var_count;
241            double* cov_data = cov->data.db + i*_var_count;
242            double s1val = sum1[j];
243            double avg1 = avg_data[i];
244            int count = count_data[i];
245
246            for( j = 0; j <= i; j++ )
247            {
248                double avg2 = avg2_data[j];
249                double cov_val = prod_data[j] - avg1 * sum2[j] - avg2 * s1val + avg1 * avg2 * count;
250                cov_val = (count > 1) ? cov_val / (count - 1) : cov_val;
251                cov_data[j] = cov_val;
252            }
253        }
254
255        CV_CALL( cvCompleteSymm( cov, 1 ));
256        CV_CALL( cvSVD( cov, w, cov_rotate_mats[cls], 0, CV_SVD_U_T ));
257        CV_CALL( cvMaxS( w, min_variation, w ));
258        for( j = 0; j < _var_count; j++ )
259            det *= w->data.db[j];
260
261        CV_CALL( cvDiv( NULL, w, w ));
262        c->data.db[cls] = log( det );
263    }
264
265    result = true;
266
267    __END__;
268
269    if( !result || cvGetErrStatus() < 0 )
270        clear();
271
272    cvReleaseMat( &cov );
273    cvReleaseMat( &__cls_labels );
274    cvReleaseMat( &__var_idx );
275    cvFree( &train_data );
276
277    return result;
278}
279
280
281float CvNormalBayesClassifier::predict( const CvMat* samples, CvMat* results ) const
282{
283    float value = 0;
284    void* buffer = 0;
285    int allocated_buffer = 0;
286
287    CV_FUNCNAME( "CvNormalBayesClassifier::predict" );
288
289    __BEGIN__;
290
291    int i, j, k, cls = -1, _var_count, nclasses;
292    double opt = FLT_MAX;
293    CvMat diff;
294    int rtype = 0, rstep = 0, size;
295    const int* vidx = 0;
296
297    nclasses = cls_labels->cols;
298    _var_count = avg[0]->cols;
299
300    if( !CV_IS_MAT(samples) || CV_MAT_TYPE(samples->type) != CV_32FC1 || samples->cols != var_all )
301        CV_ERROR( CV_StsBadArg,
302        "The input samples must be 32f matrix with the number of columns = var_all" );
303
304    if( samples->rows > 1 && !results )
305        CV_ERROR( CV_StsNullPtr,
306        "When the number of input samples is >1, the output vector of results must be passed" );
307
308    if( results )
309    {
310        if( !CV_IS_MAT(results) || CV_MAT_TYPE(results->type) != CV_32FC1 &&
311        CV_MAT_TYPE(results->type) != CV_32SC1 ||
312        results->cols != 1 && results->rows != 1 ||
313        results->cols + results->rows - 1 != samples->rows )
314        CV_ERROR( CV_StsBadArg, "The output array must be integer or floating-point vector "
315        "with the number of elements = number of rows in the input matrix" );
316
317        rtype = CV_MAT_TYPE(results->type);
318        rstep = CV_IS_MAT_CONT(results->type) ? 1 : results->step/CV_ELEM_SIZE(rtype);
319    }
320
321    if( var_idx )
322        vidx = var_idx->data.i;
323
324// allocate memory and initializing headers for calculating
325    size = sizeof(double) * (nclasses + var_count);
326    if( size <= CV_MAX_LOCAL_SIZE )
327        buffer = cvStackAlloc( size );
328    else
329    {
330        CV_CALL( buffer = cvAlloc( size ));
331        allocated_buffer = 1;
332    }
333
334    diff = cvMat( 1, var_count, CV_64FC1, buffer );
335
336    for( k = 0; k < samples->rows; k++ )
337    {
338        int ival;
339
340        for( i = 0; i < nclasses; i++ )
341        {
342            double cur = c->data.db[i];
343            CvMat* u = cov_rotate_mats[i];
344            CvMat* w = inv_eigen_values[i];
345            const double* avg_data = avg[i]->data.db;
346            const float* x = (const float*)(samples->data.ptr + samples->step*k);
347
348            // cov = u w u'  -->  cov^(-1) = u w^(-1) u'
349            for( j = 0; j < _var_count; j++ )
350                diff.data.db[j] = avg_data[j] - x[vidx ? vidx[j] : j];
351
352            CV_CALL(cvGEMM( &diff, u, 1, 0, 0, &diff, CV_GEMM_B_T ));
353            for( j = 0; j < _var_count; j++ )
354            {
355                double d = diff.data.db[j];
356                cur += d*d*w->data.db[j];
357            }
358
359            if( cur < opt )
360            {
361                cls = i;
362                opt = cur;
363            }
364            /* probability = exp( -0.5 * cur ) */
365        }
366
367        ival = cls_labels->data.i[cls];
368        if( results )
369        {
370            if( rtype == CV_32SC1 )
371                results->data.i[k*rstep] = ival;
372            else
373                results->data.fl[k*rstep] = (float)ival;
374        }
375        if( k == 0 )
376            value = (float)ival;
377
378        /*if( _probs )
379        {
380            CV_CALL( cvConvertScale( &expo, &expo, -0.5 ));
381            CV_CALL( cvExp( &expo, &expo ));
382            if( _probs->cols == 1 )
383                CV_CALL( cvReshape( &expo, &expo, 1, nclasses ));
384            CV_CALL( cvConvertScale( &expo, _probs, 1./cvSum( &expo ).val[0] ));
385        }*/
386    }
387
388    __END__;
389
390    if( allocated_buffer )
391        cvFree( &buffer );
392
393    return value;
394}
395
396
397void CvNormalBayesClassifier::write( CvFileStorage* fs, const char* name )
398{
399    CV_FUNCNAME( "CvNormalBayesClassifier::write" );
400
401    __BEGIN__;
402
403    int nclasses, i;
404
405    nclasses = cls_labels->cols;
406
407    cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_NBAYES );
408
409    CV_CALL( cvWriteInt( fs, "var_count", var_count ));
410    CV_CALL( cvWriteInt( fs, "var_all", var_all ));
411
412    if( var_idx )
413        CV_CALL( cvWrite( fs, "var_idx", var_idx ));
414    CV_CALL( cvWrite( fs, "cls_labels", cls_labels ));
415
416    CV_CALL( cvStartWriteStruct( fs, "count", CV_NODE_SEQ ));
417    for( i = 0; i < nclasses; i++ )
418        CV_CALL( cvWrite( fs, NULL, count[i] ));
419    CV_CALL( cvEndWriteStruct( fs ));
420
421    CV_CALL( cvStartWriteStruct( fs, "sum", CV_NODE_SEQ ));
422    for( i = 0; i < nclasses; i++ )
423        CV_CALL( cvWrite( fs, NULL, sum[i] ));
424    CV_CALL( cvEndWriteStruct( fs ));
425
426    CV_CALL( cvStartWriteStruct( fs, "productsum", CV_NODE_SEQ ));
427    for( i = 0; i < nclasses; i++ )
428        CV_CALL( cvWrite( fs, NULL, productsum[i] ));
429    CV_CALL( cvEndWriteStruct( fs ));
430
431    CV_CALL( cvStartWriteStruct( fs, "avg", CV_NODE_SEQ ));
432    for( i = 0; i < nclasses; i++ )
433        CV_CALL( cvWrite( fs, NULL, avg[i] ));
434    CV_CALL( cvEndWriteStruct( fs ));
435
436    CV_CALL( cvStartWriteStruct( fs, "inv_eigen_values", CV_NODE_SEQ ));
437    for( i = 0; i < nclasses; i++ )
438        CV_CALL( cvWrite( fs, NULL, inv_eigen_values[i] ));
439    CV_CALL( cvEndWriteStruct( fs ));
440
441    CV_CALL( cvStartWriteStruct( fs, "cov_rotate_mats", CV_NODE_SEQ ));
442    for( i = 0; i < nclasses; i++ )
443        CV_CALL( cvWrite( fs, NULL, cov_rotate_mats[i] ));
444    CV_CALL( cvEndWriteStruct( fs ));
445
446    CV_CALL( cvWrite( fs, "c", c ));
447
448    cvEndWriteStruct( fs );
449
450    __END__;
451}
452
453
454void CvNormalBayesClassifier::read( CvFileStorage* fs, CvFileNode* root_node )
455{
456    bool ok = false;
457    CV_FUNCNAME( "CvNormalBayesClassifier::read" );
458
459    __BEGIN__;
460
461    int nclasses, i;
462    size_t data_size;
463    CvFileNode* node;
464    CvSeq* seq;
465    CvSeqReader reader;
466
467    clear();
468
469    CV_CALL( var_count = cvReadIntByName( fs, root_node, "var_count", -1 ));
470    CV_CALL( var_all = cvReadIntByName( fs, root_node, "var_all", -1 ));
471    CV_CALL( var_idx = (CvMat*)cvReadByName( fs, root_node, "var_idx" ));
472    CV_CALL( cls_labels = (CvMat*)cvReadByName( fs, root_node, "cls_labels" ));
473    if( !cls_labels )
474        CV_ERROR( CV_StsParseError, "No \"cls_labels\" in NBayes classifier" );
475    if( cls_labels->cols < 1 )
476        CV_ERROR( CV_StsBadArg, "Number of classes is less 1" );
477    if( var_count <= 0 )
478        CV_ERROR( CV_StsParseError,
479        "The field \"var_count\" of NBayes classifier is missing" );
480    nclasses = cls_labels->cols;
481
482    data_size = nclasses*6*sizeof(CvMat*);
483    CV_CALL( count = (CvMat**)cvAlloc( data_size ));
484    memset( count, 0, data_size );
485
486    sum = count + nclasses;
487    productsum  = sum  + nclasses;
488    avg = productsum + nclasses;
489    inv_eigen_values = avg + nclasses;
490    cov_rotate_mats = inv_eigen_values + nclasses;
491
492    CV_CALL( node = cvGetFileNodeByName( fs, root_node, "count" ));
493    seq = node->data.seq;
494    if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
495        CV_ERROR( CV_StsBadArg, "" );
496    CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
497    for( i = 0; i < nclasses; i++ )
498    {
499        CV_CALL( count[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
500        CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
501    }
502
503    CV_CALL( node = cvGetFileNodeByName( fs, root_node, "sum" ));
504    seq = node->data.seq;
505    if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
506        CV_ERROR( CV_StsBadArg, "" );
507    CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
508    for( i = 0; i < nclasses; i++ )
509    {
510        CV_CALL( sum[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
511        CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
512    }
513
514    CV_CALL( node = cvGetFileNodeByName( fs, root_node, "productsum" ));
515    seq = node->data.seq;
516    if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
517        CV_ERROR( CV_StsBadArg, "" );
518    CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
519    for( i = 0; i < nclasses; i++ )
520    {
521        CV_CALL( productsum[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
522        CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
523    }
524
525    CV_CALL( node = cvGetFileNodeByName( fs, root_node, "avg" ));
526    seq = node->data.seq;
527    if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
528        CV_ERROR( CV_StsBadArg, "" );
529    CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
530    for( i = 0; i < nclasses; i++ )
531    {
532        CV_CALL( avg[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
533        CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
534    }
535
536    CV_CALL( node = cvGetFileNodeByName( fs, root_node, "inv_eigen_values" ));
537    seq = node->data.seq;
538    if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
539        CV_ERROR( CV_StsBadArg, "" );
540    CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
541    for( i = 0; i < nclasses; i++ )
542    {
543        CV_CALL( inv_eigen_values[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
544        CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
545    }
546
547    CV_CALL( node = cvGetFileNodeByName( fs, root_node, "cov_rotate_mats" ));
548    seq = node->data.seq;
549    if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
550        CV_ERROR( CV_StsBadArg, "" );
551    CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
552    for( i = 0; i < nclasses; i++ )
553    {
554        CV_CALL( cov_rotate_mats[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
555        CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
556    }
557
558    CV_CALL( c = (CvMat*)cvReadByName( fs, root_node, "c" ));
559
560    ok = true;
561
562    __END__;
563
564    if( !ok )
565        clear();
566}
567
568/* End of file. */
569
570