1// Ceres Solver - A fast non-linear least squares minimizer 2// Copyright 2010, 2011, 2012 Google Inc. All rights reserved. 3// http://code.google.com/p/ceres-solver/ 4// 5// Redistribution and use in source and binary forms, with or without 6// modification, are permitted provided that the following conditions are met: 7// 8// * Redistributions of source code must retain the above copyright notice, 9// this list of conditions and the following disclaimer. 10// * Redistributions in binary form must reproduce the above copyright notice, 11// this list of conditions and the following disclaimer in the documentation 12// and/or other materials provided with the distribution. 13// * Neither the name of Google Inc. nor the names of its contributors may be 14// used to endorse or promote products derived from this software without 15// specific prior written permission. 16// 17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27// POSSIBILITY OF SUCH DAMAGE. 28// 29// Author: sameeragarwal@google.com (Sameer Agarwal) 30 31#include "ceres/schur_eliminator.h" 32 33#include "Eigen/Dense" 34#include "ceres/block_random_access_dense_matrix.h" 35#include "ceres/block_sparse_matrix.h" 36#include "ceres/casts.h" 37#include "ceres/detect_structure.h" 38#include "ceres/internal/eigen.h" 39#include "ceres/internal/scoped_ptr.h" 40#include "ceres/linear_least_squares_problems.h" 41#include "ceres/test_util.h" 42#include "ceres/triplet_sparse_matrix.h" 43#include "ceres/types.h" 44#include "glog/logging.h" 45#include "gtest/gtest.h" 46 47// TODO(sameeragarwal): Reduce the size of these tests and redo the 48// parameterization to be more efficient. 49 50namespace ceres { 51namespace internal { 52 53class SchurEliminatorTest : public ::testing::Test { 54 protected: 55 void SetUpFromId(int id) { 56 scoped_ptr<LinearLeastSquaresProblem> 57 problem(CreateLinearLeastSquaresProblemFromId(id)); 58 CHECK_NOTNULL(problem.get()); 59 SetupHelper(problem.get()); 60 } 61 62 void SetUpFromFilename(const string& filename) { 63 scoped_ptr<LinearLeastSquaresProblem> 64 problem(CreateLinearLeastSquaresProblemFromFile(filename)); 65 CHECK_NOTNULL(problem.get()); 66 SetupHelper(problem.get()); 67 } 68 69 void SetupHelper(LinearLeastSquaresProblem* problem) { 70 A.reset(down_cast<BlockSparseMatrix*>(problem->A.release())); 71 b.reset(problem->b.release()); 72 D.reset(problem->D.release()); 73 74 num_eliminate_blocks = problem->num_eliminate_blocks; 75 num_eliminate_cols = 0; 76 const CompressedRowBlockStructure* bs = A->block_structure(); 77 78 for (int i = 0; i < num_eliminate_blocks; ++i) { 79 num_eliminate_cols += bs->cols[i].size; 80 } 81 } 82 83 // Compute the golden values for the reduced linear system and the 84 // solution to the linear least squares problem using dense linear 85 // algebra. 86 void ComputeReferenceSolution(const Vector& D) { 87 Matrix J; 88 A->ToDenseMatrix(&J); 89 VectorRef f(b.get(), J.rows()); 90 91 Matrix H = (D.cwiseProduct(D)).asDiagonal(); 92 H.noalias() += J.transpose() * J; 93 94 const Vector g = J.transpose() * f; 95 const int schur_size = J.cols() - num_eliminate_cols; 96 97 lhs_expected.resize(schur_size, schur_size); 98 lhs_expected.setZero(); 99 100 rhs_expected.resize(schur_size); 101 rhs_expected.setZero(); 102 103 sol_expected.resize(J.cols()); 104 sol_expected.setZero(); 105 106 Matrix P = H.block(0, 0, num_eliminate_cols, num_eliminate_cols); 107 Matrix Q = H.block(0, 108 num_eliminate_cols, 109 num_eliminate_cols, 110 schur_size); 111 Matrix R = H.block(num_eliminate_cols, 112 num_eliminate_cols, 113 schur_size, 114 schur_size); 115 int row = 0; 116 const CompressedRowBlockStructure* bs = A->block_structure(); 117 for (int i = 0; i < num_eliminate_blocks; ++i) { 118 const int block_size = bs->cols[i].size; 119 P.block(row, row, block_size, block_size) = 120 P 121 .block(row, row, block_size, block_size) 122 .ldlt() 123 .solve(Matrix::Identity(block_size, block_size)); 124 row += block_size; 125 } 126 127 lhs_expected 128 .triangularView<Eigen::Upper>() = R - Q.transpose() * P * Q; 129 rhs_expected = 130 g.tail(schur_size) - Q.transpose() * P * g.head(num_eliminate_cols); 131 sol_expected = H.ldlt().solve(g); 132 } 133 134 void EliminateSolveAndCompare(const VectorRef& diagonal, 135 bool use_static_structure, 136 const double relative_tolerance) { 137 const CompressedRowBlockStructure* bs = A->block_structure(); 138 const int num_col_blocks = bs->cols.size(); 139 vector<int> blocks(num_col_blocks - num_eliminate_blocks, 0); 140 for (int i = num_eliminate_blocks; i < num_col_blocks; ++i) { 141 blocks[i - num_eliminate_blocks] = bs->cols[i].size; 142 } 143 144 BlockRandomAccessDenseMatrix lhs(blocks); 145 146 const int num_cols = A->num_cols(); 147 const int schur_size = lhs.num_rows(); 148 149 Vector rhs(schur_size); 150 151 LinearSolver::Options options; 152 options.elimination_groups.push_back(num_eliminate_blocks); 153 if (use_static_structure) { 154 DetectStructure(*bs, 155 num_eliminate_blocks, 156 &options.row_block_size, 157 &options.e_block_size, 158 &options.f_block_size); 159 } 160 161 scoped_ptr<SchurEliminatorBase> eliminator; 162 eliminator.reset(SchurEliminatorBase::Create(options)); 163 eliminator->Init(num_eliminate_blocks, A->block_structure()); 164 eliminator->Eliminate(A.get(), b.get(), diagonal.data(), &lhs, rhs.data()); 165 166 MatrixRef lhs_ref(lhs.mutable_values(), lhs.num_rows(), lhs.num_cols()); 167 Vector reduced_sol = 168 lhs_ref 169 .selfadjointView<Eigen::Upper>() 170 .ldlt() 171 .solve(rhs); 172 173 // Solution to the linear least squares problem. 174 Vector sol(num_cols); 175 sol.setZero(); 176 sol.tail(schur_size) = reduced_sol; 177 eliminator->BackSubstitute(A.get(), 178 b.get(), 179 diagonal.data(), 180 reduced_sol.data(), 181 sol.data()); 182 183 Matrix delta = (lhs_ref - lhs_expected).selfadjointView<Eigen::Upper>(); 184 double diff = delta.norm(); 185 EXPECT_NEAR(diff / lhs_expected.norm(), 0.0, relative_tolerance); 186 EXPECT_NEAR((rhs - rhs_expected).norm() / rhs_expected.norm(), 0.0, 187 relative_tolerance); 188 EXPECT_NEAR((sol - sol_expected).norm() / sol_expected.norm(), 0.0, 189 relative_tolerance); 190 } 191 192 scoped_ptr<BlockSparseMatrix> A; 193 scoped_array<double> b; 194 scoped_array<double> D; 195 int num_eliminate_blocks; 196 int num_eliminate_cols; 197 198 Matrix lhs_expected; 199 Vector rhs_expected; 200 Vector sol_expected; 201}; 202 203TEST_F(SchurEliminatorTest, ScalarProblem) { 204 SetUpFromId(2); 205 Vector zero(A->num_cols()); 206 zero.setZero(); 207 208 ComputeReferenceSolution(VectorRef(zero.data(), A->num_cols())); 209 EliminateSolveAndCompare(VectorRef(zero.data(), A->num_cols()), true, 1e-14); 210 EliminateSolveAndCompare(VectorRef(zero.data(), A->num_cols()), false, 1e-14); 211 212 ComputeReferenceSolution(VectorRef(D.get(), A->num_cols())); 213 EliminateSolveAndCompare(VectorRef(D.get(), A->num_cols()), true, 1e-14); 214 EliminateSolveAndCompare(VectorRef(D.get(), A->num_cols()), false, 1e-14); 215} 216 217#ifndef CERES_NO_PROTOCOL_BUFFERS 218TEST_F(SchurEliminatorTest, BlockProblem) { 219 const string input_file = TestFileAbsolutePath("problem-6-1384-000.lsqp"); 220 221 SetUpFromFilename(input_file); 222 ComputeReferenceSolution(VectorRef(D.get(), A->num_cols())); 223 EliminateSolveAndCompare(VectorRef(D.get(), A->num_cols()), true, 1e-10); 224 EliminateSolveAndCompare(VectorRef(D.get(), A->num_cols()), false, 1e-10); 225} 226#endif // CERES_NO_PROTOCOL_BUFFERS 227 228} // namespace internal 229} // namespace ceres 230