1// Ceres Solver - A fast non-linear least squares minimizer 2// Copyright 2012 Google Inc. All rights reserved. 3// http://code.google.com/p/ceres-solver/ 4// 5// Redistribution and use in source and binary forms, with or without 6// modification, are permitted provided that the following conditions are met: 7// 8// * Redistributions of source code must retain the above copyright notice, 9// this list of conditions and the following disclaimer. 10// * Redistributions in binary form must reproduce the above copyright notice, 11// this list of conditions and the following disclaimer in the documentation 12// and/or other materials provided with the distribution. 13// * Neither the name of Google Inc. nor the names of its contributors may be 14// used to endorse or promote products derived from this software without 15// specific prior written permission. 16// 17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27// POSSIBILITY OF SUCH DAMAGE. 28// 29// Author: sameeragarwal@google.com (Sameer Agarwal) 30 31#include "ceres/internal/eigen.h" 32#include "ceres/internal/scoped_ptr.h" 33#include "ceres/levenberg_marquardt_strategy.h" 34#include "ceres/linear_solver.h" 35#include "ceres/trust_region_strategy.h" 36#include "glog/logging.h" 37#include "gmock/gmock.h" 38#include "gmock/mock-log.h" 39#include "gtest/gtest.h" 40 41using testing::AllOf; 42using testing::AnyNumber; 43using testing::HasSubstr; 44using testing::ScopedMockLog; 45using testing::_; 46 47namespace ceres { 48namespace internal { 49 50const double kTolerance = 1e-16; 51 52// Linear solver that takes as input a vector and checks that the 53// caller passes the same vector as LinearSolver::PerSolveOptions.D. 54class RegularizationCheckingLinearSolver : public DenseSparseMatrixSolver { 55 public: 56 RegularizationCheckingLinearSolver(const int num_cols, const double* diagonal) 57 : num_cols_(num_cols), 58 diagonal_(diagonal) { 59 } 60 61 virtual ~RegularizationCheckingLinearSolver() {} 62 63 private: 64 virtual LinearSolver::Summary SolveImpl( 65 DenseSparseMatrix* A, 66 const double* b, 67 const LinearSolver::PerSolveOptions& per_solve_options, 68 double* x) { 69 CHECK_NOTNULL(per_solve_options.D); 70 for (int i = 0; i < num_cols_; ++i) { 71 EXPECT_NEAR(per_solve_options.D[i], diagonal_[i], kTolerance) 72 << i << " " << per_solve_options.D[i] << " " << diagonal_[i]; 73 } 74 return LinearSolver::Summary(); 75 } 76 77 const int num_cols_; 78 const double* diagonal_; 79}; 80 81TEST(LevenbergMarquardtStrategy, AcceptRejectStepRadiusScaling) { 82 TrustRegionStrategy::Options options; 83 options.initial_radius = 2.0; 84 options.max_radius = 20.0; 85 options.min_lm_diagonal = 1e-8; 86 options.max_lm_diagonal = 1e8; 87 88 // We need a non-null pointer here, so anything should do. 89 scoped_ptr<LinearSolver> linear_solver( 90 new RegularizationCheckingLinearSolver(0, NULL)); 91 options.linear_solver = linear_solver.get(); 92 93 LevenbergMarquardtStrategy lms(options); 94 EXPECT_EQ(lms.Radius(), options.initial_radius); 95 lms.StepRejected(0.0); 96 EXPECT_EQ(lms.Radius(), 1.0); 97 lms.StepRejected(-1.0); 98 EXPECT_EQ(lms.Radius(), 0.25); 99 lms.StepAccepted(1.0); 100 EXPECT_EQ(lms.Radius(), 0.25 * 3.0); 101 lms.StepAccepted(1.0); 102 EXPECT_EQ(lms.Radius(), 0.25 * 3.0 * 3.0); 103 lms.StepAccepted(0.25); 104 EXPECT_EQ(lms.Radius(), 0.25 * 3.0 * 3.0 / 1.125); 105 lms.StepAccepted(1.0); 106 EXPECT_EQ(lms.Radius(), 0.25 * 3.0 * 3.0 / 1.125 * 3.0); 107 lms.StepAccepted(1.0); 108 EXPECT_EQ(lms.Radius(), 0.25 * 3.0 * 3.0 / 1.125 * 3.0 * 3.0); 109 lms.StepAccepted(1.0); 110 EXPECT_EQ(lms.Radius(), options.max_radius); 111} 112 113TEST(LevenbergMarquardtStrategy, CorrectDiagonalToLinearSolver) { 114 Matrix jacobian(2, 3); 115 jacobian.setZero(); 116 jacobian(0, 0) = 0.0; 117 jacobian(0, 1) = 1.0; 118 jacobian(1, 1) = 1.0; 119 jacobian(0, 2) = 100.0; 120 121 double residual = 1.0; 122 double x[3]; 123 DenseSparseMatrix dsm(jacobian); 124 125 TrustRegionStrategy::Options options; 126 options.initial_radius = 2.0; 127 options.max_radius = 20.0; 128 options.min_lm_diagonal = 1e-2; 129 options.max_lm_diagonal = 1e2; 130 131 double diagonal[3]; 132 diagonal[0] = options.min_lm_diagonal; 133 diagonal[1] = 2.0; 134 diagonal[2] = options.max_lm_diagonal; 135 for (int i = 0; i < 3; ++i) { 136 diagonal[i] = sqrt(diagonal[i] / options.initial_radius); 137 } 138 139 RegularizationCheckingLinearSolver linear_solver(3, diagonal); 140 options.linear_solver = &linear_solver; 141 142 LevenbergMarquardtStrategy lms(options); 143 TrustRegionStrategy::PerSolveOptions pso; 144 145 { 146 ScopedMockLog log; 147 EXPECT_CALL(log, Log(_, _, _)).Times(AnyNumber()); 148 EXPECT_CALL(log, Log(WARNING, _, 149 HasSubstr("Failed to compute a finite step."))); 150 151 TrustRegionStrategy::Summary summary = 152 lms.ComputeStep(pso, &dsm, &residual, x); 153 EXPECT_EQ(summary.termination_type, LINEAR_SOLVER_FAILURE); 154 } 155} 156 157} // namespace internal 158} // namespace ceres 159