canonical_views_clustering.cc revision 0ae28bd5885b5daa526898fcf7c323dc2c3e1963
1// Ceres Solver - A fast non-linear least squares minimizer 2// Copyright 2010, 2011, 2012 Google Inc. All rights reserved. 3// http://code.google.com/p/ceres-solver/ 4// 5// Redistribution and use in source and binary forms, with or without 6// modification, are permitted provided that the following conditions are met: 7// 8// * Redistributions of source code must retain the above copyright notice, 9// this list of conditions and the following disclaimer. 10// * Redistributions in binary form must reproduce the above copyright notice, 11// this list of conditions and the following disclaimer in the documentation 12// and/or other materials provided with the distribution. 13// * Neither the name of Google Inc. nor the names of its contributors may be 14// used to endorse or promote products derived from this software without 15// specific prior written permission. 16// 17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27// POSSIBILITY OF SUCH DAMAGE. 28// 29// Author: David Gallup (dgallup@google.com) 30// Sameer Agarwal (sameeragarwal@google.com) 31 32#include "ceres/canonical_views_clustering.h" 33 34#include "ceres/collections_port.h" 35#include "ceres/graph.h" 36#include "ceres/internal/macros.h" 37#include "ceres/map_util.h" 38#include "glog/logging.h" 39 40namespace ceres { 41namespace internal { 42 43typedef HashMap<int, int> IntMap; 44typedef HashSet<int> IntSet; 45 46class CanonicalViewsClustering { 47 public: 48 CanonicalViewsClustering() {} 49 50 // Compute the canonical views clustering of the vertices of the 51 // graph. centers will contain the vertices that are the identified 52 // as the canonical views/cluster centers, and membership is a map 53 // from vertices to cluster_ids. The i^th cluster center corresponds 54 // to the i^th cluster. It is possible depending on the 55 // configuration of the clustering algorithm that some of the 56 // vertices may not be assigned to any cluster. In this case they 57 // are assigned to a cluster with id = kInvalidClusterId. 58 void ComputeClustering(const Graph<int>& graph, 59 const CanonicalViewsClusteringOptions& options, 60 vector<int>* centers, 61 IntMap* membership); 62 63 private: 64 void FindValidViews(IntSet* valid_views) const; 65 double ComputeClusteringQualityDifference(const int candidate, 66 const vector<int>& centers) const; 67 void UpdateCanonicalViewAssignments(const int canonical_view); 68 void ComputeClusterMembership(const vector<int>& centers, 69 IntMap* membership) const; 70 71 CanonicalViewsClusteringOptions options_; 72 const Graph<int>* graph_; 73 // Maps a view to its representative canonical view (its cluster 74 // center). 75 IntMap view_to_canonical_view_; 76 // Maps a view to its similarity to its current cluster center. 77 HashMap<int, double> view_to_canonical_view_similarity_; 78 CERES_DISALLOW_COPY_AND_ASSIGN(CanonicalViewsClustering); 79}; 80 81void ComputeCanonicalViewsClustering( 82 const Graph<int>& graph, 83 const CanonicalViewsClusteringOptions& options, 84 vector<int>* centers, 85 IntMap* membership) { 86 time_t start_time = time(NULL); 87 CanonicalViewsClustering cv; 88 cv.ComputeClustering(graph, options, centers, membership); 89 VLOG(2) << "Canonical views clustering time (secs): " 90 << time(NULL) - start_time; 91} 92 93// Implementation of CanonicalViewsClustering 94void CanonicalViewsClustering::ComputeClustering( 95 const Graph<int>& graph, 96 const CanonicalViewsClusteringOptions& options, 97 vector<int>* centers, 98 IntMap* membership) { 99 options_ = options; 100 CHECK_NOTNULL(centers)->clear(); 101 CHECK_NOTNULL(membership)->clear(); 102 graph_ = &graph; 103 104 IntSet valid_views; 105 FindValidViews(&valid_views); 106 while (valid_views.size() > 0) { 107 // Find the next best canonical view. 108 double best_difference = -std::numeric_limits<double>::max(); 109 int best_view = 0; 110 111 // TODO(sameeragarwal): Make this loop multi-threaded. 112 for (IntSet::const_iterator view = valid_views.begin(); 113 view != valid_views.end(); 114 ++view) { 115 const double difference = 116 ComputeClusteringQualityDifference(*view, *centers); 117 if (difference > best_difference) { 118 best_difference = difference; 119 best_view = *view; 120 } 121 } 122 123 CHECK_GT(best_difference, -std::numeric_limits<double>::max()); 124 125 // Add canonical view if quality improves, or if minimum is not 126 // yet met, otherwise break. 127 if ((best_difference <= 0) && 128 (centers->size() >= options_.min_views)) { 129 break; 130 } 131 132 centers->push_back(best_view); 133 valid_views.erase(best_view); 134 UpdateCanonicalViewAssignments(best_view); 135 } 136 137 ComputeClusterMembership(*centers, membership); 138} 139 140// Return the set of vertices of the graph which have valid vertex 141// weights. 142void CanonicalViewsClustering::FindValidViews( 143 IntSet* valid_views) const { 144 const IntSet& views = graph_->vertices(); 145 for (IntSet::const_iterator view = views.begin(); 146 view != views.end(); 147 ++view) { 148 if (graph_->VertexWeight(*view) != Graph<int>::InvalidWeight()) { 149 valid_views->insert(*view); 150 } 151 } 152} 153 154// Computes the difference in the quality score if 'candidate' were 155// added to the set of canonical views. 156double CanonicalViewsClustering::ComputeClusteringQualityDifference( 157 const int candidate, 158 const vector<int>& centers) const { 159 // View score. 160 double difference = 161 options_.view_score_weight * graph_->VertexWeight(candidate); 162 163 // Compute how much the quality score changes if the candidate view 164 // was added to the list of canonical views and its nearest 165 // neighbors became members of its cluster. 166 const IntSet& neighbors = graph_->Neighbors(candidate); 167 for (IntSet::const_iterator neighbor = neighbors.begin(); 168 neighbor != neighbors.end(); 169 ++neighbor) { 170 const double old_similarity = 171 FindWithDefault(view_to_canonical_view_similarity_, *neighbor, 0.0); 172 const double new_similarity = graph_->EdgeWeight(*neighbor, candidate); 173 if (new_similarity > old_similarity) { 174 difference += new_similarity - old_similarity; 175 } 176 } 177 178 // Number of views penalty. 179 difference -= options_.size_penalty_weight; 180 181 // Orthogonality. 182 for (int i = 0; i < centers.size(); ++i) { 183 difference -= options_.similarity_penalty_weight * 184 graph_->EdgeWeight(centers[i], candidate); 185 } 186 187 return difference; 188} 189 190// Reassign views if they're more similar to the new canonical view. 191void CanonicalViewsClustering::UpdateCanonicalViewAssignments( 192 const int canonical_view) { 193 const IntSet& neighbors = graph_->Neighbors(canonical_view); 194 for (IntSet::const_iterator neighbor = neighbors.begin(); 195 neighbor != neighbors.end(); 196 ++neighbor) { 197 const double old_similarity = 198 FindWithDefault(view_to_canonical_view_similarity_, *neighbor, 0.0); 199 const double new_similarity = 200 graph_->EdgeWeight(*neighbor, canonical_view); 201 if (new_similarity > old_similarity) { 202 view_to_canonical_view_[*neighbor] = canonical_view; 203 view_to_canonical_view_similarity_[*neighbor] = new_similarity; 204 } 205 } 206} 207 208// Assign a cluster id to each view. 209void CanonicalViewsClustering::ComputeClusterMembership( 210 const vector<int>& centers, 211 IntMap* membership) const { 212 CHECK_NOTNULL(membership)->clear(); 213 214 // The i^th cluster has cluster id i. 215 IntMap center_to_cluster_id; 216 for (int i = 0; i < centers.size(); ++i) { 217 center_to_cluster_id[centers[i]] = i; 218 } 219 220 static const int kInvalidClusterId = -1; 221 222 const IntSet& views = graph_->vertices(); 223 for (IntSet::const_iterator view = views.begin(); 224 view != views.end(); 225 ++view) { 226 IntMap::const_iterator it = 227 view_to_canonical_view_.find(*view); 228 int cluster_id = kInvalidClusterId; 229 if (it != view_to_canonical_view_.end()) { 230 cluster_id = FindOrDie(center_to_cluster_id, it->second); 231 } 232 233 InsertOrDie(membership, *view, cluster_id); 234 } 235} 236 237} // namespace internal 238} // namespace ceres 239