1// Ceres Solver - A fast non-linear least squares minimizer 2// Copyright 2010, 2011, 2012 Google Inc. All rights reserved. 3// http://code.google.com/p/ceres-solver/ 4// 5// Redistribution and use in source and binary forms, with or without 6// modification, are permitted provided that the following conditions are met: 7// 8// * Redistributions of source code must retain the above copyright notice, 9// this list of conditions and the following disclaimer. 10// * Redistributions in binary form must reproduce the above copyright notice, 11// this list of conditions and the following disclaimer in the documentation 12// and/or other materials provided with the distribution. 13// * Neither the name of Google Inc. nor the names of its contributors may be 14// used to endorse or promote products derived from this software without 15// specific prior written permission. 16// 17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27// POSSIBILITY OF SUCH DAMAGE. 28// 29// Author: sameeragarwal@google.com (Sameer Agarwal) 30 31#include "ceres/implicit_schur_complement.h" 32 33#include "Eigen/Dense" 34#include "ceres/block_sparse_matrix.h" 35#include "ceres/block_structure.h" 36#include "ceres/internal/eigen.h" 37#include "ceres/internal/scoped_ptr.h" 38#include "ceres/types.h" 39#include "glog/logging.h" 40 41namespace ceres { 42namespace internal { 43 44ImplicitSchurComplement::ImplicitSchurComplement(int num_eliminate_blocks, 45 bool preconditioner) 46 : num_eliminate_blocks_(num_eliminate_blocks), 47 preconditioner_(preconditioner), 48 A_(NULL), 49 D_(NULL), 50 b_(NULL), 51 block_diagonal_EtE_inverse_(NULL), 52 block_diagonal_FtF_inverse_(NULL) { 53} 54 55ImplicitSchurComplement::~ImplicitSchurComplement() { 56} 57 58void ImplicitSchurComplement::Init(const BlockSparseMatrix& A, 59 const double* D, 60 const double* b) { 61 // Since initialization is reasonably heavy, perhaps we can save on 62 // constructing a new object everytime. 63 if (A_ == NULL) { 64 A_.reset(new PartitionedMatrixView(A, num_eliminate_blocks_)); 65 } 66 67 D_ = D; 68 b_ = b; 69 70 // Initialize temporary storage and compute the block diagonals of 71 // E'E and F'E. 72 if (block_diagonal_EtE_inverse_ == NULL) { 73 block_diagonal_EtE_inverse_.reset(A_->CreateBlockDiagonalEtE()); 74 if (preconditioner_) { 75 block_diagonal_FtF_inverse_.reset(A_->CreateBlockDiagonalFtF()); 76 } 77 rhs_.resize(A_->num_cols_f()); 78 rhs_.setZero(); 79 tmp_rows_.resize(A_->num_rows()); 80 tmp_e_cols_.resize(A_->num_cols_e()); 81 tmp_e_cols_2_.resize(A_->num_cols_e()); 82 tmp_f_cols_.resize(A_->num_cols_f()); 83 } else { 84 A_->UpdateBlockDiagonalEtE(block_diagonal_EtE_inverse_.get()); 85 if (preconditioner_) { 86 A_->UpdateBlockDiagonalFtF(block_diagonal_FtF_inverse_.get()); 87 } 88 } 89 90 // The block diagonals of the augmented linear system contain 91 // contributions from the diagonal D if it is non-null. Add that to 92 // the block diagonals and invert them. 93 AddDiagonalAndInvert(D_, block_diagonal_EtE_inverse_.get()); 94 if (preconditioner_) { 95 AddDiagonalAndInvert((D_ == NULL) ? NULL : D_ + A_->num_cols_e(), 96 block_diagonal_FtF_inverse_.get()); 97 } 98 99 // Compute the RHS of the Schur complement system. 100 UpdateRhs(); 101} 102 103// Evaluate the product 104// 105// Sx = [F'F - F'E (E'E)^-1 E'F]x 106// 107// By breaking it down into individual matrix vector products 108// involving the matrices E and F. This is implemented using a 109// PartitionedMatrixView of the input matrix A. 110void ImplicitSchurComplement::RightMultiply(const double* x, double* y) const { 111 // y1 = F x 112 tmp_rows_.setZero(); 113 A_->RightMultiplyF(x, tmp_rows_.data()); 114 115 // y2 = E' y1 116 tmp_e_cols_.setZero(); 117 A_->LeftMultiplyE(tmp_rows_.data(), tmp_e_cols_.data()); 118 119 // y3 = -(E'E)^-1 y2 120 tmp_e_cols_2_.setZero(); 121 block_diagonal_EtE_inverse_->RightMultiply(tmp_e_cols_.data(), 122 tmp_e_cols_2_.data()); 123 tmp_e_cols_2_ *= -1.0; 124 125 // y1 = y1 + E y3 126 A_->RightMultiplyE(tmp_e_cols_2_.data(), tmp_rows_.data()); 127 128 // y5 = D * x 129 if (D_ != NULL) { 130 ConstVectorRef Dref(D_ + A_->num_cols_e(), num_cols()); 131 VectorRef(y, num_cols()) = 132 (Dref.array().square() * 133 ConstVectorRef(x, num_cols()).array()).matrix(); 134 } else { 135 VectorRef(y, num_cols()).setZero(); 136 } 137 138 // y = y5 + F' y1 139 A_->LeftMultiplyF(tmp_rows_.data(), y); 140} 141 142// Given a block diagonal matrix and an optional array of diagonal 143// entries D, add them to the diagonal of the matrix and compute the 144// inverse of each diagonal block. 145void ImplicitSchurComplement::AddDiagonalAndInvert( 146 const double* D, 147 BlockSparseMatrix* block_diagonal) { 148 const CompressedRowBlockStructure* block_diagonal_structure = 149 block_diagonal->block_structure(); 150 for (int r = 0; r < block_diagonal_structure->rows.size(); ++r) { 151 const int row_block_pos = block_diagonal_structure->rows[r].block.position; 152 const int row_block_size = block_diagonal_structure->rows[r].block.size; 153 const Cell& cell = block_diagonal_structure->rows[r].cells[0]; 154 MatrixRef m(block_diagonal->mutable_values() + cell.position, 155 row_block_size, row_block_size); 156 157 if (D != NULL) { 158 ConstVectorRef d(D + row_block_pos, row_block_size); 159 m += d.array().square().matrix().asDiagonal(); 160 } 161 162 m = m 163 .selfadjointView<Eigen::Upper>() 164 .llt() 165 .solve(Matrix::Identity(row_block_size, row_block_size)); 166 } 167} 168 169// Similar to RightMultiply, use the block structure of the matrix A 170// to compute y = (E'E)^-1 (E'b - E'F x). 171void ImplicitSchurComplement::BackSubstitute(const double* x, double* y) { 172 const int num_cols_e = A_->num_cols_e(); 173 const int num_cols_f = A_->num_cols_f(); 174 const int num_cols = A_->num_cols(); 175 const int num_rows = A_->num_rows(); 176 177 // y1 = F x 178 tmp_rows_.setZero(); 179 A_->RightMultiplyF(x, tmp_rows_.data()); 180 181 // y2 = b - y1 182 tmp_rows_ = ConstVectorRef(b_, num_rows) - tmp_rows_; 183 184 // y3 = E' y2 185 tmp_e_cols_.setZero(); 186 A_->LeftMultiplyE(tmp_rows_.data(), tmp_e_cols_.data()); 187 188 // y = (E'E)^-1 y3 189 VectorRef(y, num_cols).setZero(); 190 block_diagonal_EtE_inverse_->RightMultiply(tmp_e_cols_.data(), y); 191 192 // The full solution vector y has two blocks. The first block of 193 // variables corresponds to the eliminated variables, which we just 194 // computed via back substitution. The second block of variables 195 // corresponds to the Schur complement system, so we just copy those 196 // values from the solution to the Schur complement. 197 VectorRef(y + num_cols_e, num_cols_f) = ConstVectorRef(x, num_cols_f); 198} 199 200// Compute the RHS of the Schur complement system. 201// 202// rhs = F'b - F'E (E'E)^-1 E'b 203// 204// Like BackSubstitute, we use the block structure of A to implement 205// this using a series of matrix vector products. 206void ImplicitSchurComplement::UpdateRhs() { 207 // y1 = E'b 208 tmp_e_cols_.setZero(); 209 A_->LeftMultiplyE(b_, tmp_e_cols_.data()); 210 211 // y2 = (E'E)^-1 y1 212 Vector y2 = Vector::Zero(A_->num_cols_e()); 213 block_diagonal_EtE_inverse_->RightMultiply(tmp_e_cols_.data(), y2.data()); 214 215 // y3 = E y2 216 tmp_rows_.setZero(); 217 A_->RightMultiplyE(y2.data(), tmp_rows_.data()); 218 219 // y3 = b - y3 220 tmp_rows_ = ConstVectorRef(b_, A_->num_rows()) - tmp_rows_; 221 222 // rhs = F' y3 223 rhs_.setZero(); 224 A_->LeftMultiplyF(tmp_rows_.data(), rhs_.data()); 225} 226 227} // namespace internal 228} // namespace ceres 229