1/***********************************************************************
2 * Software License Agreement (BSD License)
3 *
4 * Copyright 2008-2009  Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
5 * Copyright 2008-2009  David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
6 *
7 * THE BSD LICENSE
8 *
9 * Redistribution and use in source and binary forms, with or without
10 * modification, are permitted provided that the following conditions
11 * are met:
12 *
13 * 1. Redistributions of source code must retain the above copyright
14 *    notice, this list of conditions and the following disclaimer.
15 * 2. Redistributions in binary form must reproduce the above copyright
16 *    notice, this list of conditions and the following disclaimer in the
17 *    documentation and/or other materials provided with the distribution.
18 *
19 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
20 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
21 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
22 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
23 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
24 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
28 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 *************************************************************************/
30
31#ifndef OPENCV_FLANN_KDTREE_INDEX_H_
32#define OPENCV_FLANN_KDTREE_INDEX_H_
33
34#include <algorithm>
35#include <map>
36#include <cassert>
37#include <cstring>
38
39#include "general.h"
40#include "nn_index.h"
41#include "dynamic_bitset.h"
42#include "matrix.h"
43#include "result_set.h"
44#include "heap.h"
45#include "allocator.h"
46#include "random.h"
47#include "saving.h"
48
49
50namespace cvflann
51{
52
53struct KDTreeIndexParams : public IndexParams
54{
55    KDTreeIndexParams(int trees = 4)
56    {
57        (*this)["algorithm"] = FLANN_INDEX_KDTREE;
58        (*this)["trees"] = trees;
59    }
60};
61
62
63/**
64 * Randomized kd-tree index
65 *
66 * Contains the k-d trees and other information for indexing a set of points
67 * for nearest-neighbor matching.
68 */
69template <typename Distance>
70class KDTreeIndex : public NNIndex<Distance>
71{
72public:
73    typedef typename Distance::ElementType ElementType;
74    typedef typename Distance::ResultType DistanceType;
75
76
77    /**
78     * KDTree constructor
79     *
80     * Params:
81     *          inputData = dataset with the input features
82     *          params = parameters passed to the kdtree algorithm
83     */
84    KDTreeIndex(const Matrix<ElementType>& inputData, const IndexParams& params = KDTreeIndexParams(),
85                Distance d = Distance() ) :
86        dataset_(inputData), index_params_(params), distance_(d)
87    {
88        size_ = dataset_.rows;
89        veclen_ = dataset_.cols;
90
91        trees_ = get_param(index_params_,"trees",4);
92        tree_roots_ = new NodePtr[trees_];
93
94        // Create a permutable array of indices to the input vectors.
95        vind_.resize(size_);
96        for (size_t i = 0; i < size_; ++i) {
97            vind_[i] = int(i);
98        }
99
100        mean_ = new DistanceType[veclen_];
101        var_ = new DistanceType[veclen_];
102    }
103
104
105    KDTreeIndex(const KDTreeIndex&);
106    KDTreeIndex& operator=(const KDTreeIndex&);
107
108    /**
109     * Standard destructor
110     */
111    ~KDTreeIndex()
112    {
113        if (tree_roots_!=NULL) {
114            delete[] tree_roots_;
115        }
116        delete[] mean_;
117        delete[] var_;
118    }
119
120    /**
121     * Builds the index
122     */
123    void buildIndex()
124    {
125        /* Construct the randomized trees. */
126        for (int i = 0; i < trees_; i++) {
127            /* Randomize the order of vectors to allow for unbiased sampling. */
128            std::random_shuffle(vind_.begin(), vind_.end());
129            tree_roots_[i] = divideTree(&vind_[0], int(size_) );
130        }
131    }
132
133
134    flann_algorithm_t getType() const
135    {
136        return FLANN_INDEX_KDTREE;
137    }
138
139
140    void saveIndex(FILE* stream)
141    {
142        save_value(stream, trees_);
143        for (int i=0; i<trees_; ++i) {
144            save_tree(stream, tree_roots_[i]);
145        }
146    }
147
148
149
150    void loadIndex(FILE* stream)
151    {
152        load_value(stream, trees_);
153        if (tree_roots_!=NULL) {
154            delete[] tree_roots_;
155        }
156        tree_roots_ = new NodePtr[trees_];
157        for (int i=0; i<trees_; ++i) {
158            load_tree(stream,tree_roots_[i]);
159        }
160
161        index_params_["algorithm"] = getType();
162        index_params_["trees"] = tree_roots_;
163    }
164
165    /**
166     *  Returns size of index.
167     */
168    size_t size() const
169    {
170        return size_;
171    }
172
173    /**
174     * Returns the length of an index feature.
175     */
176    size_t veclen() const
177    {
178        return veclen_;
179    }
180
181    /**
182     * Computes the inde memory usage
183     * Returns: memory used by the index
184     */
185    int usedMemory() const
186    {
187        return int(pool_.usedMemory+pool_.wastedMemory+dataset_.rows*sizeof(int));  // pool memory and vind array memory
188    }
189
190    /**
191     * Find set of nearest neighbors to vec. Their indices are stored inside
192     * the result object.
193     *
194     * Params:
195     *     result = the result object in which the indices of the nearest-neighbors are stored
196     *     vec = the vector for which to search the nearest neighbors
197     *     maxCheck = the maximum number of restarts (in a best-bin-first manner)
198     */
199    void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams)
200    {
201        int maxChecks = get_param(searchParams,"checks", 32);
202        float epsError = 1+get_param(searchParams,"eps",0.0f);
203
204        if (maxChecks==FLANN_CHECKS_UNLIMITED) {
205            getExactNeighbors(result, vec, epsError);
206        }
207        else {
208            getNeighbors(result, vec, maxChecks, epsError);
209        }
210    }
211
212    IndexParams getParameters() const
213    {
214        return index_params_;
215    }
216
217private:
218
219
220    /*--------------------- Internal Data Structures --------------------------*/
221    struct Node
222    {
223        /**
224         * Dimension used for subdivision.
225         */
226        int divfeat;
227        /**
228         * The values used for subdivision.
229         */
230        DistanceType divval;
231        /**
232         * The child nodes.
233         */
234        Node* child1, * child2;
235    };
236    typedef Node* NodePtr;
237    typedef BranchStruct<NodePtr, DistanceType> BranchSt;
238    typedef BranchSt* Branch;
239
240
241
242    void save_tree(FILE* stream, NodePtr tree)
243    {
244        save_value(stream, *tree);
245        if (tree->child1!=NULL) {
246            save_tree(stream, tree->child1);
247        }
248        if (tree->child2!=NULL) {
249            save_tree(stream, tree->child2);
250        }
251    }
252
253
254    void load_tree(FILE* stream, NodePtr& tree)
255    {
256        tree = pool_.allocate<Node>();
257        load_value(stream, *tree);
258        if (tree->child1!=NULL) {
259            load_tree(stream, tree->child1);
260        }
261        if (tree->child2!=NULL) {
262            load_tree(stream, tree->child2);
263        }
264    }
265
266
267    /**
268     * Create a tree node that subdivides the list of vecs from vind[first]
269     * to vind[last].  The routine is called recursively on each sublist.
270     * Place a pointer to this new tree node in the location pTree.
271     *
272     * Params: pTree = the new node to create
273     *                  first = index of the first vector
274     *                  last = index of the last vector
275     */
276    NodePtr divideTree(int* ind, int count)
277    {
278        NodePtr node = pool_.allocate<Node>(); // allocate memory
279
280        /* If too few exemplars remain, then make this a leaf node. */
281        if ( count == 1) {
282            node->child1 = node->child2 = NULL;    /* Mark as leaf node. */
283            node->divfeat = *ind;    /* Store index of this vec. */
284        }
285        else {
286            int idx;
287            int cutfeat;
288            DistanceType cutval;
289            meanSplit(ind, count, idx, cutfeat, cutval);
290
291            node->divfeat = cutfeat;
292            node->divval = cutval;
293            node->child1 = divideTree(ind, idx);
294            node->child2 = divideTree(ind+idx, count-idx);
295        }
296
297        return node;
298    }
299
300
301    /**
302     * Choose which feature to use in order to subdivide this set of vectors.
303     * Make a random choice among those with the highest variance, and use
304     * its variance as the threshold value.
305     */
306    void meanSplit(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval)
307    {
308        memset(mean_,0,veclen_*sizeof(DistanceType));
309        memset(var_,0,veclen_*sizeof(DistanceType));
310
311        /* Compute mean values.  Only the first SAMPLE_MEAN values need to be
312            sampled to get a good estimate.
313         */
314        int cnt = std::min((int)SAMPLE_MEAN+1, count);
315        for (int j = 0; j < cnt; ++j) {
316            ElementType* v = dataset_[ind[j]];
317            for (size_t k=0; k<veclen_; ++k) {
318                mean_[k] += v[k];
319            }
320        }
321        for (size_t k=0; k<veclen_; ++k) {
322            mean_[k] /= cnt;
323        }
324
325        /* Compute variances (no need to divide by count). */
326        for (int j = 0; j < cnt; ++j) {
327            ElementType* v = dataset_[ind[j]];
328            for (size_t k=0; k<veclen_; ++k) {
329                DistanceType dist = v[k] - mean_[k];
330                var_[k] += dist * dist;
331            }
332        }
333        /* Select one of the highest variance indices at random. */
334        cutfeat = selectDivision(var_);
335        cutval = mean_[cutfeat];
336
337        int lim1, lim2;
338        planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
339
340        if (lim1>count/2) index = lim1;
341        else if (lim2<count/2) index = lim2;
342        else index = count/2;
343
344        /* If either list is empty, it means that all remaining features
345         * are identical. Split in the middle to maintain a balanced tree.
346         */
347        if ((lim1==count)||(lim2==0)) index = count/2;
348    }
349
350
351    /**
352     * Select the top RAND_DIM largest values from v and return the index of
353     * one of these selected at random.
354     */
355    int selectDivision(DistanceType* v)
356    {
357        int num = 0;
358        size_t topind[RAND_DIM];
359
360        /* Create a list of the indices of the top RAND_DIM values. */
361        for (size_t i = 0; i < veclen_; ++i) {
362            if ((num < RAND_DIM)||(v[i] > v[topind[num-1]])) {
363                /* Put this element at end of topind. */
364                if (num < RAND_DIM) {
365                    topind[num++] = i;            /* Add to list. */
366                }
367                else {
368                    topind[num-1] = i;         /* Replace last element. */
369                }
370                /* Bubble end value down to right location by repeated swapping. */
371                int j = num - 1;
372                while (j > 0  &&  v[topind[j]] > v[topind[j-1]]) {
373                    std::swap(topind[j], topind[j-1]);
374                    --j;
375                }
376            }
377        }
378        /* Select a random integer in range [0,num-1], and return that index. */
379        int rnd = rand_int(num);
380        return (int)topind[rnd];
381    }
382
383
384    /**
385     *  Subdivide the list of points by a plane perpendicular on axe corresponding
386     *  to the 'cutfeat' dimension at 'cutval' position.
387     *
388     *  On return:
389     *  dataset[ind[0..lim1-1]][cutfeat]<cutval
390     *  dataset[ind[lim1..lim2-1]][cutfeat]==cutval
391     *  dataset[ind[lim2..count]][cutfeat]>cutval
392     */
393    void planeSplit(int* ind, int count, int cutfeat, DistanceType cutval, int& lim1, int& lim2)
394    {
395        /* Move vector indices for left subtree to front of list. */
396        int left = 0;
397        int right = count-1;
398        for (;; ) {
399            while (left<=right && dataset_[ind[left]][cutfeat]<cutval) ++left;
400            while (left<=right && dataset_[ind[right]][cutfeat]>=cutval) --right;
401            if (left>right) break;
402            std::swap(ind[left], ind[right]); ++left; --right;
403        }
404        lim1 = left;
405        right = count-1;
406        for (;; ) {
407            while (left<=right && dataset_[ind[left]][cutfeat]<=cutval) ++left;
408            while (left<=right && dataset_[ind[right]][cutfeat]>cutval) --right;
409            if (left>right) break;
410            std::swap(ind[left], ind[right]); ++left; --right;
411        }
412        lim2 = left;
413    }
414
415    /**
416     * Performs an exact nearest neighbor search. The exact search performs a full
417     * traversal of the tree.
418     */
419    void getExactNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, float epsError)
420    {
421        //		checkID -= 1;  /* Set a different unique ID for each search. */
422
423        if (trees_ > 1) {
424            fprintf(stderr,"It doesn't make any sense to use more than one tree for exact search");
425        }
426        if (trees_>0) {
427            searchLevelExact(result, vec, tree_roots_[0], 0.0, epsError);
428        }
429        assert(result.full());
430    }
431
432    /**
433     * Performs the approximate nearest-neighbor search. The search is approximate
434     * because the tree traversal is abandoned after a given number of descends in
435     * the tree.
436     */
437    void getNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, int maxCheck, float epsError)
438    {
439        int i;
440        BranchSt branch;
441
442        int checkCount = 0;
443        Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
444        DynamicBitset checked(size_);
445
446        /* Search once through each tree down to root. */
447        for (i = 0; i < trees_; ++i) {
448            searchLevel(result, vec, tree_roots_[i], 0, checkCount, maxCheck, epsError, heap, checked);
449        }
450
451        /* Keep searching other branches from heap until finished. */
452        while ( heap->popMin(branch) && (checkCount < maxCheck || !result.full() )) {
453            searchLevel(result, vec, branch.node, branch.mindist, checkCount, maxCheck, epsError, heap, checked);
454        }
455
456        delete heap;
457
458        assert(result.full());
459    }
460
461
462    /**
463     *  Search starting from a given node of the tree.  Based on any mismatches at
464     *  higher levels, all exemplars below this level must have a distance of
465     *  at least "mindistsq".
466     */
467    void searchLevel(ResultSet<DistanceType>& result_set, const ElementType* vec, NodePtr node, DistanceType mindist, int& checkCount, int maxCheck,
468                     float epsError, Heap<BranchSt>* heap, DynamicBitset& checked)
469    {
470        if (result_set.worstDist()<mindist) {
471            //			printf("Ignoring branch, too far\n");
472            return;
473        }
474
475        /* If this is a leaf node, then do check and return. */
476        if ((node->child1 == NULL)&&(node->child2 == NULL)) {
477            /*  Do not check same node more than once when searching multiple trees.
478                Once a vector is checked, we set its location in vind to the
479                current checkID.
480             */
481            int index = node->divfeat;
482            if ( checked.test(index) || ((checkCount>=maxCheck)&& result_set.full()) ) return;
483            checked.set(index);
484            checkCount++;
485
486            DistanceType dist = distance_(dataset_[index], vec, veclen_);
487            result_set.addPoint(dist,index);
488
489            return;
490        }
491
492        /* Which child branch should be taken first? */
493        ElementType val = vec[node->divfeat];
494        DistanceType diff = val - node->divval;
495        NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
496        NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
497
498        /* Create a branch record for the branch not taken.  Add distance
499            of this feature boundary (we don't attempt to correct for any
500            use of this feature in a parent node, which is unlikely to
501            happen and would have only a small effect).  Don't bother
502            adding more branches to heap after halfway point, as cost of
503            adding exceeds their value.
504         */
505
506        DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat);
507        //		if (2 * checkCount < maxCheck  ||  !result.full()) {
508        if ((new_distsq*epsError < result_set.worstDist())||  !result_set.full()) {
509            heap->insert( BranchSt(otherChild, new_distsq) );
510        }
511
512        /* Call recursively to search next level down. */
513        searchLevel(result_set, vec, bestChild, mindist, checkCount, maxCheck, epsError, heap, checked);
514    }
515
516    /**
517     * Performs an exact search in the tree starting from a node.
518     */
519    void searchLevelExact(ResultSet<DistanceType>& result_set, const ElementType* vec, const NodePtr node, DistanceType mindist, const float epsError)
520    {
521        /* If this is a leaf node, then do check and return. */
522        if ((node->child1 == NULL)&&(node->child2 == NULL)) {
523            int index = node->divfeat;
524            DistanceType dist = distance_(dataset_[index], vec, veclen_);
525            result_set.addPoint(dist,index);
526            return;
527        }
528
529        /* Which child branch should be taken first? */
530        ElementType val = vec[node->divfeat];
531        DistanceType diff = val - node->divval;
532        NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
533        NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
534
535        /* Create a branch record for the branch not taken.  Add distance
536            of this feature boundary (we don't attempt to correct for any
537            use of this feature in a parent node, which is unlikely to
538            happen and would have only a small effect).  Don't bother
539            adding more branches to heap after halfway point, as cost of
540            adding exceeds their value.
541         */
542
543        DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat);
544
545        /* Call recursively to search next level down. */
546        searchLevelExact(result_set, vec, bestChild, mindist, epsError);
547
548        if (new_distsq*epsError<=result_set.worstDist()) {
549            searchLevelExact(result_set, vec, otherChild, new_distsq, epsError);
550        }
551    }
552
553
554private:
555
556    enum
557    {
558        /**
559         * To improve efficiency, only SAMPLE_MEAN random values are used to
560         * compute the mean and variance at each level when building a tree.
561         * A value of 100 seems to perform as well as using all values.
562         */
563        SAMPLE_MEAN = 100,
564        /**
565         * Top random dimensions to consider
566         *
567         * When creating random trees, the dimension on which to subdivide is
568         * selected at random from among the top RAND_DIM dimensions with the
569         * highest variance.  A value of 5 works well.
570         */
571        RAND_DIM=5
572    };
573
574
575    /**
576     * Number of randomized trees that are used
577     */
578    int trees_;
579
580    /**
581     *  Array of indices to vectors in the dataset.
582     */
583    std::vector<int> vind_;
584
585    /**
586     * The dataset used by this index
587     */
588    const Matrix<ElementType> dataset_;
589
590    IndexParams index_params_;
591
592    size_t size_;
593    size_t veclen_;
594
595
596    DistanceType* mean_;
597    DistanceType* var_;
598
599
600    /**
601     * Array of k-d trees used to find neighbours.
602     */
603    NodePtr* tree_roots_;
604
605    /**
606     * Pooled memory allocator.
607     *
608     * Using a pooled memory allocator is more efficient
609     * than allocating memory directly when there is a large
610     * number small of memory allocations.
611     */
612    PooledAllocator pool_;
613
614    Distance distance_;
615
616
617};   // class KDTreeForest
618
619}
620
621#endif //OPENCV_FLANN_KDTREE_INDEX_H_
622