1/*M///////////////////////////////////////////////////////////////////////////////////////
2//
3//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4//
5//  By downloading, copying, installing or using the software you agree to this license.
6//  If you do not agree to this license, do not download, install,
7//  copy or use the software.
8//
9//
10//                        Intel License Agreement
11//
12// Copyright (C) 2000, Intel Corporation, all rights reserved.
13// Third party copyrights are property of their respective owners.
14//
15// Redistribution and use in source and binary forms, with or without modification,
16// are permitted provided that the following conditions are met:
17//
18//   * Redistribution's of source code must retain the above copyright notice,
19//     this list of conditions and the following disclaimer.
20//
21//   * Redistribution's in binary form must reproduce the above copyright notice,
22//     this list of conditions and the following disclaimer in the documentation
23//     and/or other materials provided with the distribution.
24//
25//   * The name of Intel Corporation may not be used to endorse or promote products
26//     derived from this software without specific prior written permission.
27//
28// This software is provided by the copyright holders and contributors "as is" and
29// any express or implied warranties, including, but not limited to, the implied
30// warranties of merchantability and fitness for a particular purpose are disclaimed.
31// In no event shall the Intel Corporation or contributors be liable for any direct,
32// indirect, incidental, special, exemplary, or consequential damages
33// (including, but not limited to, procurement of substitute goods or services;
34// loss of use, data, or profits; or business interruption) however caused
35// and on any theory of liability, whether in contract, strict liability,
36// or tort (including negligence or otherwise) arising in any way out of
37// the use of this software, even if advised of the possibility of such damage.
38//
39//M*/
40
41#include "_ml.h"
42
43/****************************************************************************************\
44                                COPYRIGHT NOTICE
45                                ----------------
46
47  The code has been derived from libsvm library (version 2.6)
48  (http://www.csie.ntu.edu.tw/~cjlin/libsvm).
49
50  Here is the orignal copyright:
51------------------------------------------------------------------------------------------
52    Copyright (c) 2000-2003 Chih-Chung Chang and Chih-Jen Lin
53    All rights reserved.
54
55    Redistribution and use in source and binary forms, with or without
56    modification, are permitted provided that the following conditions
57    are met:
58
59    1. Redistributions of source code must retain the above copyright
60    notice, this list of conditions and the following disclaimer.
61
62    2. Redistributions in binary form must reproduce the above copyright
63    notice, this list of conditions and the following disclaimer in the
64    documentation and/or other materials provided with the distribution.
65
66    3. Neither name of copyright holders nor the names of its contributors
67    may be used to endorse or promote products derived from this software
68    without specific prior written permission.
69
70
71    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
72    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
73    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
74    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE REGENTS OR
75    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
76    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
77    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
78    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
79    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
80    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
81    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
82\****************************************************************************************/
83
84#define CV_SVM_MIN_CACHE_SIZE  (40 << 20)  /* 40Mb */
85
86#include <stdarg.h>
87#include <ctype.h>
88
89#if _MSC_VER >= 1200
90#pragma warning( disable: 4514 ) /* unreferenced inline functions */
91#endif
92
93#if 1
94typedef float Qfloat;
95#define QFLOAT_TYPE CV_32F
96#else
97typedef double Qfloat;
98#define QFLOAT_TYPE CV_64F
99#endif
100
101// Param Grid
102bool CvParamGrid::check() const
103{
104    bool ok = false;
105
106    CV_FUNCNAME( "CvParamGrid::check" );
107    __BEGIN__;
108
109    if( min_val > max_val )
110        CV_ERROR( CV_StsBadArg, "Lower bound of the grid must be less then the upper one" );
111    if( min_val < DBL_EPSILON )
112        CV_ERROR( CV_StsBadArg, "Lower bound of the grid must be positive" );
113    if( step < 1. + FLT_EPSILON )
114        CV_ERROR( CV_StsBadArg, "Grid step must greater then 1" );
115
116    ok = true;
117
118    __END__;
119
120    return ok;
121}
122
123CvParamGrid CvSVM::get_default_grid( int param_id )
124{
125    CvParamGrid grid;
126    if( param_id == CvSVM::C )
127    {
128        grid.min_val = 0.1;
129        grid.max_val = 500;
130        grid.step = 5; // total iterations = 5
131    }
132    else if( param_id == CvSVM::GAMMA )
133    {
134        grid.min_val = 1e-5;
135        grid.max_val = 0.6;
136        grid.step = 15; // total iterations = 4
137    }
138    else if( param_id == CvSVM::P )
139    {
140        grid.min_val = 0.01;
141        grid.max_val = 100;
142        grid.step = 7; // total iterations = 4
143    }
144    else if( param_id == CvSVM::NU )
145    {
146        grid.min_val = 0.01;
147        grid.max_val = 0.2;
148        grid.step = 3; // total iterations = 3
149    }
150    else if( param_id == CvSVM::COEF )
151    {
152        grid.min_val = 0.1;
153        grid.max_val = 300;
154        grid.step = 14; // total iterations = 3
155    }
156    else if( param_id == CvSVM::DEGREE )
157    {
158        grid.min_val = 0.01;
159        grid.max_val = 4;
160        grid.step = 7; // total iterations = 3
161    }
162    else
163        cvError( CV_StsBadArg, "CvSVM::get_default_grid", "Invalid type of parameter "
164            "(use one of CvSVM::C, CvSVM::GAMMA et al.)", __FILE__, __LINE__ );
165    return grid;
166}
167
168// SVM training parameters
169CvSVMParams::CvSVMParams() :
170    svm_type(CvSVM::C_SVC), kernel_type(CvSVM::RBF), degree(0),
171    gamma(1), coef0(0), C(1), nu(0), p(0), class_weights(0)
172{
173    term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 1000, FLT_EPSILON );
174}
175
176
177CvSVMParams::CvSVMParams( int _svm_type, int _kernel_type,
178    double _degree, double _gamma, double _coef0,
179    double _Con, double _nu, double _p,
180    CvMat* _class_weights, CvTermCriteria _term_crit ) :
181    svm_type(_svm_type), kernel_type(_kernel_type),
182    degree(_degree), gamma(_gamma), coef0(_coef0),
183    C(_Con), nu(_nu), p(_p), class_weights(_class_weights), term_crit(_term_crit)
184{
185}
186
187
188/////////////////////////////////////// SVM kernel ///////////////////////////////////////
189
190CvSVMKernel::CvSVMKernel()
191{
192    clear();
193}
194
195
196void CvSVMKernel::clear()
197{
198    params = 0;
199    calc_func = 0;
200}
201
202
203CvSVMKernel::~CvSVMKernel()
204{
205}
206
207
208CvSVMKernel::CvSVMKernel( const CvSVMParams* _params, Calc _calc_func )
209{
210    clear();
211    create( _params, _calc_func );
212}
213
214
215bool CvSVMKernel::create( const CvSVMParams* _params, Calc _calc_func )
216{
217    clear();
218    params = _params;
219    calc_func = _calc_func;
220
221    if( !calc_func )
222        calc_func = params->kernel_type == CvSVM::RBF ? &CvSVMKernel::calc_rbf :
223                    params->kernel_type == CvSVM::POLY ? &CvSVMKernel::calc_poly :
224                    params->kernel_type == CvSVM::SIGMOID ? &CvSVMKernel::calc_sigmoid :
225                    &CvSVMKernel::calc_linear;
226
227    return true;
228}
229
230
231void CvSVMKernel::calc_non_rbf_base( int vcount, int var_count, const float** vecs,
232                                     const float* another, Qfloat* results,
233                                     double alpha, double beta )
234{
235    int j, k;
236    for( j = 0; j < vcount; j++ )
237    {
238        const float* sample = vecs[j];
239        double s = 0;
240        for( k = 0; k <= var_count - 4; k += 4 )
241            s += sample[k]*another[k] + sample[k+1]*another[k+1] +
242                 sample[k+2]*another[k+2] + sample[k+3]*another[k+3];
243        for( ; k < var_count; k++ )
244            s += sample[k]*another[k];
245        results[j] = (Qfloat)(s*alpha + beta);
246    }
247}
248
249
250void CvSVMKernel::calc_linear( int vcount, int var_count, const float** vecs,
251                               const float* another, Qfloat* results )
252{
253    calc_non_rbf_base( vcount, var_count, vecs, another, results, 1, 0 );
254}
255
256
257void CvSVMKernel::calc_poly( int vcount, int var_count, const float** vecs,
258                             const float* another, Qfloat* results )
259{
260    CvMat R = cvMat( 1, vcount, QFLOAT_TYPE, results );
261    calc_non_rbf_base( vcount, var_count, vecs, another, results, params->gamma, params->coef0 );
262    cvPow( &R, &R, params->degree );
263}
264
265
266void CvSVMKernel::calc_sigmoid( int vcount, int var_count, const float** vecs,
267                                const float* another, Qfloat* results )
268{
269    int j;
270    calc_non_rbf_base( vcount, var_count, vecs, another, results,
271                       -2*params->gamma, -2*params->coef0 );
272    // TODO: speedup this
273    for( j = 0; j < vcount; j++ )
274    {
275        Qfloat t = results[j];
276        double e = exp(-fabs(t));
277        if( t > 0 )
278            results[j] = (Qfloat)((1. - e)/(1. + e));
279        else
280            results[j] = (Qfloat)((e - 1.)/(e + 1.));
281    }
282}
283
284
285void CvSVMKernel::calc_rbf( int vcount, int var_count, const float** vecs,
286                            const float* another, Qfloat* results )
287{
288    CvMat R = cvMat( 1, vcount, QFLOAT_TYPE, results );
289    double gamma = -params->gamma;
290    int j, k;
291
292    for( j = 0; j < vcount; j++ )
293    {
294        const float* sample = vecs[j];
295        double s = 0;
296
297        for( k = 0; k <= var_count - 4; k += 4 )
298        {
299            double t0 = sample[k] - another[k];
300            double t1 = sample[k+1] - another[k+1];
301
302            s += t0*t0 + t1*t1;
303
304            t0 = sample[k+2] - another[k+2];
305            t1 = sample[k+3] - another[k+3];
306
307            s += t0*t0 + t1*t1;
308        }
309
310        for( ; k < var_count; k++ )
311        {
312            double t0 = sample[k] - another[k];
313            s += t0*t0;
314        }
315        results[j] = (Qfloat)(s*gamma);
316    }
317
318    cvExp( &R, &R );
319}
320
321
322void CvSVMKernel::calc( int vcount, int var_count, const float** vecs,
323                        const float* another, Qfloat* results )
324{
325    const Qfloat max_val = (Qfloat)(FLT_MAX*1e-3);
326    int j;
327    (this->*calc_func)( vcount, var_count, vecs, another, results );
328    for( j = 0; j < vcount; j++ )
329    {
330        if( results[j] > max_val )
331            results[j] = max_val;
332    }
333}
334
335
336// Generalized SMO+SVMlight algorithm
337// Solves:
338//
339//  min [0.5(\alpha^T Q \alpha) + b^T \alpha]
340//
341//      y^T \alpha = \delta
342//      y_i = +1 or -1
343//      0 <= alpha_i <= Cp for y_i = 1
344//      0 <= alpha_i <= Cn for y_i = -1
345//
346// Given:
347//
348//  Q, b, y, Cp, Cn, and an initial feasible point \alpha
349//  l is the size of vectors and matrices
350//  eps is the stopping criterion
351//
352// solution will be put in \alpha, objective value will be put in obj
353//
354
355void CvSVMSolver::clear()
356{
357    G = 0;
358    alpha = 0;
359    y = 0;
360    b = 0;
361    buf[0] = buf[1] = 0;
362    cvReleaseMemStorage( &storage );
363    kernel = 0;
364    select_working_set_func = 0;
365    calc_rho_func = 0;
366
367    rows = 0;
368    samples = 0;
369    get_row_func = 0;
370}
371
372
373CvSVMSolver::CvSVMSolver()
374{
375    storage = 0;
376    clear();
377}
378
379
380CvSVMSolver::~CvSVMSolver()
381{
382    clear();
383}
384
385
386CvSVMSolver::CvSVMSolver( int _sample_count, int _var_count, const float** _samples, schar* _y,
387                int _alpha_count, double* _alpha, double _Cp, double _Cn,
388                CvMemStorage* _storage, CvSVMKernel* _kernel, GetRow _get_row,
389                SelectWorkingSet _select_working_set, CalcRho _calc_rho )
390{
391    storage = 0;
392    create( _sample_count, _var_count, _samples, _y, _alpha_count, _alpha, _Cp, _Cn,
393            _storage, _kernel, _get_row, _select_working_set, _calc_rho );
394}
395
396
397bool CvSVMSolver::create( int _sample_count, int _var_count, const float** _samples, schar* _y,
398                int _alpha_count, double* _alpha, double _Cp, double _Cn,
399                CvMemStorage* _storage, CvSVMKernel* _kernel, GetRow _get_row,
400                SelectWorkingSet _select_working_set, CalcRho _calc_rho )
401{
402    bool ok = false;
403    int i, svm_type;
404
405    CV_FUNCNAME( "CvSVMSolver::create" );
406
407    __BEGIN__;
408
409    int rows_hdr_size;
410
411    clear();
412
413    sample_count = _sample_count;
414    var_count = _var_count;
415    samples = _samples;
416    y = _y;
417    alpha_count = _alpha_count;
418    alpha = _alpha;
419    kernel = _kernel;
420
421    C[0] = _Cn;
422    C[1] = _Cp;
423    eps = kernel->params->term_crit.epsilon;
424    max_iter = kernel->params->term_crit.max_iter;
425    storage = cvCreateChildMemStorage( _storage );
426
427    b = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(b[0]));
428    alpha_status = (schar*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha_status[0]));
429    G = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(G[0]));
430    for( i = 0; i < 2; i++ )
431        buf[i] = (Qfloat*)cvMemStorageAlloc( storage, sample_count*2*sizeof(buf[i][0]) );
432    svm_type = kernel->params->svm_type;
433
434    select_working_set_func = _select_working_set;
435    if( !select_working_set_func )
436        select_working_set_func = svm_type == CvSVM::NU_SVC || svm_type == CvSVM::NU_SVR ?
437        &CvSVMSolver::select_working_set_nu_svm : &CvSVMSolver::select_working_set;
438
439    calc_rho_func = _calc_rho;
440    if( !calc_rho_func )
441        calc_rho_func = svm_type == CvSVM::NU_SVC || svm_type == CvSVM::NU_SVR ?
442            &CvSVMSolver::calc_rho_nu_svm : &CvSVMSolver::calc_rho;
443
444    get_row_func = _get_row;
445    if( !get_row_func )
446        get_row_func = params->svm_type == CvSVM::EPS_SVR ||
447                       params->svm_type == CvSVM::NU_SVR ? &CvSVMSolver::get_row_svr :
448                       params->svm_type == CvSVM::C_SVC ||
449                       params->svm_type == CvSVM::NU_SVC ? &CvSVMSolver::get_row_svc :
450                       &CvSVMSolver::get_row_one_class;
451
452    cache_line_size = sample_count*sizeof(Qfloat);
453    // cache size = max(num_of_samples^2*sizeof(Qfloat)*0.25, 64Kb)
454    // (assuming that for large training sets ~25% of Q matrix is used)
455    cache_size = MAX( cache_line_size*sample_count/4, CV_SVM_MIN_CACHE_SIZE );
456
457    // the size of Q matrix row headers
458    rows_hdr_size = sample_count*sizeof(rows[0]);
459    if( rows_hdr_size > storage->block_size )
460        CV_ERROR( CV_StsOutOfRange, "Too small storage block size" );
461
462    lru_list.prev = lru_list.next = &lru_list;
463    rows = (CvSVMKernelRow*)cvMemStorageAlloc( storage, rows_hdr_size );
464    memset( rows, 0, rows_hdr_size );
465
466    ok = true;
467
468    __END__;
469
470    return ok;
471}
472
473
474float* CvSVMSolver::get_row_base( int i, bool* _existed )
475{
476    int i1 = i < sample_count ? i : i - sample_count;
477    CvSVMKernelRow* row = rows + i1;
478    bool existed = row->data != 0;
479    Qfloat* data;
480
481    if( existed || cache_size <= 0 )
482    {
483        CvSVMKernelRow* del_row = existed ? row : lru_list.prev;
484        data = del_row->data;
485        assert( data != 0 );
486
487        // delete row from the LRU list
488        del_row->data = 0;
489        del_row->prev->next = del_row->next;
490        del_row->next->prev = del_row->prev;
491    }
492    else
493    {
494        data = (Qfloat*)cvMemStorageAlloc( storage, cache_line_size );
495        cache_size -= cache_line_size;
496    }
497
498    // insert row into the LRU list
499    row->data = data;
500    row->prev = &lru_list;
501    row->next = lru_list.next;
502    row->prev->next = row->next->prev = row;
503
504    if( !existed )
505    {
506        kernel->calc( sample_count, var_count, samples, samples[i1], row->data );
507    }
508
509    if( _existed )
510        *_existed = existed;
511
512    return row->data;
513}
514
515
516float* CvSVMSolver::get_row_svc( int i, float* row, float*, bool existed )
517{
518    if( !existed )
519    {
520        const schar* _y = y;
521        int j, len = sample_count;
522        assert( _y && i < sample_count );
523
524        if( _y[i] > 0 )
525        {
526            for( j = 0; j < len; j++ )
527                row[j] = _y[j]*row[j];
528        }
529        else
530        {
531            for( j = 0; j < len; j++ )
532                row[j] = -_y[j]*row[j];
533        }
534    }
535    return row;
536}
537
538
539float* CvSVMSolver::get_row_one_class( int, float* row, float*, bool )
540{
541    return row;
542}
543
544
545float* CvSVMSolver::get_row_svr( int i, float* row, float* dst, bool )
546{
547    int j, len = sample_count;
548    Qfloat* dst_pos = dst;
549    Qfloat* dst_neg = dst + len;
550    if( i >= len )
551    {
552        Qfloat* temp;
553        CV_SWAP( dst_pos, dst_neg, temp );
554    }
555
556    for( j = 0; j < len; j++ )
557    {
558        Qfloat t = row[j];
559        dst_pos[j] = t;
560        dst_neg[j] = -t;
561    }
562    return dst;
563}
564
565
566
567float* CvSVMSolver::get_row( int i, float* dst )
568{
569    bool existed = false;
570    float* row = get_row_base( i, &existed );
571    return (this->*get_row_func)( i, row, dst, existed );
572}
573
574
575#undef is_upper_bound
576#define is_upper_bound(i) (alpha_status[i] > 0)
577
578#undef is_lower_bound
579#define is_lower_bound(i) (alpha_status[i] < 0)
580
581#undef is_free
582#define is_free(i) (alpha_status[i] == 0)
583
584#undef get_C
585#define get_C(i) (C[y[i]>0])
586
587#undef update_alpha_status
588#define update_alpha_status(i) \
589    alpha_status[i] = (schar)(alpha[i] >= get_C(i) ? 1 : alpha[i] <= 0 ? -1 : 0)
590
591#undef reconstruct_gradient
592#define reconstruct_gradient() /* empty for now */
593
594
595bool CvSVMSolver::solve_generic( CvSVMSolutionInfo& si )
596{
597    int iter = 0;
598    int i, j, k;
599
600    // 1. initialize gradient and alpha status
601    for( i = 0; i < alpha_count; i++ )
602    {
603        update_alpha_status(i);
604        G[i] = b[i];
605        if( fabs(G[i]) > 1e200 )
606            return false;
607    }
608
609    for( i = 0; i < alpha_count; i++ )
610    {
611        if( !is_lower_bound(i) )
612        {
613            const Qfloat *Q_i = get_row( i, buf[0] );
614            double alpha_i = alpha[i];
615
616            for( j = 0; j < alpha_count; j++ )
617                G[j] += alpha_i*Q_i[j];
618        }
619    }
620
621    // 2. optimization loop
622    for(;;)
623    {
624        const Qfloat *Q_i, *Q_j;
625        double C_i, C_j;
626        double old_alpha_i, old_alpha_j, alpha_i, alpha_j;
627        double delta_alpha_i, delta_alpha_j;
628
629#ifdef _DEBUG
630        for( i = 0; i < alpha_count; i++ )
631        {
632            if( fabs(G[i]) > 1e+300 )
633                return false;
634
635            if( fabs(alpha[i]) > 1e16 )
636                return false;
637        }
638#endif
639
640        if( (this->*select_working_set_func)( i, j ) != 0 || iter++ >= max_iter )
641            break;
642
643        Q_i = get_row( i, buf[0] );
644        Q_j = get_row( j, buf[1] );
645
646        C_i = get_C(i);
647        C_j = get_C(j);
648
649        alpha_i = old_alpha_i = alpha[i];
650        alpha_j = old_alpha_j = alpha[j];
651
652        if( y[i] != y[j] )
653        {
654            double denom = Q_i[i]+Q_j[j]+2*Q_i[j];
655            double delta = (-G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
656            double diff = alpha_i - alpha_j;
657            alpha_i += delta;
658            alpha_j += delta;
659
660            if( diff > 0 && alpha_j < 0 )
661            {
662                alpha_j = 0;
663                alpha_i = diff;
664            }
665            else if( diff <= 0 && alpha_i < 0 )
666            {
667                alpha_i = 0;
668                alpha_j = -diff;
669            }
670
671            if( diff > C_i - C_j && alpha_i > C_i )
672            {
673                alpha_i = C_i;
674                alpha_j = C_i - diff;
675            }
676            else if( diff <= C_i - C_j && alpha_j > C_j )
677            {
678                alpha_j = C_j;
679                alpha_i = C_j + diff;
680            }
681        }
682        else
683        {
684            double denom = Q_i[i]+Q_j[j]-2*Q_i[j];
685            double delta = (G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
686            double sum = alpha_i + alpha_j;
687            alpha_i -= delta;
688            alpha_j += delta;
689
690            if( sum > C_i && alpha_i > C_i )
691            {
692                alpha_i = C_i;
693                alpha_j = sum - C_i;
694            }
695            else if( sum <= C_i && alpha_j < 0)
696            {
697                alpha_j = 0;
698                alpha_i = sum;
699            }
700
701            if( sum > C_j && alpha_j > C_j )
702            {
703                alpha_j = C_j;
704                alpha_i = sum - C_j;
705            }
706            else if( sum <= C_j && alpha_i < 0 )
707            {
708                alpha_i = 0;
709                alpha_j = sum;
710            }
711        }
712
713        // update alpha
714        alpha[i] = alpha_i;
715        alpha[j] = alpha_j;
716        update_alpha_status(i);
717        update_alpha_status(j);
718
719        // update G
720        delta_alpha_i = alpha_i - old_alpha_i;
721        delta_alpha_j = alpha_j - old_alpha_j;
722
723        for( k = 0; k < alpha_count; k++ )
724            G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
725    }
726
727    // calculate rho
728    (this->*calc_rho_func)( si.rho, si.r );
729
730    // calculate objective value
731    for( i = 0, si.obj = 0; i < alpha_count; i++ )
732        si.obj += alpha[i] * (G[i] + b[i]);
733
734    si.obj *= 0.5;
735
736    si.upper_bound_p = C[1];
737    si.upper_bound_n = C[0];
738
739    return true;
740}
741
742
743// return 1 if already optimal, return 0 otherwise
744bool
745CvSVMSolver::select_working_set( int& out_i, int& out_j )
746{
747    // return i,j which maximize -grad(f)^T d , under constraint
748    // if alpha_i == C, d != +1
749    // if alpha_i == 0, d != -1
750    double Gmax1 = -DBL_MAX;        // max { -grad(f)_i * d | y_i*d = +1 }
751    int Gmax1_idx = -1;
752
753    double Gmax2 = -DBL_MAX;        // max { -grad(f)_i * d | y_i*d = -1 }
754    int Gmax2_idx = -1;
755
756    int i;
757
758    for( i = 0; i < alpha_count; i++ )
759    {
760        double t;
761
762        if( y[i] > 0 )    // y = +1
763        {
764            if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )  // d = +1
765            {
766                Gmax1 = t;
767                Gmax1_idx = i;
768            }
769            if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )  // d = -1
770            {
771                Gmax2 = t;
772                Gmax2_idx = i;
773            }
774        }
775        else        // y = -1
776        {
777            if( !is_upper_bound(i) && (t = -G[i]) > Gmax2 )  // d = +1
778            {
779                Gmax2 = t;
780                Gmax2_idx = i;
781            }
782            if( !is_lower_bound(i) && (t = G[i]) > Gmax1 )  // d = -1
783            {
784                Gmax1 = t;
785                Gmax1_idx = i;
786            }
787        }
788    }
789
790    out_i = Gmax1_idx;
791    out_j = Gmax2_idx;
792
793    return Gmax1 + Gmax2 < eps;
794}
795
796
797void
798CvSVMSolver::calc_rho( double& rho, double& r )
799{
800    int i, nr_free = 0;
801    double ub = DBL_MAX, lb = -DBL_MAX, sum_free = 0;
802
803    for( i = 0; i < alpha_count; i++ )
804    {
805        double yG = y[i]*G[i];
806
807        if( is_lower_bound(i) )
808        {
809            if( y[i] > 0 )
810                ub = MIN(ub,yG);
811            else
812                lb = MAX(lb,yG);
813        }
814        else if( is_upper_bound(i) )
815        {
816            if( y[i] < 0)
817                ub = MIN(ub,yG);
818            else
819                lb = MAX(lb,yG);
820        }
821        else
822        {
823            ++nr_free;
824            sum_free += yG;
825        }
826    }
827
828    rho = nr_free > 0 ? sum_free/nr_free : (ub + lb)*0.5;
829    r = 0;
830}
831
832
833bool
834CvSVMSolver::select_working_set_nu_svm( int& out_i, int& out_j )
835{
836    // return i,j which maximize -grad(f)^T d , under constraint
837    // if alpha_i == C, d != +1
838    // if alpha_i == 0, d != -1
839    double Gmax1 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = +1, d = +1 }
840    int Gmax1_idx = -1;
841
842    double Gmax2 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = +1, d = -1 }
843    int Gmax2_idx = -1;
844
845    double Gmax3 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = -1, d = +1 }
846    int Gmax3_idx = -1;
847
848    double Gmax4 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = -1, d = -1 }
849    int Gmax4_idx = -1;
850
851    int i;
852
853    for( i = 0; i < alpha_count; i++ )
854    {
855        double t;
856
857        if( y[i] > 0 )    // y == +1
858        {
859            if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )  // d = +1
860            {
861                Gmax1 = t;
862                Gmax1_idx = i;
863            }
864            if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )  // d = -1
865            {
866                Gmax2 = t;
867                Gmax2_idx = i;
868            }
869        }
870        else        // y == -1
871        {
872            if( !is_upper_bound(i) && (t = -G[i]) > Gmax3 )  // d = +1
873            {
874                Gmax3 = t;
875                Gmax3_idx = i;
876            }
877            if( !is_lower_bound(i) && (t = G[i]) > Gmax4 )  // d = -1
878            {
879                Gmax4 = t;
880                Gmax4_idx = i;
881            }
882        }
883    }
884
885    if( MAX(Gmax1 + Gmax2, Gmax3 + Gmax4) < eps )
886        return 1;
887
888    if( Gmax1 + Gmax2 > Gmax3 + Gmax4 )
889    {
890        out_i = Gmax1_idx;
891        out_j = Gmax2_idx;
892    }
893    else
894    {
895        out_i = Gmax3_idx;
896        out_j = Gmax4_idx;
897    }
898    return 0;
899}
900
901
902void
903CvSVMSolver::calc_rho_nu_svm( double& rho, double& r )
904{
905    int nr_free1 = 0, nr_free2 = 0;
906    double ub1 = DBL_MAX, ub2 = DBL_MAX;
907    double lb1 = -DBL_MAX, lb2 = -DBL_MAX;
908    double sum_free1 = 0, sum_free2 = 0;
909    double r1, r2;
910
911    int i;
912
913    for( i = 0; i < alpha_count; i++ )
914    {
915        double G_i = G[i];
916        if( y[i] > 0 )
917        {
918            if( is_lower_bound(i) )
919                ub1 = MIN( ub1, G_i );
920            else if( is_upper_bound(i) )
921                lb1 = MAX( lb1, G_i );
922            else
923            {
924                ++nr_free1;
925                sum_free1 += G_i;
926            }
927        }
928        else
929        {
930            if( is_lower_bound(i) )
931                ub2 = MIN( ub2, G_i );
932            else if( is_upper_bound(i) )
933                lb2 = MAX( lb2, G_i );
934            else
935            {
936                ++nr_free2;
937                sum_free2 += G_i;
938            }
939        }
940    }
941
942    r1 = nr_free1 > 0 ? sum_free1/nr_free1 : (ub1 + lb1)*0.5;
943    r2 = nr_free2 > 0 ? sum_free2/nr_free2 : (ub2 + lb2)*0.5;
944
945    rho = (r1 - r2)*0.5;
946    r = (r1 + r2)*0.5;
947}
948
949
950/*
951///////////////////////// construct and solve various formulations ///////////////////////
952*/
953
954bool CvSVMSolver::solve_c_svc( int _sample_count, int _var_count, const float** _samples, schar* _y,
955                               double _Cp, double _Cn, CvMemStorage* _storage,
956                               CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
957{
958    int i;
959
960    if( !create( _sample_count, _var_count, _samples, _y, _sample_count,
961                 _alpha, _Cp, _Cn, _storage, _kernel, &CvSVMSolver::get_row_svc,
962                 &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
963        return false;
964
965    for( i = 0; i < sample_count; i++ )
966    {
967        alpha[i] = 0;
968        b[i] = -1;
969    }
970
971    if( !solve_generic( _si ))
972        return false;
973
974    for( i = 0; i < sample_count; i++ )
975        alpha[i] *= y[i];
976
977    return true;
978}
979
980
981bool CvSVMSolver::solve_nu_svc( int _sample_count, int _var_count, const float** _samples, schar* _y,
982                                CvMemStorage* _storage, CvSVMKernel* _kernel,
983                                double* _alpha, CvSVMSolutionInfo& _si )
984{
985    int i;
986    double sum_pos, sum_neg, inv_r;
987
988    if( !create( _sample_count, _var_count, _samples, _y, _sample_count,
989                 _alpha, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_svc,
990                 &CvSVMSolver::select_working_set_nu_svm, &CvSVMSolver::calc_rho_nu_svm ))
991        return false;
992
993    sum_pos = kernel->params->nu * sample_count * 0.5;
994    sum_neg = kernel->params->nu * sample_count * 0.5;
995
996    for( i = 0; i < sample_count; i++ )
997    {
998        if( y[i] > 0 )
999        {
1000            alpha[i] = MIN(1.0, sum_pos);
1001            sum_pos -= alpha[i];
1002        }
1003        else
1004        {
1005            alpha[i] = MIN(1.0, sum_neg);
1006            sum_neg -= alpha[i];
1007        }
1008        b[i] = 0;
1009    }
1010
1011    if( !solve_generic( _si ))
1012        return false;
1013
1014    inv_r = 1./_si.r;
1015
1016    for( i = 0; i < sample_count; i++ )
1017        alpha[i] *= y[i]*inv_r;
1018
1019    _si.rho *= inv_r;
1020    _si.obj *= (inv_r*inv_r);
1021    _si.upper_bound_p = inv_r;
1022    _si.upper_bound_n = inv_r;
1023
1024    return true;
1025}
1026
1027
1028bool CvSVMSolver::solve_one_class( int _sample_count, int _var_count, const float** _samples,
1029                                   CvMemStorage* _storage, CvSVMKernel* _kernel,
1030                                   double* _alpha, CvSVMSolutionInfo& _si )
1031{
1032    int i, n;
1033    double nu = _kernel->params->nu;
1034
1035    if( !create( _sample_count, _var_count, _samples, 0, _sample_count,
1036                 _alpha, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_one_class,
1037                 &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
1038        return false;
1039
1040    y = (schar*)cvMemStorageAlloc( storage, sample_count*sizeof(y[0]) );
1041    n = cvRound( nu*sample_count );
1042
1043    for( i = 0; i < sample_count; i++ )
1044    {
1045        y[i] = 1;
1046        b[i] = 0;
1047        alpha[i] = i < n ? 1 : 0;
1048    }
1049
1050    if( n < sample_count )
1051        alpha[n] = nu * sample_count - n;
1052    else
1053        alpha[n-1] = nu * sample_count - (n-1);
1054
1055    return solve_generic(_si);
1056}
1057
1058
1059bool CvSVMSolver::solve_eps_svr( int _sample_count, int _var_count, const float** _samples,
1060                                 const float* _y, CvMemStorage* _storage,
1061                                 CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
1062{
1063    int i;
1064    double p = _kernel->params->p, C = _kernel->params->C;
1065
1066    if( !create( _sample_count, _var_count, _samples, 0,
1067                 _sample_count*2, 0, C, C, _storage, _kernel, &CvSVMSolver::get_row_svr,
1068                 &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
1069        return false;
1070
1071    y = (schar*)cvMemStorageAlloc( storage, sample_count*2*sizeof(y[0]) );
1072    alpha = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha[0]) );
1073
1074    for( i = 0; i < sample_count; i++ )
1075    {
1076        alpha[i] = 0;
1077        b[i] = p - _y[i];
1078        y[i] = 1;
1079
1080        alpha[i+sample_count] = 0;
1081        b[i+sample_count] = p + _y[i];
1082        y[i+sample_count] = -1;
1083    }
1084
1085    if( !solve_generic( _si ))
1086        return false;
1087
1088    for( i = 0; i < sample_count; i++ )
1089        _alpha[i] = alpha[i] - alpha[i+sample_count];
1090
1091    return true;
1092}
1093
1094
1095bool CvSVMSolver::solve_nu_svr( int _sample_count, int _var_count, const float** _samples,
1096                                const float* _y, CvMemStorage* _storage,
1097                                CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
1098{
1099    int i;
1100    double C = _kernel->params->C, sum;
1101
1102    if( !create( _sample_count, _var_count, _samples, 0,
1103                 _sample_count*2, 0, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_svr,
1104                 &CvSVMSolver::select_working_set_nu_svm, &CvSVMSolver::calc_rho_nu_svm ))
1105        return false;
1106
1107    y = (schar*)cvMemStorageAlloc( storage, sample_count*2*sizeof(y[0]) );
1108    alpha = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha[0]) );
1109    sum = C * _kernel->params->nu * sample_count * 0.5;
1110
1111    for( i = 0; i < sample_count; i++ )
1112    {
1113        alpha[i] = alpha[i + sample_count] = MIN(sum, C);
1114        sum -= alpha[i];
1115
1116        b[i] = -_y[i];
1117        y[i] = 1;
1118
1119        b[i + sample_count] = _y[i];
1120        y[i + sample_count] = -1;
1121    }
1122
1123    if( !solve_generic( _si ))
1124        return false;
1125
1126    for( i = 0; i < sample_count; i++ )
1127        _alpha[i] = alpha[i] - alpha[i+sample_count];
1128
1129    return true;
1130}
1131
1132
1133//////////////////////////////////////////////////////////////////////////////////////////
1134
1135CvSVM::CvSVM()
1136{
1137    decision_func = 0;
1138    class_labels = 0;
1139    class_weights = 0;
1140    storage = 0;
1141    var_idx = 0;
1142    kernel = 0;
1143    solver = 0;
1144    default_model_name = "my_svm";
1145
1146    clear();
1147}
1148
1149
1150CvSVM::~CvSVM()
1151{
1152    clear();
1153}
1154
1155
1156void CvSVM::clear()
1157{
1158    cvFree( &decision_func );
1159    cvReleaseMat( &class_labels );
1160    cvReleaseMat( &class_weights );
1161    cvReleaseMemStorage( &storage );
1162    cvReleaseMat( &var_idx );
1163    delete kernel;
1164    delete solver;
1165    kernel = 0;
1166    solver = 0;
1167    var_all = 0;
1168    sv = 0;
1169    sv_total = 0;
1170}
1171
1172
1173CvSVM::CvSVM( const CvMat* _train_data, const CvMat* _responses,
1174    const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params )
1175{
1176    decision_func = 0;
1177    class_labels = 0;
1178    class_weights = 0;
1179    storage = 0;
1180    var_idx = 0;
1181    kernel = 0;
1182    solver = 0;
1183    default_model_name = "my_svm";
1184
1185    train( _train_data, _responses, _var_idx, _sample_idx, _params );
1186}
1187
1188
1189int CvSVM::get_support_vector_count() const
1190{
1191    return sv_total;
1192}
1193
1194
1195const float* CvSVM::get_support_vector(int i) const
1196{
1197    return sv && (unsigned)i < (unsigned)sv_total ? sv[i] : 0;
1198}
1199
1200
1201bool CvSVM::set_params( const CvSVMParams& _params )
1202{
1203    bool ok = false;
1204
1205    CV_FUNCNAME( "CvSVM::set_params" );
1206
1207    __BEGIN__;
1208
1209    int kernel_type, svm_type;
1210
1211    params = _params;
1212
1213    kernel_type = params.kernel_type;
1214    svm_type = params.svm_type;
1215
1216    if( kernel_type != LINEAR && kernel_type != POLY &&
1217        kernel_type != SIGMOID && kernel_type != RBF )
1218        CV_ERROR( CV_StsBadArg, "Unknown/unsupported kernel type" );
1219
1220    if( kernel_type == LINEAR )
1221        params.gamma = 1;
1222    else if( params.gamma <= 0 )
1223        CV_ERROR( CV_StsOutOfRange, "gamma parameter of the kernel must be positive" );
1224
1225    if( kernel_type != SIGMOID && kernel_type != POLY )
1226        params.coef0 = 0;
1227    else if( params.coef0 < 0 )
1228        CV_ERROR( CV_StsOutOfRange, "The kernel parameter <coef0> must be positive or zero" );
1229
1230    if( kernel_type != POLY )
1231        params.degree = 0;
1232    else if( params.degree <= 0 )
1233        CV_ERROR( CV_StsOutOfRange, "The kernel parameter <degree> must be positive" );
1234
1235    if( svm_type != C_SVC && svm_type != NU_SVC &&
1236        svm_type != ONE_CLASS && svm_type != EPS_SVR &&
1237        svm_type != NU_SVR )
1238        CV_ERROR( CV_StsBadArg, "Unknown/unsupported SVM type" );
1239
1240    if( svm_type == ONE_CLASS || svm_type == NU_SVC )
1241        params.C = 0;
1242    else if( params.C <= 0 )
1243        CV_ERROR( CV_StsOutOfRange, "The parameter C must be positive" );
1244
1245    if( svm_type == C_SVC || svm_type == EPS_SVR )
1246        params.nu = 0;
1247    else if( params.nu <= 0 || params.nu >= 1 )
1248        CV_ERROR( CV_StsOutOfRange, "The parameter nu must be between 0 and 1" );
1249
1250    if( svm_type != EPS_SVR )
1251        params.p = 0;
1252    else if( params.p <= 0 )
1253        CV_ERROR( CV_StsOutOfRange, "The parameter p must be positive" );
1254
1255    if( svm_type != C_SVC )
1256        params.class_weights = 0;
1257
1258    params.term_crit = cvCheckTermCriteria( params.term_crit, DBL_EPSILON, INT_MAX );
1259    params.term_crit.epsilon = MAX( params.term_crit.epsilon, DBL_EPSILON );
1260    ok = true;
1261
1262    __END__;
1263
1264    return ok;
1265}
1266
1267
1268
1269void CvSVM::create_kernel()
1270{
1271    kernel = new CvSVMKernel(&params,0);
1272}
1273
1274
1275void CvSVM::create_solver( )
1276{
1277    solver = new CvSVMSolver;
1278}
1279
1280
1281// switching function
1282bool CvSVM::train1( int sample_count, int var_count, const float** samples,
1283                    const void* _responses, double Cp, double Cn,
1284                    CvMemStorage* _storage, double* alpha, double& rho )
1285{
1286    bool ok = false;
1287
1288    //CV_FUNCNAME( "CvSVM::train1" );
1289
1290    __BEGIN__;
1291
1292    CvSVMSolutionInfo si;
1293    int svm_type = params.svm_type;
1294
1295    si.rho = 0;
1296
1297    ok = svm_type == C_SVC ? solver->solve_c_svc( sample_count, var_count, samples, (schar*)_responses,
1298                                                  Cp, Cn, _storage, kernel, alpha, si ) :
1299         svm_type == NU_SVC ? solver->solve_nu_svc( sample_count, var_count, samples, (schar*)_responses,
1300                                                    _storage, kernel, alpha, si ) :
1301         svm_type == ONE_CLASS ? solver->solve_one_class( sample_count, var_count, samples,
1302                                                          _storage, kernel, alpha, si ) :
1303         svm_type == EPS_SVR ? solver->solve_eps_svr( sample_count, var_count, samples, (float*)_responses,
1304                                                      _storage, kernel, alpha, si ) :
1305         svm_type == NU_SVR ? solver->solve_nu_svr( sample_count, var_count, samples, (float*)_responses,
1306                                                    _storage, kernel, alpha, si ) : false;
1307
1308    rho = si.rho;
1309
1310    __END__;
1311
1312    return ok;
1313}
1314
1315
1316bool CvSVM::do_train( int svm_type, int sample_count, int var_count, const float** samples,
1317                    const CvMat* responses, CvMemStorage* temp_storage, double* alpha )
1318{
1319    bool ok = false;
1320
1321    CV_FUNCNAME( "CvSVM::do_train" );
1322
1323    __BEGIN__;
1324
1325    CvSVMDecisionFunc* df = 0;
1326    const int sample_size = var_count*sizeof(samples[0][0]);
1327    int i, j, k;
1328
1329    if( svm_type == ONE_CLASS || svm_type == EPS_SVR || svm_type == NU_SVR )
1330    {
1331        int sv_count = 0;
1332
1333        CV_CALL( decision_func = df =
1334            (CvSVMDecisionFunc*)cvAlloc( sizeof(df[0]) ));
1335
1336        df->rho = 0;
1337        if( !train1( sample_count, var_count, samples, svm_type == ONE_CLASS ? 0 :
1338            responses->data.i, 0, 0, temp_storage, alpha, df->rho ))
1339            EXIT;
1340
1341        for( i = 0; i < sample_count; i++ )
1342            sv_count += fabs(alpha[i]) > 0;
1343
1344        sv_total = df->sv_count = sv_count;
1345        CV_CALL( df->alpha = (double*)cvMemStorageAlloc( storage, sv_count*sizeof(df->alpha[0])) );
1346        CV_CALL( sv = (float**)cvMemStorageAlloc( storage, sv_count*sizeof(sv[0])));
1347
1348        for( i = k = 0; i < sample_count; i++ )
1349        {
1350            if( fabs(alpha[i]) > 0 )
1351            {
1352                CV_CALL( sv[k] = (float*)cvMemStorageAlloc( storage, sample_size ));
1353                memcpy( sv[k], samples[i], sample_size );
1354                df->alpha[k++] = alpha[i];
1355            }
1356        }
1357    }
1358    else
1359    {
1360        int class_count = class_labels->cols;
1361        int* sv_tab = 0;
1362        const float** temp_samples = 0;
1363        int* class_ranges = 0;
1364        schar* temp_y = 0;
1365        assert( svm_type == CvSVM::C_SVC || svm_type == CvSVM::NU_SVC );
1366
1367        if( svm_type == CvSVM::C_SVC && params.class_weights )
1368        {
1369            const CvMat* cw = params.class_weights;
1370
1371            if( !CV_IS_MAT(cw) || cw->cols != 1 && cw->rows != 1 ||
1372                cw->rows + cw->cols - 1 != class_count ||
1373                CV_MAT_TYPE(cw->type) != CV_32FC1 && CV_MAT_TYPE(cw->type) != CV_64FC1 )
1374                CV_ERROR( CV_StsBadArg, "params.class_weights must be 1d floating-point vector "
1375                    "containing as many elements as the number of classes" );
1376
1377            CV_CALL( class_weights = cvCreateMat( cw->rows, cw->cols, CV_64F ));
1378            CV_CALL( cvConvert( cw, class_weights ));
1379            CV_CALL( cvScale( class_weights, class_weights, params.C ));
1380        }
1381
1382        CV_CALL( decision_func = df = (CvSVMDecisionFunc*)cvAlloc(
1383            (class_count*(class_count-1)/2)*sizeof(df[0])));
1384
1385        CV_CALL( sv_tab = (int*)cvMemStorageAlloc( temp_storage, sample_count*sizeof(sv_tab[0]) ));
1386        memset( sv_tab, 0, sample_count*sizeof(sv_tab[0]) );
1387        CV_CALL( class_ranges = (int*)cvMemStorageAlloc( temp_storage,
1388                            (class_count + 1)*sizeof(class_ranges[0])));
1389        CV_CALL( temp_samples = (const float**)cvMemStorageAlloc( temp_storage,
1390                            sample_count*sizeof(temp_samples[0])));
1391        CV_CALL( temp_y = (schar*)cvMemStorageAlloc( temp_storage, sample_count));
1392
1393        class_ranges[class_count] = 0;
1394        cvSortSamplesByClasses( samples, responses, class_ranges, 0 );
1395        //check that while cross-validation there were the samples from all the classes
1396        if( class_ranges[class_count] <= 0 )
1397            CV_ERROR( CV_StsBadArg, "While cross-validation one or more of the classes have "
1398            "been fell out of the sample. Try to enlarge <CvSVMParams::k_fold>" );
1399
1400        if( svm_type == NU_SVC )
1401        {
1402            // check if nu is feasible
1403            for(i = 0; i < class_count; i++ )
1404            {
1405                int ci = class_ranges[i+1] - class_ranges[i];
1406                for( j = i+1; j< class_count; j++ )
1407                {
1408                    int cj = class_ranges[j+1] - class_ranges[j];
1409                    if( params.nu*(ci + cj)*0.5 > MIN( ci, cj ) )
1410                    {
1411                        // !!!TODO!!! add some diagnostic
1412                        EXIT; // exit immediately; will release the model and return NULL pointer
1413                    }
1414                }
1415            }
1416        }
1417
1418        // train n*(n-1)/2 classifiers
1419        for( i = 0; i < class_count; i++ )
1420        {
1421            for( j = i+1; j < class_count; j++, df++ )
1422            {
1423                int si = class_ranges[i], ci = class_ranges[i+1] - si;
1424                int sj = class_ranges[j], cj = class_ranges[j+1] - sj;
1425                double Cp = params.C, Cn = Cp;
1426                int k1 = 0, sv_count = 0;
1427
1428                for( k = 0; k < ci; k++ )
1429                {
1430                    temp_samples[k] = samples[si + k];
1431                    temp_y[k] = 1;
1432                }
1433
1434                for( k = 0; k < cj; k++ )
1435                {
1436                    temp_samples[ci + k] = samples[sj + k];
1437                    temp_y[ci + k] = -1;
1438                }
1439
1440                if( class_weights )
1441                {
1442                    Cp = class_weights->data.db[i];
1443                    Cn = class_weights->data.db[j];
1444                }
1445
1446                if( !train1( ci + cj, var_count, temp_samples, temp_y,
1447                             Cp, Cn, temp_storage, alpha, df->rho ))
1448                    EXIT;
1449
1450                for( k = 0; k < ci + cj; k++ )
1451                    sv_count += fabs(alpha[k]) > 0;
1452
1453                df->sv_count = sv_count;
1454
1455                CV_CALL( df->alpha = (double*)cvMemStorageAlloc( temp_storage,
1456                                                sv_count*sizeof(df->alpha[0])));
1457                CV_CALL( df->sv_index = (int*)cvMemStorageAlloc( temp_storage,
1458                                                sv_count*sizeof(df->sv_index[0])));
1459
1460                for( k = 0; k < ci; k++ )
1461                {
1462                    if( fabs(alpha[k]) > 0 )
1463                    {
1464                        sv_tab[si + k] = 1;
1465                        df->sv_index[k1] = si + k;
1466                        df->alpha[k1++] = alpha[k];
1467                    }
1468                }
1469
1470                for( k = 0; k < cj; k++ )
1471                {
1472                    if( fabs(alpha[ci + k]) > 0 )
1473                    {
1474                        sv_tab[sj + k] = 1;
1475                        df->sv_index[k1] = sj + k;
1476                        df->alpha[k1++] = alpha[ci + k];
1477                    }
1478                }
1479            }
1480        }
1481
1482        // allocate support vectors and initialize sv_tab
1483        for( i = 0, k = 0; i < sample_count; i++ )
1484        {
1485            if( sv_tab[i] )
1486                sv_tab[i] = ++k;
1487        }
1488
1489        sv_total = k;
1490        CV_CALL( sv = (float**)cvMemStorageAlloc( storage, sv_total*sizeof(sv[0])));
1491
1492        for( i = 0, k = 0; i < sample_count; i++ )
1493        {
1494            if( sv_tab[i] )
1495            {
1496                CV_CALL( sv[k] = (float*)cvMemStorageAlloc( storage, sample_size ));
1497                memcpy( sv[k], samples[i], sample_size );
1498                k++;
1499            }
1500        }
1501
1502        df = (CvSVMDecisionFunc*)decision_func;
1503
1504        // set sv pointers
1505        for( i = 0; i < class_count; i++ )
1506        {
1507            for( j = i+1; j < class_count; j++, df++ )
1508            {
1509                for( k = 0; k < df->sv_count; k++ )
1510                {
1511                    df->sv_index[k] = sv_tab[df->sv_index[k]]-1;
1512                    assert( (unsigned)df->sv_index[k] < (unsigned)sv_total );
1513                }
1514            }
1515        }
1516    }
1517
1518    ok = true;
1519
1520    __END__;
1521
1522    return ok;
1523}
1524
1525bool CvSVM::train( const CvMat* _train_data, const CvMat* _responses,
1526    const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params )
1527{
1528    bool ok = false;
1529    CvMat* responses = 0;
1530    CvMemStorage* temp_storage = 0;
1531    const float** samples = 0;
1532
1533    CV_FUNCNAME( "CvSVM::train" );
1534
1535    __BEGIN__;
1536
1537    int svm_type, sample_count, var_count, sample_size;
1538    int block_size = 1 << 16;
1539    double* alpha;
1540
1541    clear();
1542    CV_CALL( set_params( _params ));
1543
1544    svm_type = _params.svm_type;
1545
1546    /* Prepare training data and related parameters */
1547    CV_CALL( cvPrepareTrainData( "CvSVM::train", _train_data, CV_ROW_SAMPLE,
1548                                 svm_type != CvSVM::ONE_CLASS ? _responses : 0,
1549                                 svm_type == CvSVM::C_SVC ||
1550                                 svm_type == CvSVM::NU_SVC ? CV_VAR_CATEGORICAL :
1551                                 CV_VAR_ORDERED, _var_idx, _sample_idx,
1552                                 false, &samples, &sample_count, &var_count, &var_all,
1553                                 &responses, &class_labels, &var_idx ));
1554
1555
1556    sample_size = var_count*sizeof(samples[0][0]);
1557
1558    // make the storage block size large enough to fit all
1559    // the temporary vectors and output support vectors.
1560    block_size = MAX( block_size, sample_count*(int)sizeof(CvSVMKernelRow));
1561    block_size = MAX( block_size, sample_count*2*(int)sizeof(double) + 1024 );
1562    block_size = MAX( block_size, sample_size*2 + 1024 );
1563
1564    CV_CALL( storage = cvCreateMemStorage(block_size));
1565    CV_CALL( temp_storage = cvCreateChildMemStorage(storage));
1566    CV_CALL( alpha = (double*)cvMemStorageAlloc(temp_storage, sample_count*sizeof(double)));
1567
1568    create_kernel();
1569    create_solver();
1570
1571    if( !do_train( svm_type, sample_count, var_count, samples, responses, temp_storage, alpha ))
1572        EXIT;
1573
1574    ok = true; // model has been trained succesfully
1575
1576    __END__;
1577
1578    delete solver;
1579    solver = 0;
1580    cvReleaseMemStorage( &temp_storage );
1581    cvReleaseMat( &responses );
1582    cvFree( &samples );
1583
1584    if( cvGetErrStatus() < 0 || !ok )
1585        clear();
1586
1587    return ok;
1588}
1589
1590bool CvSVM::train_auto( const CvMat* _train_data, const CvMat* _responses,
1591    const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params, int k_fold,
1592    CvParamGrid C_grid, CvParamGrid gamma_grid, CvParamGrid p_grid,
1593    CvParamGrid nu_grid, CvParamGrid coef_grid, CvParamGrid degree_grid )
1594{
1595    bool ok = false;
1596    CvMat* responses = 0;
1597    CvMat* responses_local = 0;
1598    CvMemStorage* temp_storage = 0;
1599    const float** samples = 0;
1600    const float** samples_local = 0;
1601
1602    CV_FUNCNAME( "CvSVM::train_auto" );
1603    __BEGIN__;
1604
1605    int svm_type, sample_count, var_count, sample_size;
1606    int block_size = 1 << 16;
1607    double* alpha;
1608    int i, k;
1609    CvRNG rng = cvRNG(-1);
1610
1611    // all steps are logarithmic and must be > 1
1612    double degree_step = 10, g_step = 10, coef_step = 10, C_step = 10, nu_step = 10, p_step = 10;
1613    double gamma = 0, C = 0, degree = 0, coef = 0, p = 0, nu = 0;
1614    double best_degree = 0, best_gamma = 0, best_coef = 0, best_C = 0, best_nu = 0, best_p = 0;
1615    float min_error = FLT_MAX, error;
1616
1617    if( _params.svm_type == CvSVM::ONE_CLASS )
1618    {
1619        if(!train( _train_data, _responses, _var_idx, _sample_idx, _params ))
1620            EXIT;
1621        return true;
1622    }
1623
1624    clear();
1625
1626    if( k_fold < 2 )
1627        CV_ERROR( CV_StsBadArg, "Parameter <k_fold> must be > 1" );
1628
1629    CV_CALL(set_params( _params ));
1630    svm_type = _params.svm_type;
1631
1632    // All the parameters except, possibly, <coef0> are positive.
1633    // <coef0> is nonnegative
1634    if( C_grid.step <= 1 )
1635    {
1636        C_grid.min_val = C_grid.max_val = params.C;
1637        C_grid.step = 10;
1638    }
1639    else
1640        CV_CALL(C_grid.check());
1641
1642    if( gamma_grid.step <= 1 )
1643    {
1644        gamma_grid.min_val = gamma_grid.max_val = params.gamma;
1645        gamma_grid.step = 10;
1646    }
1647    else
1648        CV_CALL(gamma_grid.check());
1649
1650    if( p_grid.step <= 1 )
1651    {
1652        p_grid.min_val = p_grid.max_val = params.p;
1653        p_grid.step = 10;
1654    }
1655    else
1656        CV_CALL(p_grid.check());
1657
1658    if( nu_grid.step <= 1 )
1659    {
1660        nu_grid.min_val = nu_grid.max_val = params.nu;
1661        nu_grid.step = 10;
1662    }
1663    else
1664        CV_CALL(nu_grid.check());
1665
1666    if( coef_grid.step <= 1 )
1667    {
1668        coef_grid.min_val = coef_grid.max_val = params.coef0;
1669        coef_grid.step = 10;
1670    }
1671    else
1672        CV_CALL(coef_grid.check());
1673
1674    if( degree_grid.step <= 1 )
1675    {
1676        degree_grid.min_val = degree_grid.max_val = params.degree;
1677        degree_grid.step = 10;
1678    }
1679    else
1680        CV_CALL(degree_grid.check());
1681
1682    // these parameters are not used:
1683    if( params.kernel_type != CvSVM::POLY )
1684        degree_grid.min_val = degree_grid.max_val = params.degree;
1685    if( params.kernel_type == CvSVM::LINEAR )
1686        gamma_grid.min_val = gamma_grid.max_val = params.gamma;
1687    if( params.kernel_type != CvSVM::POLY && params.kernel_type != CvSVM::SIGMOID )
1688        coef_grid.min_val = coef_grid.max_val = params.coef0;
1689    if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS )
1690        C_grid.min_val = C_grid.max_val = params.C;
1691    if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR )
1692        nu_grid.min_val = nu_grid.max_val = params.nu;
1693    if( svm_type != CvSVM::EPS_SVR )
1694        p_grid.min_val = p_grid.max_val = params.p;
1695
1696    CV_ASSERT( g_step > 1 && degree_step > 1 && coef_step > 1);
1697    CV_ASSERT( p_step > 1 && C_step > 1 && nu_step > 1 );
1698
1699    /* Prepare training data and related parameters */
1700    CV_CALL(cvPrepareTrainData( "CvSVM::train_auto", _train_data, CV_ROW_SAMPLE,
1701                                 svm_type != CvSVM::ONE_CLASS ? _responses : 0,
1702                                 svm_type == CvSVM::C_SVC ||
1703                                 svm_type == CvSVM::NU_SVC ? CV_VAR_CATEGORICAL :
1704                                 CV_VAR_ORDERED, _var_idx, _sample_idx,
1705                                 false, &samples, &sample_count, &var_count, &var_all,
1706                                 &responses, &class_labels, &var_idx ));
1707
1708    sample_size = var_count*sizeof(samples[0][0]);
1709
1710    // make the storage block size large enough to fit all
1711    // the temporary vectors and output support vectors.
1712    block_size = MAX( block_size, sample_count*(int)sizeof(CvSVMKernelRow));
1713    block_size = MAX( block_size, sample_count*2*(int)sizeof(double) + 1024 );
1714    block_size = MAX( block_size, sample_size*2 + 1024 );
1715
1716    CV_CALL(storage = cvCreateMemStorage(block_size));
1717    CV_CALL(temp_storage = cvCreateChildMemStorage(storage));
1718    CV_CALL(alpha = (double*)cvMemStorageAlloc(temp_storage, sample_count*sizeof(double)));
1719
1720    create_kernel();
1721    create_solver();
1722
1723    {
1724    const int testset_size = sample_count/k_fold;
1725    const int trainset_size = sample_count - testset_size;
1726    const int last_testset_size = sample_count - testset_size*(k_fold-1);
1727    const int last_trainset_size = sample_count - last_testset_size;
1728    const bool is_regression = (svm_type == EPS_SVR) || (svm_type == NU_SVR);
1729
1730    size_t resp_elem_size = CV_ELEM_SIZE(responses->type);
1731    size_t size = 2*last_trainset_size*sizeof(samples[0]);
1732
1733    samples_local = (const float**) cvAlloc( size );
1734    memset( samples_local, 0, size );
1735
1736    responses_local = cvCreateMat( 1, trainset_size, CV_MAT_TYPE(responses->type) );
1737    cvZero( responses_local );
1738
1739    // randomly permute samples and responses
1740    for( i = 0; i < sample_count; i++ )
1741    {
1742        int i1 = cvRandInt( &rng ) % sample_count;
1743        int i2 = cvRandInt( &rng ) % sample_count;
1744        const float* temp;
1745        float t;
1746        int y;
1747
1748        CV_SWAP( samples[i1], samples[i2], temp );
1749        if( is_regression )
1750            CV_SWAP( responses->data.fl[i1], responses->data.fl[i2], t );
1751        else
1752            CV_SWAP( responses->data.i[i1], responses->data.i[i2], y );
1753    }
1754
1755    C = C_grid.min_val;
1756    do
1757    {
1758      params.C = C;
1759      gamma = gamma_grid.min_val;
1760      do
1761      {
1762        params.gamma = gamma;
1763        p = p_grid.min_val;
1764        do
1765        {
1766          params.p = p;
1767          nu = nu_grid.min_val;
1768          do
1769          {
1770            params.nu = nu;
1771            coef = coef_grid.min_val;
1772            do
1773            {
1774              params.coef0 = coef;
1775              degree = degree_grid.min_val;
1776              do
1777              {
1778                params.degree = degree;
1779
1780                float** test_samples_ptr = (float**)samples;
1781                uchar* true_resp = responses->data.ptr;
1782                int test_size = testset_size;
1783                int train_size = trainset_size;
1784
1785                error = 0;
1786                for( k = 0; k < k_fold; k++ )
1787                {
1788                    memcpy( samples_local, samples, sizeof(samples[0])*test_size*k );
1789                    memcpy( samples_local + test_size*k, test_samples_ptr + test_size,
1790                        sizeof(samples[0])*(sample_count - testset_size*(k+1)) );
1791
1792                    memcpy( responses_local->data.ptr, responses->data.ptr, resp_elem_size*test_size*k );
1793                    memcpy( responses_local->data.ptr + resp_elem_size*test_size*k,
1794                        true_resp + resp_elem_size*test_size,
1795                        sizeof(samples[0])*(sample_count - testset_size*(k+1)) );
1796
1797                    if( k == k_fold - 1 )
1798                    {
1799                        test_size = last_testset_size;
1800                        train_size = last_trainset_size;
1801                        responses_local->cols = last_trainset_size;
1802                    }
1803
1804                    // Train SVM on <train_size> samples
1805                    if( !do_train( svm_type, train_size, var_count,
1806                        (const float**)samples_local, responses_local, temp_storage, alpha ) )
1807                        EXIT;
1808
1809                    // Compute test set error on <test_size> samples
1810                    CvMat s = cvMat( 1, var_count, CV_32FC1 );
1811                    for( i = 0; i < test_size; i++, true_resp += resp_elem_size, test_samples_ptr++ )
1812                    {
1813                        float resp;
1814                        s.data.fl = *test_samples_ptr;
1815                        resp = predict( &s );
1816                        error += is_regression ? powf( resp - *(float*)true_resp, 2 )
1817                            : ((int)resp != *(int*)true_resp);
1818                    }
1819                }
1820                if( min_error > error )
1821                {
1822                    min_error   = error;
1823                    best_degree = degree;
1824                    best_gamma  = gamma;
1825                    best_coef   = coef;
1826                    best_C      = C;
1827                    best_nu     = nu;
1828                    best_p      = p;
1829                }
1830                degree *= degree_grid.step;
1831              }
1832              while( degree < degree_grid.max_val );
1833              coef *= coef_grid.step;
1834            }
1835            while( coef < coef_grid.max_val );
1836            nu *= nu_grid.step;
1837          }
1838          while( nu < nu_grid.max_val );
1839          p *= p_grid.step;
1840        }
1841        while( p < p_grid.max_val );
1842        gamma *= gamma_grid.step;
1843      }
1844      while( gamma < gamma_grid.max_val );
1845      C *= C_grid.step;
1846    }
1847    while( C < C_grid.max_val );
1848    }
1849
1850    min_error /= (float) sample_count;
1851
1852    params.C      = best_C;
1853    params.nu     = best_nu;
1854    params.p      = best_p;
1855    params.gamma  = best_gamma;
1856    params.degree = best_degree;
1857    params.coef0  = best_coef;
1858
1859    CV_CALL(ok = do_train( svm_type, sample_count, var_count, samples, responses, temp_storage, alpha ));
1860
1861    __END__;
1862
1863    delete solver;
1864    solver = 0;
1865    cvReleaseMemStorage( &temp_storage );
1866    cvReleaseMat( &responses );
1867    cvReleaseMat( &responses_local );
1868    cvFree( &samples );
1869    cvFree( &samples_local );
1870
1871    if( cvGetErrStatus() < 0 || !ok )
1872        clear();
1873
1874    return ok;
1875}
1876
1877float CvSVM::predict( const CvMat* sample ) const
1878{
1879    bool local_alloc = 0;
1880    float result = 0;
1881    float* row_sample = 0;
1882    Qfloat* buffer = 0;
1883
1884    CV_FUNCNAME( "CvSVM::predict" );
1885
1886    __BEGIN__;
1887
1888    int class_count;
1889    int var_count, buf_sz;
1890
1891    if( !kernel )
1892        CV_ERROR( CV_StsBadArg, "The SVM should be trained first" );
1893
1894    class_count = class_labels ? class_labels->cols :
1895                  params.svm_type == ONE_CLASS ? 1 : 0;
1896
1897    CV_CALL( cvPreparePredictData( sample, var_all, var_idx,
1898                                   class_count, 0, &row_sample ));
1899
1900    var_count = get_var_count();
1901
1902    buf_sz = sv_total*sizeof(buffer[0]) + (class_count+1)*sizeof(int);
1903    if( buf_sz <= CV_MAX_LOCAL_SIZE )
1904    {
1905        CV_CALL( buffer = (Qfloat*)cvStackAlloc( buf_sz ));
1906        local_alloc = 1;
1907    }
1908    else
1909        CV_CALL( buffer = (Qfloat*)cvAlloc( buf_sz ));
1910
1911    if( params.svm_type == EPS_SVR ||
1912        params.svm_type == NU_SVR ||
1913        params.svm_type == ONE_CLASS )
1914    {
1915        CvSVMDecisionFunc* df = (CvSVMDecisionFunc*)decision_func;
1916        int i, sv_count = df->sv_count;
1917        double sum = -df->rho;
1918
1919        kernel->calc( sv_count, var_count, (const float**)sv, row_sample, buffer );
1920        for( i = 0; i < sv_count; i++ )
1921            sum += buffer[i]*df->alpha[i];
1922
1923        result = params.svm_type == ONE_CLASS ? (float)(sum > 0) : (float)sum;
1924    }
1925    else if( params.svm_type == C_SVC ||
1926             params.svm_type == NU_SVC )
1927    {
1928        CvSVMDecisionFunc* df = (CvSVMDecisionFunc*)decision_func;
1929        int* vote = (int*)(buffer + sv_total);
1930        int i, j, k;
1931
1932        memset( vote, 0, class_count*sizeof(vote[0]));
1933        kernel->calc( sv_total, var_count, (const float**)sv, row_sample, buffer );
1934
1935        for( i = 0; i < class_count; i++ )
1936        {
1937            for( j = i+1; j < class_count; j++, df++ )
1938            {
1939                double sum = -df->rho;
1940                int sv_count = df->sv_count;
1941                for( k = 0; k < sv_count; k++ )
1942                    sum += df->alpha[k]*buffer[df->sv_index[k]];
1943
1944                vote[sum > 0 ? i : j]++;
1945            }
1946        }
1947
1948        for( i = 1, k = 0; i < class_count; i++ )
1949        {
1950            if( vote[i] > vote[k] )
1951                k = i;
1952        }
1953
1954        result = (float)(class_labels->data.i[k]);
1955    }
1956    else
1957        CV_ERROR( CV_StsBadArg, "INTERNAL ERROR: Unknown SVM type, "
1958                                "the SVM structure is probably corrupted" );
1959
1960    __END__;
1961
1962    if( sample && (!CV_IS_MAT(sample) || sample->data.fl != row_sample) )
1963        cvFree( &row_sample );
1964
1965    if( !local_alloc )
1966        cvFree( &buffer );
1967
1968    return result;
1969}
1970
1971
1972void CvSVM::write_params( CvFileStorage* fs )
1973{
1974    //CV_FUNCNAME( "CvSVM::write_params" );
1975
1976    __BEGIN__;
1977
1978    int svm_type = params.svm_type;
1979    int kernel_type = params.kernel_type;
1980
1981    const char* svm_type_str =
1982        svm_type == CvSVM::C_SVC ? "C_SVC" :
1983        svm_type == CvSVM::NU_SVC ? "NU_SVC" :
1984        svm_type == CvSVM::ONE_CLASS ? "ONE_CLASS" :
1985        svm_type == CvSVM::EPS_SVR ? "EPS_SVR" :
1986        svm_type == CvSVM::NU_SVR ? "NU_SVR" : 0;
1987    const char* kernel_type_str =
1988        kernel_type == CvSVM::LINEAR ? "LINEAR" :
1989        kernel_type == CvSVM::POLY ? "POLY" :
1990        kernel_type == CvSVM::RBF ? "RBF" :
1991        kernel_type == CvSVM::SIGMOID ? "SIGMOID" : 0;
1992
1993    if( svm_type_str )
1994        cvWriteString( fs, "svm_type", svm_type_str );
1995    else
1996        cvWriteInt( fs, "svm_type", svm_type );
1997
1998    // save kernel
1999    cvStartWriteStruct( fs, "kernel", CV_NODE_MAP + CV_NODE_FLOW );
2000
2001    if( kernel_type_str )
2002        cvWriteString( fs, "type", kernel_type_str );
2003    else
2004        cvWriteInt( fs, "type", kernel_type );
2005
2006    if( kernel_type == CvSVM::POLY || !kernel_type_str )
2007        cvWriteReal( fs, "degree", params.degree );
2008
2009    if( kernel_type != CvSVM::LINEAR || !kernel_type_str )
2010        cvWriteReal( fs, "gamma", params.gamma );
2011
2012    if( kernel_type == CvSVM::POLY || kernel_type == CvSVM::SIGMOID || !kernel_type_str )
2013        cvWriteReal( fs, "coef0", params.coef0 );
2014
2015    cvEndWriteStruct(fs);
2016
2017    if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR ||
2018        svm_type == CvSVM::NU_SVR || !svm_type_str )
2019        cvWriteReal( fs, "C", params.C );
2020
2021    if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS ||
2022        svm_type == CvSVM::NU_SVR || !svm_type_str )
2023        cvWriteReal( fs, "nu", params.nu );
2024
2025    if( svm_type == CvSVM::EPS_SVR || !svm_type_str )
2026        cvWriteReal( fs, "p", params.p );
2027
2028    cvStartWriteStruct( fs, "term_criteria", CV_NODE_MAP + CV_NODE_FLOW );
2029    if( params.term_crit.type & CV_TERMCRIT_EPS )
2030        cvWriteReal( fs, "epsilon", params.term_crit.epsilon );
2031    if( params.term_crit.type & CV_TERMCRIT_ITER )
2032        cvWriteInt( fs, "iterations", params.term_crit.max_iter );
2033    cvEndWriteStruct( fs );
2034
2035    __END__;
2036}
2037
2038
2039void CvSVM::write( CvFileStorage* fs, const char* name )
2040{
2041    CV_FUNCNAME( "CvSVM::write" );
2042
2043    __BEGIN__;
2044
2045    int i, var_count = get_var_count(), df_count, class_count;
2046    const CvSVMDecisionFunc* df = decision_func;
2047
2048    cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_SVM );
2049
2050    write_params( fs );
2051
2052    cvWriteInt( fs, "var_all", var_all );
2053    cvWriteInt( fs, "var_count", var_count );
2054
2055    class_count = class_labels ? class_labels->cols :
2056                  params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
2057
2058    if( class_count )
2059    {
2060        cvWriteInt( fs, "class_count", class_count );
2061
2062        if( class_labels )
2063            cvWrite( fs, "class_labels", class_labels );
2064
2065        if( class_weights )
2066            cvWrite( fs, "class_weights", class_weights );
2067    }
2068
2069    if( var_idx )
2070        cvWrite( fs, "var_idx", var_idx );
2071
2072    // write the joint collection of support vectors
2073    cvWriteInt( fs, "sv_total", sv_total );
2074    cvStartWriteStruct( fs, "support_vectors", CV_NODE_SEQ );
2075    for( i = 0; i < sv_total; i++ )
2076    {
2077        cvStartWriteStruct( fs, 0, CV_NODE_SEQ + CV_NODE_FLOW );
2078        cvWriteRawData( fs, sv[i], var_count, "f" );
2079        cvEndWriteStruct( fs );
2080    }
2081
2082    cvEndWriteStruct( fs );
2083
2084    // write decision functions
2085    df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
2086    df = decision_func;
2087
2088    cvStartWriteStruct( fs, "decision_functions", CV_NODE_SEQ );
2089    for( i = 0; i < df_count; i++ )
2090    {
2091        int sv_count = df[i].sv_count;
2092        cvStartWriteStruct( fs, 0, CV_NODE_MAP );
2093        cvWriteInt( fs, "sv_count", sv_count );
2094        cvWriteReal( fs, "rho", df[i].rho );
2095        cvStartWriteStruct( fs, "alpha", CV_NODE_SEQ+CV_NODE_FLOW );
2096        cvWriteRawData( fs, df[i].alpha, df[i].sv_count, "d" );
2097        cvEndWriteStruct( fs );
2098        if( class_count > 1 )
2099        {
2100            cvStartWriteStruct( fs, "index", CV_NODE_SEQ+CV_NODE_FLOW );
2101            cvWriteRawData( fs, df[i].sv_index, df[i].sv_count, "i" );
2102            cvEndWriteStruct( fs );
2103        }
2104        else
2105            CV_ASSERT( sv_count == sv_total );
2106        cvEndWriteStruct( fs );
2107    }
2108    cvEndWriteStruct( fs );
2109    cvEndWriteStruct( fs );
2110
2111    __END__;
2112}
2113
2114
2115void CvSVM::read_params( CvFileStorage* fs, CvFileNode* svm_node )
2116{
2117    CV_FUNCNAME( "CvSVM::read_params" );
2118
2119    __BEGIN__;
2120
2121    int svm_type, kernel_type;
2122    CvSVMParams _params;
2123
2124    CvFileNode* tmp_node = cvGetFileNodeByName( fs, svm_node, "svm_type" );
2125    CvFileNode* kernel_node;
2126    if( !tmp_node )
2127        CV_ERROR( CV_StsBadArg, "svm_type tag is not found" );
2128
2129    if( CV_NODE_TYPE(tmp_node->tag) == CV_NODE_INT )
2130        svm_type = cvReadInt( tmp_node, -1 );
2131    else
2132    {
2133        const char* svm_type_str = cvReadString( tmp_node, "" );
2134        svm_type =
2135            strcmp( svm_type_str, "C_SVC" ) == 0 ? CvSVM::C_SVC :
2136            strcmp( svm_type_str, "NU_SVC" ) == 0 ? CvSVM::NU_SVC :
2137            strcmp( svm_type_str, "ONE_CLASS" ) == 0 ? CvSVM::ONE_CLASS :
2138            strcmp( svm_type_str, "EPS_SVR" ) == 0 ? CvSVM::EPS_SVR :
2139            strcmp( svm_type_str, "NU_SVR" ) == 0 ? CvSVM::NU_SVR : -1;
2140
2141        if( svm_type < 0 )
2142            CV_ERROR( CV_StsParseError, "Missing of invalid SVM type" );
2143    }
2144
2145    kernel_node = cvGetFileNodeByName( fs, svm_node, "kernel" );
2146    if( !kernel_node )
2147        CV_ERROR( CV_StsParseError, "SVM kernel tag is not found" );
2148
2149    tmp_node = cvGetFileNodeByName( fs, kernel_node, "type" );
2150    if( !tmp_node )
2151        CV_ERROR( CV_StsParseError, "SVM kernel type tag is not found" );
2152
2153    if( CV_NODE_TYPE(tmp_node->tag) == CV_NODE_INT )
2154        kernel_type = cvReadInt( tmp_node, -1 );
2155    else
2156    {
2157        const char* kernel_type_str = cvReadString( tmp_node, "" );
2158        kernel_type =
2159            strcmp( kernel_type_str, "LINEAR" ) == 0 ? CvSVM::LINEAR :
2160            strcmp( kernel_type_str, "POLY" ) == 0 ? CvSVM::POLY :
2161            strcmp( kernel_type_str, "RBF" ) == 0 ? CvSVM::RBF :
2162            strcmp( kernel_type_str, "SIGMOID" ) == 0 ? CvSVM::SIGMOID : -1;
2163
2164        if( kernel_type < 0 )
2165            CV_ERROR( CV_StsParseError, "Missing of invalid SVM kernel type" );
2166    }
2167
2168    _params.svm_type = svm_type;
2169    _params.kernel_type = kernel_type;
2170    _params.degree = cvReadRealByName( fs, kernel_node, "degree", 0 );
2171    _params.gamma = cvReadRealByName( fs, kernel_node, "gamma", 0 );
2172    _params.coef0 = cvReadRealByName( fs, kernel_node, "coef0", 0 );
2173
2174    _params.C = cvReadRealByName( fs, svm_node, "C", 0 );
2175    _params.nu = cvReadRealByName( fs, svm_node, "nu", 0 );
2176    _params.p = cvReadRealByName( fs, svm_node, "p", 0 );
2177    _params.class_weights = 0;
2178
2179    tmp_node = cvGetFileNodeByName( fs, svm_node, "term_criteria" );
2180    if( tmp_node )
2181    {
2182        _params.term_crit.epsilon = cvReadRealByName( fs, tmp_node, "epsilon", -1. );
2183        _params.term_crit.max_iter = cvReadIntByName( fs, tmp_node, "iterations", -1 );
2184        _params.term_crit.type = (_params.term_crit.epsilon >= 0 ? CV_TERMCRIT_EPS : 0) +
2185                               (_params.term_crit.max_iter >= 0 ? CV_TERMCRIT_ITER : 0);
2186    }
2187    else
2188        _params.term_crit = cvTermCriteria( CV_TERMCRIT_EPS + CV_TERMCRIT_ITER, 1000, FLT_EPSILON );
2189
2190    set_params( _params );
2191
2192    __END__;
2193}
2194
2195
2196void CvSVM::read( CvFileStorage* fs, CvFileNode* svm_node )
2197{
2198    const double not_found_dbl = DBL_MAX;
2199
2200    CV_FUNCNAME( "CvSVM::read" );
2201
2202    __BEGIN__;
2203
2204    int i, var_count, df_count, class_count;
2205    int block_size = 1 << 16, sv_size;
2206    CvFileNode *sv_node, *df_node;
2207    CvSVMDecisionFunc* df;
2208    CvSeqReader reader;
2209
2210    if( !svm_node )
2211        CV_ERROR( CV_StsParseError, "The requested element is not found" );
2212
2213    clear();
2214
2215    // read SVM parameters
2216    read_params( fs, svm_node );
2217
2218    // and top-level data
2219    sv_total = cvReadIntByName( fs, svm_node, "sv_total", -1 );
2220    var_all = cvReadIntByName( fs, svm_node, "var_all", -1 );
2221    var_count = cvReadIntByName( fs, svm_node, "var_count", var_all );
2222    class_count = cvReadIntByName( fs, svm_node, "class_count", 0 );
2223
2224    if( sv_total <= 0 || var_all <= 0 || var_count <= 0 || var_count > var_all || class_count < 0 )
2225        CV_ERROR( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
2226
2227    CV_CALL( class_labels = (CvMat*)cvReadByName( fs, svm_node, "class_labels" ));
2228    CV_CALL( class_weights = (CvMat*)cvReadByName( fs, svm_node, "class_weights" ));
2229    CV_CALL( var_idx = (CvMat*)cvReadByName( fs, svm_node, "comp_idx" ));
2230
2231    if( class_count > 1 && (!class_labels ||
2232        !CV_IS_MAT(class_labels) || class_labels->cols != class_count))
2233        CV_ERROR( CV_StsParseError, "Array of class labels is missing or invalid" );
2234
2235    if( var_count < var_all && (!var_idx || !CV_IS_MAT(var_idx) || var_idx->cols != var_count) )
2236        CV_ERROR( CV_StsParseError, "var_idx array is missing or invalid" );
2237
2238    // read support vectors
2239    sv_node = cvGetFileNodeByName( fs, svm_node, "support_vectors" );
2240    if( !sv_node || !CV_NODE_IS_SEQ(sv_node->tag))
2241        CV_ERROR( CV_StsParseError, "Missing or invalid sequence of support vectors" );
2242
2243    block_size = MAX( block_size, sv_total*(int)sizeof(CvSVMKernelRow));
2244    block_size = MAX( block_size, sv_total*2*(int)sizeof(double));
2245    block_size = MAX( block_size, var_all*(int)sizeof(double));
2246    CV_CALL( storage = cvCreateMemStorage( block_size ));
2247    CV_CALL( sv = (float**)cvMemStorageAlloc( storage,
2248                                sv_total*sizeof(sv[0]) ));
2249
2250    CV_CALL( cvStartReadSeq( sv_node->data.seq, &reader, 0 ));
2251    sv_size = var_count*sizeof(sv[0][0]);
2252
2253    for( i = 0; i < sv_total; i++ )
2254    {
2255        CvFileNode* sv_elem = (CvFileNode*)reader.ptr;
2256        CV_ASSERT( var_count == 1 || (CV_NODE_IS_SEQ(sv_elem->tag) &&
2257                   sv_elem->data.seq->total == var_count) );
2258
2259        CV_CALL( sv[i] = (float*)cvMemStorageAlloc( storage, sv_size ));
2260        CV_CALL( cvReadRawData( fs, sv_elem, sv[i], "f" ));
2261        CV_NEXT_SEQ_ELEM( sv_node->data.seq->elem_size, reader );
2262    }
2263
2264    // read decision functions
2265    df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
2266    df_node = cvGetFileNodeByName( fs, svm_node, "decision_functions" );
2267    if( !df_node || !CV_NODE_IS_SEQ(df_node->tag) ||
2268        df_node->data.seq->total != df_count )
2269        CV_ERROR( CV_StsParseError, "decision_functions is missing or is not a collection "
2270                  "or has a wrong number of elements" );
2271
2272    CV_CALL( df = decision_func = (CvSVMDecisionFunc*)cvAlloc( df_count*sizeof(df[0]) ));
2273    cvStartReadSeq( df_node->data.seq, &reader, 0 );
2274
2275    for( i = 0; i < df_count; i++ )
2276    {
2277        CvFileNode* df_elem = (CvFileNode*)reader.ptr;
2278        CvFileNode* alpha_node = cvGetFileNodeByName( fs, df_elem, "alpha" );
2279
2280        int sv_count = cvReadIntByName( fs, df_elem, "sv_count", -1 );
2281        if( sv_count <= 0 )
2282            CV_ERROR( CV_StsParseError, "sv_count is missing or non-positive" );
2283        df[i].sv_count = sv_count;
2284
2285        df[i].rho = cvReadRealByName( fs, df_elem, "rho", not_found_dbl );
2286        if( fabs(df[i].rho - not_found_dbl) < DBL_EPSILON )
2287            CV_ERROR( CV_StsParseError, "rho is missing" );
2288
2289        if( !alpha_node )
2290            CV_ERROR( CV_StsParseError, "alpha is missing in the decision function" );
2291
2292        CV_CALL( df[i].alpha = (double*)cvMemStorageAlloc( storage,
2293                                        sv_count*sizeof(df[i].alpha[0])));
2294        CV_ASSERT( sv_count == 1 || CV_NODE_IS_SEQ(alpha_node->tag) &&
2295                   alpha_node->data.seq->total == sv_count );
2296        CV_CALL( cvReadRawData( fs, alpha_node, df[i].alpha, "d" ));
2297
2298        if( class_count > 1 )
2299        {
2300            CvFileNode* index_node = cvGetFileNodeByName( fs, df_elem, "index" );
2301            if( !index_node )
2302                CV_ERROR( CV_StsParseError, "index is missing in the decision function" );
2303            CV_CALL( df[i].sv_index = (int*)cvMemStorageAlloc( storage,
2304                                            sv_count*sizeof(df[i].sv_index[0])));
2305            CV_ASSERT( sv_count == 1 || CV_NODE_IS_SEQ(index_node->tag) &&
2306                   index_node->data.seq->total == sv_count );
2307            CV_CALL( cvReadRawData( fs, index_node, df[i].sv_index, "i" ));
2308        }
2309        else
2310            df[i].sv_index = 0;
2311
2312        CV_NEXT_SEQ_ELEM( df_node->data.seq->elem_size, reader );
2313    }
2314
2315    create_kernel();
2316
2317    __END__;
2318}
2319
2320#if 0
2321
2322static void*
2323icvCloneSVM( const void* _src )
2324{
2325    CvSVMModel* dst = 0;
2326
2327    CV_FUNCNAME( "icvCloneSVM" );
2328
2329    __BEGIN__;
2330
2331    const CvSVMModel* src = (const CvSVMModel*)_src;
2332    int var_count, class_count;
2333    int i, sv_total, df_count;
2334    int sv_size;
2335
2336    if( !CV_IS_SVM(src) )
2337        CV_ERROR( !src ? CV_StsNullPtr : CV_StsBadArg, "Input pointer is NULL or invalid" );
2338
2339    // 0. create initial CvSVMModel structure
2340    CV_CALL( dst = icvCreateSVM() );
2341    dst->params = src->params;
2342    dst->params.weight_labels = 0;
2343    dst->params.weights = 0;
2344
2345    dst->var_all = src->var_all;
2346    if( src->class_labels )
2347        dst->class_labels = cvCloneMat( src->class_labels );
2348    if( src->class_weights )
2349        dst->class_weights = cvCloneMat( src->class_weights );
2350    if( src->comp_idx )
2351        dst->comp_idx = cvCloneMat( src->comp_idx );
2352
2353    var_count = src->comp_idx ? src->comp_idx->cols : src->var_all;
2354    class_count = src->class_labels ? src->class_labels->cols :
2355                  src->params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
2356    sv_total = dst->sv_total = src->sv_total;
2357    CV_CALL( dst->storage = cvCreateMemStorage( src->storage->block_size ));
2358    CV_CALL( dst->sv = (float**)cvMemStorageAlloc( dst->storage,
2359                                    sv_total*sizeof(dst->sv[0]) ));
2360
2361    sv_size = var_count*sizeof(dst->sv[0][0]);
2362
2363    for( i = 0; i < sv_total; i++ )
2364    {
2365        CV_CALL( dst->sv[i] = (float*)cvMemStorageAlloc( dst->storage, sv_size ));
2366        memcpy( dst->sv[i], src->sv[i], sv_size );
2367    }
2368
2369    df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
2370
2371    CV_CALL( dst->decision_func = cvAlloc( df_count*sizeof(CvSVMDecisionFunc) ));
2372
2373    for( i = 0; i < df_count; i++ )
2374    {
2375        const CvSVMDecisionFunc *sdf =
2376            (const CvSVMDecisionFunc*)src->decision_func+i;
2377        CvSVMDecisionFunc *ddf =
2378            (CvSVMDecisionFunc*)dst->decision_func+i;
2379        int sv_count = sdf->sv_count;
2380        ddf->sv_count = sv_count;
2381        ddf->rho = sdf->rho;
2382        CV_CALL( ddf->alpha = (double*)cvMemStorageAlloc( dst->storage,
2383                                        sv_count*sizeof(ddf->alpha[0])));
2384        memcpy( ddf->alpha, sdf->alpha, sv_count*sizeof(ddf->alpha[0]));
2385
2386        if( class_count > 1 )
2387        {
2388            CV_CALL( ddf->sv_index = (int*)cvMemStorageAlloc( dst->storage,
2389                                                sv_count*sizeof(ddf->sv_index[0])));
2390            memcpy( ddf->sv_index, sdf->sv_index, sv_count*sizeof(ddf->sv_index[0]));
2391        }
2392        else
2393            ddf->sv_index = 0;
2394    }
2395
2396    __END__;
2397
2398    if( cvGetErrStatus() < 0 && dst )
2399        icvReleaseSVM( &dst );
2400
2401    return dst;
2402}
2403
2404static int icvRegisterSVMType()
2405{
2406    CvTypeInfo info;
2407    memset( &info, 0, sizeof(info) );
2408
2409    info.flags = 0;
2410    info.header_size = sizeof( info );
2411    info.is_instance = icvIsSVM;
2412    info.release = (CvReleaseFunc)icvReleaseSVM;
2413    info.read = icvReadSVM;
2414    info.write = icvWriteSVM;
2415    info.clone = icvCloneSVM;
2416    info.type_name = CV_TYPE_NAME_ML_SVM;
2417    cvRegisterType( &info );
2418
2419    return 1;
2420}
2421
2422
2423static int svm = icvRegisterSVMType();
2424
2425/* The function trains SVM model with optimal parameters, obtained by using cross-validation.
2426The parameters to be estimated should be indicated by setting theirs values to FLT_MAX.
2427The optimal parameters are saved in <model_params> */
2428CV_IMPL CvStatModel*
2429cvTrainSVM_CrossValidation( const CvMat* train_data, int tflag,
2430            const CvMat* responses,
2431            CvStatModelParams* model_params,
2432            const CvStatModelParams* cross_valid_params,
2433            const CvMat* comp_idx,
2434            const CvMat* sample_idx,
2435            const CvParamGrid* degree_grid,
2436            const CvParamGrid* gamma_grid,
2437            const CvParamGrid* coef_grid,
2438            const CvParamGrid* C_grid,
2439            const CvParamGrid* nu_grid,
2440            const CvParamGrid* p_grid )
2441{
2442    CvStatModel* svm = 0;
2443
2444    CV_FUNCNAME("cvTainSVMCrossValidation");
2445    __BEGIN__;
2446
2447    double degree_step = 7,
2448	       g_step      = 15,
2449		   coef_step   = 14,
2450		   C_step      = 20,
2451		   nu_step     = 5,
2452		   p_step      = 7; // all steps must be > 1
2453    double degree_begin = 0.01, degree_end = 2;
2454    double g_begin      = 1e-5, g_end      = 0.5;
2455    double coef_begin   = 0.1,  coef_end   = 300;
2456    double C_begin      = 0.1,  C_end      = 6000;
2457    double nu_begin     = 0.01,  nu_end    = 0.4;
2458    double p_begin      = 0.01, p_end      = 100;
2459
2460    double rate = 0, gamma = 0, C = 0, degree = 0, coef = 0, p = 0, nu = 0;
2461
2462	double best_rate    = 0;
2463    double best_degree  = degree_begin;
2464    double best_gamma   = g_begin;
2465    double best_coef    = coef_begin;
2466	double best_C       = C_begin;
2467	double best_nu      = nu_begin;
2468    double best_p       = p_begin;
2469
2470    CvSVMModelParams svm_params, *psvm_params;
2471    CvCrossValidationParams* cv_params = (CvCrossValidationParams*)cross_valid_params;
2472    int svm_type, kernel;
2473    int is_regression;
2474
2475    if( !model_params )
2476        CV_ERROR( CV_StsBadArg, "" );
2477    if( !cv_params )
2478        CV_ERROR( CV_StsBadArg, "" );
2479
2480    svm_params = *(CvSVMModelParams*)model_params;
2481    psvm_params = (CvSVMModelParams*)model_params;
2482    svm_type = svm_params.svm_type;
2483    kernel = svm_params.kernel_type;
2484
2485    svm_params.degree = svm_params.degree > 0 ? svm_params.degree : 1;
2486    svm_params.gamma = svm_params.gamma > 0 ? svm_params.gamma : 1;
2487    svm_params.coef0 = svm_params.coef0 > 0 ? svm_params.coef0 : 1e-6;
2488    svm_params.C = svm_params.C > 0 ? svm_params.C : 1;
2489    svm_params.nu = svm_params.nu > 0 ? svm_params.nu : 1;
2490    svm_params.p = svm_params.p > 0 ? svm_params.p : 1;
2491
2492    if( degree_grid )
2493    {
2494        if( !(degree_grid->max_val == 0 && degree_grid->min_val == 0 &&
2495              degree_grid->step == 0) )
2496        {
2497            if( degree_grid->min_val > degree_grid->max_val )
2498                CV_ERROR( CV_StsBadArg,
2499                "low bound of grid should be less then the upper one");
2500            if( degree_grid->step <= 1 )
2501                CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2502            degree_begin = degree_grid->min_val;
2503            degree_end   = degree_grid->max_val;
2504            degree_step  = degree_grid->step;
2505        }
2506    }
2507    else
2508        degree_begin = degree_end = svm_params.degree;
2509
2510    if( gamma_grid )
2511    {
2512        if( !(gamma_grid->max_val == 0 && gamma_grid->min_val == 0 &&
2513              gamma_grid->step == 0) )
2514        {
2515            if( gamma_grid->min_val > gamma_grid->max_val )
2516                CV_ERROR( CV_StsBadArg,
2517                "low bound of grid should be less then the upper one");
2518            if( gamma_grid->step <= 1 )
2519                CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2520            g_begin = gamma_grid->min_val;
2521            g_end   = gamma_grid->max_val;
2522            g_step  = gamma_grid->step;
2523        }
2524    }
2525    else
2526        g_begin = g_end = svm_params.gamma;
2527
2528    if( coef_grid )
2529    {
2530        if( !(coef_grid->max_val == 0 && coef_grid->min_val == 0 &&
2531              coef_grid->step == 0) )
2532        {
2533            if( coef_grid->min_val > coef_grid->max_val )
2534                CV_ERROR( CV_StsBadArg,
2535                "low bound of grid should be less then the upper one");
2536            if( coef_grid->step <= 1 )
2537                CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2538            coef_begin = coef_grid->min_val;
2539            coef_end   = coef_grid->max_val;
2540            coef_step  = coef_grid->step;
2541        }
2542    }
2543    else
2544        coef_begin = coef_end = svm_params.coef0;
2545
2546    if( C_grid )
2547    {
2548        if( !(C_grid->max_val == 0 && C_grid->min_val == 0 && C_grid->step == 0))
2549        {
2550            if( C_grid->min_val > C_grid->max_val )
2551                CV_ERROR( CV_StsBadArg,
2552                "low bound of grid should be less then the upper one");
2553            if( C_grid->step <= 1 )
2554                CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2555            C_begin = C_grid->min_val;
2556            C_end   = C_grid->max_val;
2557            C_step  = C_grid->step;
2558        }
2559    }
2560    else
2561        C_begin = C_end = svm_params.C;
2562
2563    if( nu_grid )
2564    {
2565        if(!(nu_grid->max_val == 0 && nu_grid->min_val == 0 && nu_grid->step==0))
2566        {
2567            if( nu_grid->min_val > nu_grid->max_val )
2568                CV_ERROR( CV_StsBadArg,
2569                "low bound of grid should be less then the upper one");
2570            if( nu_grid->step <= 1 )
2571                CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2572            nu_begin = nu_grid->min_val;
2573            nu_end   = nu_grid->max_val;
2574            nu_step  = nu_grid->step;
2575        }
2576    }
2577    else
2578        nu_begin = nu_end = svm_params.nu;
2579
2580    if( p_grid )
2581    {
2582        if( !(p_grid->max_val == 0 && p_grid->min_val == 0 && p_grid->step == 0))
2583        {
2584            if( p_grid->min_val > p_grid->max_val )
2585                CV_ERROR( CV_StsBadArg,
2586                "low bound of grid should be less then the upper one");
2587            if( p_grid->step <= 1 )
2588                CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2589            p_begin = p_grid->min_val;
2590            p_end   = p_grid->max_val;
2591            p_step  = p_grid->step;
2592        }
2593    }
2594    else
2595        p_begin = p_end = svm_params.p;
2596
2597    // these parameters are not used:
2598    if( kernel != CvSVM::POLY )
2599        degree_begin = degree_end = svm_params.degree;
2600
2601   if( kernel == CvSVM::LINEAR )
2602        g_begin = g_end = svm_params.gamma;
2603
2604    if( kernel != CvSVM::POLY && kernel != CvSVM::SIGMOID )
2605        coef_begin = coef_end = svm_params.coef0;
2606
2607    if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS )
2608        C_begin = C_end = svm_params.C;
2609
2610    if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR )
2611        nu_begin = nu_end = svm_params.nu;
2612
2613    if( svm_type != CvSVM::EPS_SVR )
2614        p_begin = p_end = svm_params.p;
2615
2616    is_regression = cv_params->is_regression;
2617    best_rate = is_regression ? FLT_MAX : 0;
2618
2619    assert( g_step > 1 && degree_step > 1 && coef_step > 1);
2620    assert( p_step > 1 && C_step > 1 && nu_step > 1 );
2621
2622    for( degree = degree_begin; degree <= degree_end; degree *= degree_step )
2623    {
2624      svm_params.degree = degree;
2625      //printf("degree = %.3f\n", degree );
2626      for( gamma= g_begin; gamma <= g_end; gamma *= g_step )
2627      {
2628        svm_params.gamma = gamma;
2629        //printf("   gamma = %.3f\n", gamma );
2630        for( coef = coef_begin; coef <= coef_end; coef *= coef_step )
2631        {
2632          svm_params.coef0 = coef;
2633          //printf("      coef = %.3f\n", coef );
2634          for( C = C_begin; C <= C_end; C *= C_step )
2635          {
2636            svm_params.C = C;
2637            //printf("         C = %.3f\n", C );
2638            for( nu = nu_begin; nu <= nu_end; nu *= nu_step )
2639            {
2640              svm_params.nu = nu;
2641              //printf("            nu = %.3f\n", nu );
2642              for( p = p_begin; p <= p_end; p *= p_step )
2643              {
2644                int well;
2645                svm_params.p = p;
2646                //printf("               p = %.3f\n", p );
2647
2648                CV_CALL(rate = cvCrossValidation( train_data, tflag, responses, &cvTrainSVM,
2649                    cross_valid_params, (CvStatModelParams*)&svm_params, comp_idx, sample_idx ));
2650
2651                well =  rate > best_rate && !is_regression || rate < best_rate && is_regression;
2652                if( well || (rate == best_rate && C < best_C) )
2653                {
2654                    best_rate   = rate;
2655                    best_degree = degree;
2656                    best_gamma  = gamma;
2657                    best_coef   = coef;
2658                    best_C      = C;
2659                    best_nu     = nu;
2660                    best_p      = p;
2661                }
2662                //printf("                  rate = %.2f\n", rate );
2663              }
2664            }
2665          }
2666        }
2667      }
2668    }
2669    //printf("The best:\nrate = %.2f%% degree = %f gamma = %f coef = %f c = %f nu = %f p = %f\n",
2670      //  best_rate, best_degree, best_gamma, best_coef, best_C, best_nu, best_p );
2671
2672    psvm_params->C      = best_C;
2673    psvm_params->nu     = best_nu;
2674    psvm_params->p      = best_p;
2675    psvm_params->gamma  = best_gamma;
2676    psvm_params->degree = best_degree;
2677    psvm_params->coef0  = best_coef;
2678
2679    CV_CALL(svm = cvTrainSVM( train_data, tflag, responses, model_params, comp_idx, sample_idx ));
2680
2681    __END__;
2682
2683    return svm;
2684}
2685
2686#endif
2687
2688/* End of file. */
2689
2690