1/*M///////////////////////////////////////////////////////////////////////////////////////
2//
3//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4//
5//  By downloading, copying, installing or using the software you agree to this license.
6//  If you do not agree to this license, do not download, install,
7//  copy or use the software.
8//
9//
10//                        Intel License Agreement
11//                For Open Source Computer Vision Library
12//
13// Copyright (C) 2000, Intel Corporation, all rights reserved.
14// Third party copyrights are property of their respective owners.
15//
16// Redistribution and use in source and binary forms, with or without modification,
17// are permitted provided that the following conditions are met:
18//
19//   * Redistribution's of source code must retain the above copyright notice,
20//     this list of conditions and the following disclaimer.
21//
22//   * Redistribution's in binary form must reproduce the above copyright notice,
23//     this list of conditions and the following disclaimer in the documentation
24//     and/or other materials provided with the distribution.
25//
26//   * The name of Intel Corporation may not be used to endorse or promote products
27//     derived from this software without specific prior written permission.
28//
29// This software is provided by the copyright holders and contributors "as is" and
30// any express or implied warranties, including, but not limited to, the implied
31// warranties of merchantability and fitness for a particular purpose are disclaimed.
32// In no event shall the Intel Corporation or contributors be liable for any direct,
33// indirect, incidental, special, exemplary, or consequential damages
34// (including, but not limited to, procurement of substitute goods or services;
35// loss of use, data, or profits; or business interruption) however caused
36// and on any theory of liability, whether in contract, strict liability,
37// or tort (including negligence or otherwise) arising in any way out of
38// the use of this software, even if advised of the possibility of such damage.
39//
40//M*/
41
42/*
43    Partially based on Yossi Rubner code:
44    =========================================================================
45    emd.c
46
47    Last update: 3/14/98
48
49    An implementation of the Earth Movers Distance.
50    Based of the solution for the Transportation problem as described in
51    "Introduction to Mathematical Programming" by F. S. Hillier and
52    G. J. Lieberman, McGraw-Hill, 1990.
53
54    Copyright (C) 1998 Yossi Rubner
55    Computer Science Department, Stanford University
56    E-Mail: rubner@cs.stanford.edu   URL: http://vision.stanford.edu/~rubner
57    ==========================================================================
58*/
59#include "_cv.h"
60
61#define MAX_ITERATIONS 500
62#define CV_EMD_INF   ((float)1e20)
63#define CV_EMD_EPS   ((float)1e-5)
64
65/* CvNode1D is used for lists, representing 1D sparse array */
66typedef struct CvNode1D
67{
68    float val;
69    struct CvNode1D *next;
70}
71CvNode1D;
72
73/* CvNode2D is used for lists, representing 2D sparse matrix */
74typedef struct CvNode2D
75{
76    float val;
77    struct CvNode2D *next[2];  /* next row & next column */
78    int i, j;
79}
80CvNode2D;
81
82
83typedef struct CvEMDState
84{
85    int ssize, dsize;
86
87    float **cost;
88    CvNode2D *_x;
89    CvNode2D *end_x;
90    CvNode2D *enter_x;
91    char **is_x;
92
93    CvNode2D **rows_x;
94    CvNode2D **cols_x;
95
96    CvNode1D *u;
97    CvNode1D *v;
98
99    int* idx1;
100    int* idx2;
101
102    /* find_loop buffers */
103    CvNode2D **loop;
104    char *is_used;
105
106    /* russel buffers */
107    float *s;
108    float *d;
109    float **delta;
110
111    float weight, max_cost;
112    char *buffer;
113}
114CvEMDState;
115
116/* static function declaration */
117static CvStatus icvInitEMD( const float *signature1, int size1,
118                            const float *signature2, int size2,
119                            int dims, CvDistanceFunction dist_func, void *user_param,
120                            const float* cost, int cost_step,
121                            CvEMDState * state, float *lower_bound,
122                            char *local_buffer, int local_buffer_size );
123
124static CvStatus icvFindBasicVariables( float **cost, char **is_x,
125                                       CvNode1D * u, CvNode1D * v, int ssize, int dsize );
126
127static float icvIsOptimal( float **cost, char **is_x,
128                           CvNode1D * u, CvNode1D * v,
129                           int ssize, int dsize, CvNode2D * enter_x );
130
131static void icvRussel( CvEMDState * state );
132
133
134static CvStatus icvNewSolution( CvEMDState * state );
135static int icvFindLoop( CvEMDState * state );
136
137static void icvAddBasicVariable( CvEMDState * state,
138                                 int min_i, int min_j,
139                                 CvNode1D * prev_u_min_i,
140                                 CvNode1D * prev_v_min_j,
141                                 CvNode1D * u_head );
142
143static float icvDistL2( const float *x, const float *y, void *user_param );
144static float icvDistL1( const float *x, const float *y, void *user_param );
145static float icvDistC( const float *x, const float *y, void *user_param );
146
147/* The main function */
148CV_IMPL float
149cvCalcEMD2( const CvArr* signature_arr1,
150            const CvArr* signature_arr2,
151            int dist_type,
152            CvDistanceFunction dist_func,
153            const CvArr* cost_matrix,
154            CvArr* flow_matrix,
155            float *lower_bound,
156            void *user_param )
157{
158    char local_buffer[16384];
159    char *local_buffer_ptr = (char *)cvAlignPtr(local_buffer,16);
160    CvEMDState state;
161    float emd = 0;
162
163    CV_FUNCNAME( "cvCalcEMD2" );
164
165    memset( &state, 0, sizeof(state));
166
167    __BEGIN__;
168
169    double total_cost = 0;
170    CvStatus result = CV_NO_ERR;
171    float eps, min_delta;
172    CvNode2D *xp = 0;
173    CvMat sign_stub1, *signature1 = (CvMat*)signature_arr1;
174    CvMat sign_stub2, *signature2 = (CvMat*)signature_arr2;
175    CvMat cost_stub, *cost = &cost_stub;
176    CvMat flow_stub, *flow = (CvMat*)flow_matrix;
177    int dims, size1, size2;
178
179    CV_CALL( signature1 = cvGetMat( signature1, &sign_stub1 ));
180    CV_CALL( signature2 = cvGetMat( signature2, &sign_stub2 ));
181
182    if( signature1->cols != signature2->cols )
183        CV_ERROR( CV_StsUnmatchedSizes, "The arrays must have equal number of columns (which is number of dimensions but 1)" );
184
185    dims = signature1->cols - 1;
186    size1 = signature1->rows;
187    size2 = signature2->rows;
188
189    if( !CV_ARE_TYPES_EQ( signature1, signature2 ))
190        CV_ERROR( CV_StsUnmatchedFormats, "The array must have equal types" );
191
192    if( CV_MAT_TYPE( signature1->type ) != CV_32FC1 )
193        CV_ERROR( CV_StsUnsupportedFormat, "The signatures must be 32fC1" );
194
195    if( flow )
196    {
197        CV_CALL( flow = cvGetMat( flow, &flow_stub ));
198
199        if( flow->rows != size1 || flow->cols != size2 )
200            CV_ERROR( CV_StsUnmatchedSizes,
201            "The flow matrix size does not match to the signatures' sizes" );
202
203        if( CV_MAT_TYPE( flow->type ) != CV_32FC1 )
204            CV_ERROR( CV_StsUnsupportedFormat, "The flow matrix must be 32fC1" );
205    }
206
207    cost->data.fl = 0;
208    cost->step = 0;
209
210    if( dist_type < 0 )
211    {
212        if( cost_matrix )
213        {
214            if( dist_func )
215                CV_ERROR( CV_StsBadArg,
216                "Only one of cost matrix or distance function should be non-NULL in case of user-defined distance" );
217
218            if( lower_bound )
219                CV_ERROR( CV_StsBadArg,
220                "The lower boundary can not be calculated if the cost matrix is used" );
221
222            CV_CALL( cost = cvGetMat( cost_matrix, &cost_stub ));
223            if( cost->rows != size1 || cost->cols != size2 )
224                CV_ERROR( CV_StsUnmatchedSizes,
225                "The cost matrix size does not match to the signatures' sizes" );
226
227            if( CV_MAT_TYPE( cost->type ) != CV_32FC1 )
228                CV_ERROR( CV_StsUnsupportedFormat, "The cost matrix must be 32fC1" );
229        }
230        else if( !dist_func )
231            CV_ERROR( CV_StsNullPtr, "In case of user-defined distance Distance function is undefined" );
232    }
233    else
234    {
235        if( dims == 0 )
236            CV_ERROR( CV_StsBadSize,
237            "Number of dimensions can be 0 only if a user-defined metric is used" );
238        user_param = (void *) (size_t)dims;
239        switch (dist_type)
240        {
241        case CV_DIST_L1:
242            dist_func = icvDistL1;
243            break;
244        case CV_DIST_L2:
245            dist_func = icvDistL2;
246            break;
247        case CV_DIST_C:
248            dist_func = icvDistC;
249            break;
250        default:
251            CV_ERROR( CV_StsBadFlag, "Bad or unsupported metric type" );
252        }
253    }
254
255    IPPI_CALL( result = icvInitEMD( signature1->data.fl, size1,
256                                    signature2->data.fl, size2,
257                                    dims, dist_func, user_param,
258                                    cost->data.fl, cost->step,
259                                    &state, lower_bound, local_buffer_ptr,
260                                    sizeof( local_buffer ) - 16 ));
261
262    if( result > 0 && lower_bound )
263    {
264        emd = *lower_bound;
265        EXIT;
266    }
267
268    eps = CV_EMD_EPS * state.max_cost;
269
270    /* if ssize = 1 or dsize = 1 then we are done, else ... */
271    if( state.ssize > 1 && state.dsize > 1 )
272    {
273        int itr;
274
275        for( itr = 1; itr < MAX_ITERATIONS; itr++ )
276        {
277            /* find basic variables */
278            result = icvFindBasicVariables( state.cost, state.is_x,
279                                            state.u, state.v, state.ssize, state.dsize );
280            if( result < 0 )
281                break;
282
283            /* check for optimality */
284            min_delta = icvIsOptimal( state.cost, state.is_x,
285                                      state.u, state.v,
286                                      state.ssize, state.dsize, state.enter_x );
287
288            if( min_delta == CV_EMD_INF )
289            {
290                CV_ERROR( CV_StsNoConv, "" );
291            }
292
293            /* if no negative deltamin, we found the optimal solution */
294            if( min_delta >= -eps )
295                break;
296
297            /* improve solution */
298            IPPI_CALL( icvNewSolution( &state ));
299        }
300    }
301
302    /* compute the total flow */
303    for( xp = state._x; xp < state.end_x; xp++ )
304    {
305        float val = xp->val;
306        int i = xp->i;
307        int j = xp->j;
308        int ci = state.idx1[i];
309        int cj = state.idx2[j];
310
311        if( xp != state.enter_x && ci >= 0 && cj >= 0 )
312        {
313            total_cost += (double)val * state.cost[i][j];
314            if( flow )
315                ((float*)(flow->data.ptr + flow->step*ci))[cj] = val;
316        }
317    }
318
319    emd = (float) (total_cost / state.weight);
320
321    __END__;
322
323    if( state.buffer && state.buffer != local_buffer_ptr )
324        cvFree( &state.buffer );
325
326    return emd;
327}
328
329
330/************************************************************************************\
331*          initialize structure, allocate buffers and generate initial golution      *
332\************************************************************************************/
333static CvStatus
334icvInitEMD( const float* signature1, int size1,
335            const float* signature2, int size2,
336            int dims, CvDistanceFunction dist_func, void* user_param,
337            const float* cost, int cost_step,
338            CvEMDState* state, float* lower_bound,
339            char* local_buffer, int local_buffer_size )
340{
341    float s_sum = 0, d_sum = 0, diff;
342    int i, j;
343    int ssize = 0, dsize = 0;
344    int equal_sums = 1;
345    int buffer_size;
346    float max_cost = 0;
347    char *buffer, *buffer_end;
348
349    memset( state, 0, sizeof( *state ));
350    assert( cost_step % sizeof(float) == 0 );
351    cost_step /= sizeof(float);
352
353    /* calculate buffer size */
354    buffer_size = (size1+1) * (size2+1) * (sizeof( float ) +    /* cost */
355                                   sizeof( char ) +     /* is_x */
356                                   sizeof( float )) +   /* delta matrix */
357        (size1 + size2 + 2) * (sizeof( CvNode2D ) + /* _x */
358                           sizeof( CvNode2D * ) +  /* cols_x & rows_x */
359                           sizeof( CvNode1D ) + /* u & v */
360                           sizeof( float ) + /* s & d */
361                           sizeof( int ) + sizeof(CvNode2D*)) +  /* idx1 & idx2 */
362        (size1+1) * (sizeof( float * ) + sizeof( char * ) + /* rows pointers for */
363                 sizeof( float * )) + 256;      /*  cost, is_x and delta */
364
365    if( buffer_size < (int) (dims * 2 * sizeof( float )))
366    {
367        buffer_size = dims * 2 * sizeof( float );
368    }
369
370    /* allocate buffers */
371    if( local_buffer != 0 && local_buffer_size >= buffer_size )
372    {
373        buffer = local_buffer;
374    }
375    else
376    {
377        buffer = (char*)cvAlloc( buffer_size );
378        if( !buffer )
379            return CV_OUTOFMEM_ERR;
380    }
381
382    state->buffer = buffer;
383    buffer_end = buffer + buffer_size;
384
385    state->idx1 = (int*) buffer;
386    buffer += (size1 + 1) * sizeof( int );
387
388    state->idx2 = (int*) buffer;
389    buffer += (size2 + 1) * sizeof( int );
390
391    state->s = (float *) buffer;
392    buffer += (size1 + 1) * sizeof( float );
393
394    state->d = (float *) buffer;
395    buffer += (size2 + 1) * sizeof( float );
396
397    /* sum up the supply and demand */
398    for( i = 0; i < size1; i++ )
399    {
400        float weight = signature1[i * (dims + 1)];
401
402        if( weight > 0 )
403        {
404            s_sum += weight;
405            state->s[ssize] = weight;
406            state->idx1[ssize++] = i;
407
408        }
409        else if( weight < 0 )
410            return CV_BADRANGE_ERR;
411    }
412
413    for( i = 0; i < size2; i++ )
414    {
415        float weight = signature2[i * (dims + 1)];
416
417        if( weight > 0 )
418        {
419            d_sum += weight;
420            state->d[dsize] = weight;
421            state->idx2[dsize++] = i;
422        }
423        else if( weight < 0 )
424            return CV_BADRANGE_ERR;
425    }
426
427    if( ssize == 0 || dsize == 0 )
428        return CV_BADRANGE_ERR;
429
430    /* if supply different than the demand, add a zero-cost dummy cluster */
431    diff = s_sum - d_sum;
432    if( fabs( diff ) >= CV_EMD_EPS * s_sum )
433    {
434        equal_sums = 0;
435        if( diff < 0 )
436        {
437            state->s[ssize] = -diff;
438            state->idx1[ssize++] = -1;
439        }
440        else
441        {
442            state->d[dsize] = diff;
443            state->idx2[dsize++] = -1;
444        }
445    }
446
447    state->ssize = ssize;
448    state->dsize = dsize;
449    state->weight = s_sum > d_sum ? s_sum : d_sum;
450
451    if( lower_bound && equal_sums )     /* check lower bound */
452    {
453        int sz1 = size1 * (dims + 1), sz2 = size2 * (dims + 1);
454        float lb = 0;
455
456        float* xs = (float *) buffer;
457        float* xd = xs + dims;
458
459        memset( xs, 0, dims*sizeof(xs[0]));
460        memset( xd, 0, dims*sizeof(xd[0]));
461
462        for( j = 0; j < sz1; j += dims + 1 )
463        {
464            float weight = signature1[j];
465            for( i = 0; i < dims; i++ )
466                xs[i] += signature1[j + i + 1] * weight;
467        }
468
469        for( j = 0; j < sz2; j += dims + 1 )
470        {
471            float weight = signature2[j];
472            for( i = 0; i < dims; i++ )
473                xd[i] += signature2[j + i + 1] * weight;
474        }
475
476        lb = dist_func( xs, xd, user_param ) / state->weight;
477        i = *lower_bound <= lb;
478        *lower_bound = lb;
479        if( i )
480            return ( CvStatus ) 1;
481    }
482
483    /* assign pointers */
484    state->is_used = (char *) buffer;
485    /* init delta matrix */
486    state->delta = (float **) buffer;
487    buffer += ssize * sizeof( float * );
488
489    for( i = 0; i < ssize; i++ )
490    {
491        state->delta[i] = (float *) buffer;
492        buffer += dsize * sizeof( float );
493    }
494
495    state->loop = (CvNode2D **) buffer;
496    buffer += (ssize + dsize + 1) * sizeof(CvNode2D*);
497
498    state->_x = state->end_x = (CvNode2D *) buffer;
499    buffer += (ssize + dsize) * sizeof( CvNode2D );
500
501    /* init cost matrix */
502    state->cost = (float **) buffer;
503    buffer += ssize * sizeof( float * );
504
505    /* compute the distance matrix */
506    for( i = 0; i < ssize; i++ )
507    {
508        int ci = state->idx1[i];
509
510        state->cost[i] = (float *) buffer;
511        buffer += dsize * sizeof( float );
512
513        if( ci >= 0 )
514        {
515            for( j = 0; j < dsize; j++ )
516            {
517                int cj = state->idx2[j];
518                if( cj < 0 )
519                    state->cost[i][j] = 0;
520                else
521                {
522                    float val;
523                    if( dist_func )
524                    {
525                        val = dist_func( signature1 + ci * (dims + 1) + 1,
526                                         signature2 + cj * (dims + 1) + 1,
527                                         user_param );
528                    }
529                    else
530                    {
531                        assert( cost );
532                        val = cost[cost_step*ci + cj];
533                    }
534                    state->cost[i][j] = val;
535                    if( max_cost < val )
536                        max_cost = val;
537                }
538            }
539        }
540        else
541        {
542            for( j = 0; j < dsize; j++ )
543                state->cost[i][j] = 0;
544        }
545    }
546
547    state->max_cost = max_cost;
548
549    memset( buffer, 0, buffer_end - buffer );
550
551    state->rows_x = (CvNode2D **) buffer;
552    buffer += ssize * sizeof( CvNode2D * );
553
554    state->cols_x = (CvNode2D **) buffer;
555    buffer += dsize * sizeof( CvNode2D * );
556
557    state->u = (CvNode1D *) buffer;
558    buffer += ssize * sizeof( CvNode1D );
559
560    state->v = (CvNode1D *) buffer;
561    buffer += dsize * sizeof( CvNode1D );
562
563    /* init is_x matrix */
564    state->is_x = (char **) buffer;
565    buffer += ssize * sizeof( char * );
566
567    for( i = 0; i < ssize; i++ )
568    {
569        state->is_x[i] = buffer;
570        buffer += dsize;
571    }
572
573    assert( buffer <= buffer_end );
574
575    icvRussel( state );
576
577    state->enter_x = (state->end_x)++;
578    return CV_NO_ERR;
579}
580
581
582/****************************************************************************************\
583*                              icvFindBasicVariables                                   *
584\****************************************************************************************/
585static CvStatus
586icvFindBasicVariables( float **cost, char **is_x,
587                       CvNode1D * u, CvNode1D * v, int ssize, int dsize )
588{
589    int i, j, found;
590    int u_cfound, v_cfound;
591    CvNode1D u0_head, u1_head, *cur_u, *prev_u;
592    CvNode1D v0_head, v1_head, *cur_v, *prev_v;
593
594    /* initialize the rows list (u) and the columns list (v) */
595    u0_head.next = u;
596    for( i = 0; i < ssize; i++ )
597    {
598        u[i].next = u + i + 1;
599    }
600    u[ssize - 1].next = 0;
601    u1_head.next = 0;
602
603    v0_head.next = ssize > 1 ? v + 1 : 0;
604    for( i = 1; i < dsize; i++ )
605    {
606        v[i].next = v + i + 1;
607    }
608    v[dsize - 1].next = 0;
609    v1_head.next = 0;
610
611    /* there are ssize+dsize variables but only ssize+dsize-1 independent equations,
612       so set v[0]=0 */
613    v[0].val = 0;
614    v1_head.next = v;
615    v1_head.next->next = 0;
616
617    /* loop until all variables are found */
618    u_cfound = v_cfound = 0;
619    while( u_cfound < ssize || v_cfound < dsize )
620    {
621        found = 0;
622        if( v_cfound < dsize )
623        {
624            /* loop over all marked columns */
625            prev_v = &v1_head;
626
627            for( found |= (cur_v = v1_head.next) != 0; cur_v != 0; cur_v = cur_v->next )
628            {
629                float cur_v_val = cur_v->val;
630
631                j = (int)(cur_v - v);
632                /* find the variables in column j */
633                prev_u = &u0_head;
634                for( cur_u = u0_head.next; cur_u != 0; )
635                {
636                    i = (int)(cur_u - u);
637                    if( is_x[i][j] )
638                    {
639                        /* compute u[i] */
640                        cur_u->val = cost[i][j] - cur_v_val;
641                        /* ...and add it to the marked list */
642                        prev_u->next = cur_u->next;
643                        cur_u->next = u1_head.next;
644                        u1_head.next = cur_u;
645                        cur_u = prev_u->next;
646                    }
647                    else
648                    {
649                        prev_u = cur_u;
650                        cur_u = cur_u->next;
651                    }
652                }
653                prev_v->next = cur_v->next;
654                v_cfound++;
655            }
656        }
657
658        if( u_cfound < ssize )
659        {
660            /* loop over all marked rows */
661            prev_u = &u1_head;
662            for( found |= (cur_u = u1_head.next) != 0; cur_u != 0; cur_u = cur_u->next )
663            {
664                float cur_u_val = cur_u->val;
665                float *_cost;
666                char *_is_x;
667
668                i = (int)(cur_u - u);
669                _cost = cost[i];
670                _is_x = is_x[i];
671                /* find the variables in rows i */
672                prev_v = &v0_head;
673                for( cur_v = v0_head.next; cur_v != 0; )
674                {
675                    j = (int)(cur_v - v);
676                    if( _is_x[j] )
677                    {
678                        /* compute v[j] */
679                        cur_v->val = _cost[j] - cur_u_val;
680                        /* ...and add it to the marked list */
681                        prev_v->next = cur_v->next;
682                        cur_v->next = v1_head.next;
683                        v1_head.next = cur_v;
684                        cur_v = prev_v->next;
685                    }
686                    else
687                    {
688                        prev_v = cur_v;
689                        cur_v = cur_v->next;
690                    }
691                }
692                prev_u->next = cur_u->next;
693                u_cfound++;
694            }
695        }
696
697        if( !found )
698        {
699            return CV_NOTDEFINED_ERR;
700        }
701    }
702
703    return CV_NO_ERR;
704}
705
706
707/****************************************************************************************\
708*                                   icvIsOptimal                                       *
709\****************************************************************************************/
710static float
711icvIsOptimal( float **cost, char **is_x,
712              CvNode1D * u, CvNode1D * v, int ssize, int dsize, CvNode2D * enter_x )
713{
714    float delta, min_delta = CV_EMD_INF;
715    int i, j, min_i = 0, min_j = 0;
716
717    /* find the minimal cij-ui-vj over all i,j */
718    for( i = 0; i < ssize; i++ )
719    {
720        float u_val = u[i].val;
721        float *_cost = cost[i];
722        char *_is_x = is_x[i];
723
724        for( j = 0; j < dsize; j++ )
725        {
726            if( !_is_x[j] )
727            {
728                delta = _cost[j] - u_val - v[j].val;
729                if( min_delta > delta )
730                {
731                    min_delta = delta;
732                    min_i = i;
733                    min_j = j;
734                }
735            }
736        }
737    }
738
739    enter_x->i = min_i;
740    enter_x->j = min_j;
741
742    return min_delta;
743}
744
745/****************************************************************************************\
746*                                   icvNewSolution                                     *
747\****************************************************************************************/
748static CvStatus
749icvNewSolution( CvEMDState * state )
750{
751    int i, j;
752    float min_val = CV_EMD_INF;
753    int steps;
754    CvNode2D head, *cur_x, *next_x, *leave_x = 0;
755    CvNode2D *enter_x = state->enter_x;
756    CvNode2D **loop = state->loop;
757
758    /* enter the new basic variable */
759    i = enter_x->i;
760    j = enter_x->j;
761    state->is_x[i][j] = 1;
762    enter_x->next[0] = state->rows_x[i];
763    enter_x->next[1] = state->cols_x[j];
764    enter_x->val = 0;
765    state->rows_x[i] = enter_x;
766    state->cols_x[j] = enter_x;
767
768    /* find a chain reaction */
769    steps = icvFindLoop( state );
770
771    if( steps == 0 )
772        return CV_NOTDEFINED_ERR;
773
774    /* find the largest value in the loop */
775    for( i = 1; i < steps; i += 2 )
776    {
777        float temp = loop[i]->val;
778
779        if( min_val > temp )
780        {
781            leave_x = loop[i];
782            min_val = temp;
783        }
784    }
785
786    /* update the loop */
787    for( i = 0; i < steps; i += 2 )
788    {
789        float temp0 = loop[i]->val + min_val;
790        float temp1 = loop[i + 1]->val - min_val;
791
792        loop[i]->val = temp0;
793        loop[i + 1]->val = temp1;
794    }
795
796    /* remove the leaving basic variable */
797    i = leave_x->i;
798    j = leave_x->j;
799    state->is_x[i][j] = 0;
800
801    head.next[0] = state->rows_x[i];
802    cur_x = &head;
803    while( (next_x = cur_x->next[0]) != leave_x )
804    {
805        cur_x = next_x;
806        assert( cur_x );
807    }
808    cur_x->next[0] = next_x->next[0];
809    state->rows_x[i] = head.next[0];
810
811    head.next[1] = state->cols_x[j];
812    cur_x = &head;
813    while( (next_x = cur_x->next[1]) != leave_x )
814    {
815        cur_x = next_x;
816        assert( cur_x );
817    }
818    cur_x->next[1] = next_x->next[1];
819    state->cols_x[j] = head.next[1];
820
821    /* set enter_x to be the new empty slot */
822    state->enter_x = leave_x;
823
824    return CV_NO_ERR;
825}
826
827
828
829/****************************************************************************************\
830*                                    icvFindLoop                                       *
831\****************************************************************************************/
832static int
833icvFindLoop( CvEMDState * state )
834{
835    int i, steps = 1;
836    CvNode2D *new_x;
837    CvNode2D **loop = state->loop;
838    CvNode2D *enter_x = state->enter_x, *_x = state->_x;
839    char *is_used = state->is_used;
840
841    memset( is_used, 0, state->ssize + state->dsize );
842
843    new_x = loop[0] = enter_x;
844    is_used[enter_x - _x] = 1;
845    steps = 1;
846
847    do
848    {
849        if( (steps & 1) == 1 )
850        {
851            /* find an unused x in the row */
852            new_x = state->rows_x[new_x->i];
853            while( new_x != 0 && is_used[new_x - _x] )
854                new_x = new_x->next[0];
855        }
856        else
857        {
858            /* find an unused x in the column, or the entering x */
859            new_x = state->cols_x[new_x->j];
860            while( new_x != 0 && is_used[new_x - _x] && new_x != enter_x )
861                new_x = new_x->next[1];
862            if( new_x == enter_x )
863                break;
864        }
865
866        if( new_x != 0 )        /* found the next x */
867        {
868            /* add x to the loop */
869            loop[steps++] = new_x;
870            is_used[new_x - _x] = 1;
871        }
872        else                    /* didn't find the next x */
873        {
874            /* backtrack */
875            do
876            {
877                i = steps & 1;
878                new_x = loop[steps - 1];
879                do
880                {
881                    new_x = new_x->next[i];
882                }
883                while( new_x != 0 && is_used[new_x - _x] );
884
885                if( new_x == 0 )
886                {
887                    is_used[loop[--steps] - _x] = 0;
888                }
889            }
890            while( new_x == 0 && steps > 0 );
891
892            is_used[loop[steps - 1] - _x] = 0;
893            loop[steps - 1] = new_x;
894            is_used[new_x - _x] = 1;
895        }
896    }
897    while( steps > 0 );
898
899    return steps;
900}
901
902
903
904/****************************************************************************************\
905*                                        icvRussel                                     *
906\****************************************************************************************/
907static void
908icvRussel( CvEMDState * state )
909{
910    int i, j, min_i = -1, min_j = -1;
911    float min_delta, diff;
912    CvNode1D u_head, *cur_u, *prev_u;
913    CvNode1D v_head, *cur_v, *prev_v;
914    CvNode1D *prev_u_min_i = 0, *prev_v_min_j = 0, *remember;
915    CvNode1D *u = state->u, *v = state->v;
916    int ssize = state->ssize, dsize = state->dsize;
917    float eps = CV_EMD_EPS * state->max_cost;
918    float **cost = state->cost;
919    float **delta = state->delta;
920
921    /* initialize the rows list (ur), and the columns list (vr) */
922    u_head.next = u;
923    for( i = 0; i < ssize; i++ )
924    {
925        u[i].next = u + i + 1;
926    }
927    u[ssize - 1].next = 0;
928
929    v_head.next = v;
930    for( i = 0; i < dsize; i++ )
931    {
932        v[i].val = -CV_EMD_INF;
933        v[i].next = v + i + 1;
934    }
935    v[dsize - 1].next = 0;
936
937    /* find the maximum row and column values (ur[i] and vr[j]) */
938    for( i = 0; i < ssize; i++ )
939    {
940        float u_val = -CV_EMD_INF;
941        float *cost_row = cost[i];
942
943        for( j = 0; j < dsize; j++ )
944        {
945            float temp = cost_row[j];
946
947            if( u_val < temp )
948                u_val = temp;
949            if( v[j].val < temp )
950                v[j].val = temp;
951        }
952        u[i].val = u_val;
953    }
954
955    /* compute the delta matrix */
956    for( i = 0; i < ssize; i++ )
957    {
958        float u_val = u[i].val;
959        float *delta_row = delta[i];
960        float *cost_row = cost[i];
961
962        for( j = 0; j < dsize; j++ )
963        {
964            delta_row[j] = cost_row[j] - u_val - v[j].val;
965        }
966    }
967
968    /* find the basic variables */
969    do
970    {
971        /* find the smallest delta[i][j] */
972        min_i = -1;
973        min_delta = CV_EMD_INF;
974        prev_u = &u_head;
975        for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
976        {
977            i = (int)(cur_u - u);
978            float *delta_row = delta[i];
979
980            prev_v = &v_head;
981            for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
982            {
983                j = (int)(cur_v - v);
984                if( min_delta > delta_row[j] )
985                {
986                    min_delta = delta_row[j];
987                    min_i = i;
988                    min_j = j;
989                    prev_u_min_i = prev_u;
990                    prev_v_min_j = prev_v;
991                }
992                prev_v = cur_v;
993            }
994            prev_u = cur_u;
995        }
996
997        if( min_i < 0 )
998            break;
999
1000        /* add x[min_i][min_j] to the basis, and adjust supplies and cost */
1001        remember = prev_u_min_i->next;
1002        icvAddBasicVariable( state, min_i, min_j, prev_u_min_i, prev_v_min_j, &u_head );
1003
1004        /* update the necessary delta[][] */
1005        if( remember == prev_u_min_i->next )    /* line min_i was deleted */
1006        {
1007            for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
1008            {
1009                j = (int)(cur_v - v);
1010                if( cur_v->val == cost[min_i][j] )      /* column j needs updating */
1011                {
1012                    float max_val = -CV_EMD_INF;
1013
1014                    /* find the new maximum value in the column */
1015                    for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
1016                    {
1017                        float temp = cost[cur_u - u][j];
1018
1019                        if( max_val < temp )
1020                            max_val = temp;
1021                    }
1022
1023                    /* if needed, adjust the relevant delta[*][j] */
1024                    diff = max_val - cur_v->val;
1025                    cur_v->val = max_val;
1026                    if( fabs( diff ) < eps )
1027                    {
1028                        for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
1029                            delta[cur_u - u][j] += diff;
1030                    }
1031                }
1032            }
1033        }
1034        else                    /* column min_j was deleted */
1035        {
1036            for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
1037            {
1038                i = (int)(cur_u - u);
1039                if( cur_u->val == cost[i][min_j] )      /* row i needs updating */
1040                {
1041                    float max_val = -CV_EMD_INF;
1042
1043                    /* find the new maximum value in the row */
1044                    for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
1045                    {
1046                        float temp = cost[i][cur_v - v];
1047
1048                        if( max_val < temp )
1049                            max_val = temp;
1050                    }
1051
1052                    /* if needed, adjust the relevant delta[i][*] */
1053                    diff = max_val - cur_u->val;
1054                    cur_u->val = max_val;
1055
1056                    if( fabs( diff ) < eps )
1057                    {
1058                        for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
1059                            delta[i][cur_v - v] += diff;
1060                    }
1061                }
1062            }
1063        }
1064    }
1065    while( u_head.next != 0 || v_head.next != 0 );
1066}
1067
1068
1069
1070/****************************************************************************************\
1071*                                   icvAddBasicVariable                                *
1072\****************************************************************************************/
1073static void
1074icvAddBasicVariable( CvEMDState * state,
1075                     int min_i, int min_j,
1076                     CvNode1D * prev_u_min_i, CvNode1D * prev_v_min_j, CvNode1D * u_head )
1077{
1078    float temp;
1079    CvNode2D *end_x = state->end_x;
1080
1081    if( state->s[min_i] < state->d[min_j] + state->weight * CV_EMD_EPS )
1082    {                           /* supply exhausted */
1083        temp = state->s[min_i];
1084        state->s[min_i] = 0;
1085        state->d[min_j] -= temp;
1086    }
1087    else                        /* demand exhausted */
1088    {
1089        temp = state->d[min_j];
1090        state->d[min_j] = 0;
1091        state->s[min_i] -= temp;
1092    }
1093
1094    /* x(min_i,min_j) is a basic variable */
1095    state->is_x[min_i][min_j] = 1;
1096
1097    end_x->val = temp;
1098    end_x->i = min_i;
1099    end_x->j = min_j;
1100    end_x->next[0] = state->rows_x[min_i];
1101    end_x->next[1] = state->cols_x[min_j];
1102    state->rows_x[min_i] = end_x;
1103    state->cols_x[min_j] = end_x;
1104    state->end_x = end_x + 1;
1105
1106    /* delete supply row only if the empty, and if not last row */
1107    if( state->s[min_i] == 0 && u_head->next->next != 0 )
1108        prev_u_min_i->next = prev_u_min_i->next->next;  /* remove row from list */
1109    else
1110        prev_v_min_j->next = prev_v_min_j->next->next;  /* remove column from list */
1111}
1112
1113
1114/****************************************************************************************\
1115*                                  standard  metrics                                     *
1116\****************************************************************************************/
1117static float
1118icvDistL1( const float *x, const float *y, void *user_param )
1119{
1120    int i, dims = (int)(size_t)user_param;
1121    double s = 0;
1122
1123    for( i = 0; i < dims; i++ )
1124    {
1125        double t = x[i] - y[i];
1126
1127        s += fabs( t );
1128    }
1129    return (float)s;
1130}
1131
1132static float
1133icvDistL2( const float *x, const float *y, void *user_param )
1134{
1135    int i, dims = (int)(size_t)user_param;
1136    double s = 0;
1137
1138    for( i = 0; i < dims; i++ )
1139    {
1140        double t = x[i] - y[i];
1141
1142        s += t * t;
1143    }
1144    return cvSqrt( (float)s );
1145}
1146
1147static float
1148icvDistC( const float *x, const float *y, void *user_param )
1149{
1150    int i, dims = (int)(size_t)user_param;
1151    double s = 0;
1152
1153    for( i = 0; i < dims; i++ )
1154    {
1155        double t = fabs( x[i] - y[i] );
1156
1157        if( s < t )
1158            s = t;
1159    }
1160    return (float)s;
1161}
1162
1163/* End of file. */
1164
1165