1// Ceres Solver - A fast non-linear least squares minimizer 2// Copyright 2010, 2011, 2012 Google Inc. All rights reserved. 3// http://code.google.com/p/ceres-solver/ 4// 5// Redistribution and use in source and binary forms, with or without 6// modification, are permitted provided that the following conditions are met: 7// 8// * Redistributions of source code must retain the above copyright notice, 9// this list of conditions and the following disclaimer. 10// * Redistributions in binary form must reproduce the above copyright notice, 11// this list of conditions and the following disclaimer in the documentation 12// and/or other materials provided with the distribution. 13// * Neither the name of Google Inc. nor the names of its contributors may be 14// used to endorse or promote products derived from this software without 15// specific prior written permission. 16// 17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27// POSSIBILITY OF SUCH DAMAGE. 28// 29// Author: sameeragarwal@google.com (Sameer Agarwal) 30 31#include "ceres/linear_least_squares_problems.h" 32 33#include <cstdio> 34#include <string> 35#include <vector> 36#include "ceres/block_sparse_matrix.h" 37#include "ceres/block_structure.h" 38#include "ceres/casts.h" 39#include "ceres/file.h" 40#include "ceres/internal/scoped_ptr.h" 41#include "ceres/stringprintf.h" 42#include "ceres/triplet_sparse_matrix.h" 43#include "ceres/types.h" 44#include "glog/logging.h" 45 46namespace ceres { 47namespace internal { 48 49LinearLeastSquaresProblem* CreateLinearLeastSquaresProblemFromId(int id) { 50 switch (id) { 51 case 0: 52 return LinearLeastSquaresProblem0(); 53 case 1: 54 return LinearLeastSquaresProblem1(); 55 case 2: 56 return LinearLeastSquaresProblem2(); 57 case 3: 58 return LinearLeastSquaresProblem3(); 59 default: 60 LOG(FATAL) << "Unknown problem id requested " << id; 61 } 62 return NULL; 63} 64 65/* 66A = [1 2] 67 [3 4] 68 [6 -10] 69 70b = [ 8 71 18 72 -18] 73 74x = [2 75 3] 76 77D = [1 78 2] 79 80x_D = [1.78448275; 81 2.82327586;] 82 */ 83LinearLeastSquaresProblem* LinearLeastSquaresProblem0() { 84 LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem; 85 86 TripletSparseMatrix* A = new TripletSparseMatrix(3, 2, 6); 87 problem->b.reset(new double[3]); 88 problem->D.reset(new double[2]); 89 90 problem->x.reset(new double[2]); 91 problem->x_D.reset(new double[2]); 92 93 int* Ai = A->mutable_rows(); 94 int* Aj = A->mutable_cols(); 95 double* Ax = A->mutable_values(); 96 97 int counter = 0; 98 for (int i = 0; i < 3; ++i) { 99 for (int j = 0; j< 2; ++j) { 100 Ai[counter]=i; 101 Aj[counter]=j; 102 ++counter; 103 } 104 }; 105 106 Ax[0] = 1.; 107 Ax[1] = 2.; 108 Ax[2] = 3.; 109 Ax[3] = 4.; 110 Ax[4] = 6; 111 Ax[5] = -10; 112 A->set_num_nonzeros(6); 113 problem->A.reset(A); 114 115 problem->b[0] = 8; 116 problem->b[1] = 18; 117 problem->b[2] = -18; 118 119 problem->x[0] = 2.0; 120 problem->x[1] = 3.0; 121 122 problem->D[0] = 1; 123 problem->D[1] = 2; 124 125 problem->x_D[0] = 1.78448275; 126 problem->x_D[1] = 2.82327586; 127 return problem; 128} 129 130 131/* 132 A = [1 0 | 2 0 0 133 3 0 | 0 4 0 134 0 5 | 0 0 6 135 0 7 | 8 0 0 136 0 9 | 1 0 0 137 0 0 | 1 1 1] 138 139 b = [0 140 1 141 2 142 3 143 4 144 5] 145 146 c = A'* b = [ 3 147 67 148 33 149 9 150 17] 151 152 A'A = [10 0 2 12 0 153 0 155 65 0 30 154 2 65 70 1 1 155 12 0 1 17 1 156 0 30 1 1 37] 157 158 S = [ 42.3419 -1.4000 -11.5806 159 -1.4000 2.6000 1.0000 160 11.5806 1.0000 31.1935] 161 162 r = [ 4.3032 163 5.4000 164 5.0323] 165 166 S\r = [ 0.2102 167 2.1367 168 0.1388] 169 170 A\b = [-2.3061 171 0.3172 172 0.2102 173 2.1367 174 0.1388] 175*/ 176// The following two functions create a TripletSparseMatrix and a 177// BlockSparseMatrix version of this problem. 178 179// TripletSparseMatrix version. 180LinearLeastSquaresProblem* LinearLeastSquaresProblem1() { 181 int num_rows = 6; 182 int num_cols = 5; 183 184 LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem; 185 TripletSparseMatrix* A = new TripletSparseMatrix(num_rows, 186 num_cols, 187 num_rows * num_cols); 188 problem->b.reset(new double[num_rows]); 189 problem->D.reset(new double[num_cols]); 190 problem->num_eliminate_blocks = 2; 191 192 int* rows = A->mutable_rows(); 193 int* cols = A->mutable_cols(); 194 double* values = A->mutable_values(); 195 196 int nnz = 0; 197 198 // Row 1 199 { 200 rows[nnz] = 0; 201 cols[nnz] = 0; 202 values[nnz++] = 1; 203 204 rows[nnz] = 0; 205 cols[nnz] = 2; 206 values[nnz++] = 2; 207 } 208 209 // Row 2 210 { 211 rows[nnz] = 1; 212 cols[nnz] = 0; 213 values[nnz++] = 3; 214 215 rows[nnz] = 1; 216 cols[nnz] = 3; 217 values[nnz++] = 4; 218 } 219 220 // Row 3 221 { 222 rows[nnz] = 2; 223 cols[nnz] = 1; 224 values[nnz++] = 5; 225 226 rows[nnz] = 2; 227 cols[nnz] = 4; 228 values[nnz++] = 6; 229 } 230 231 // Row 4 232 { 233 rows[nnz] = 3; 234 cols[nnz] = 1; 235 values[nnz++] = 7; 236 237 rows[nnz] = 3; 238 cols[nnz] = 2; 239 values[nnz++] = 8; 240 } 241 242 // Row 5 243 { 244 rows[nnz] = 4; 245 cols[nnz] = 1; 246 values[nnz++] = 9; 247 248 rows[nnz] = 4; 249 cols[nnz] = 2; 250 values[nnz++] = 1; 251 } 252 253 // Row 6 254 { 255 rows[nnz] = 5; 256 cols[nnz] = 2; 257 values[nnz++] = 1; 258 259 rows[nnz] = 5; 260 cols[nnz] = 3; 261 values[nnz++] = 1; 262 263 rows[nnz] = 5; 264 cols[nnz] = 4; 265 values[nnz++] = 1; 266 } 267 268 A->set_num_nonzeros(nnz); 269 CHECK(A->IsValid()); 270 271 problem->A.reset(A); 272 273 for (int i = 0; i < num_cols; ++i) { 274 problem->D.get()[i] = 1; 275 } 276 277 for (int i = 0; i < num_rows; ++i) { 278 problem->b.get()[i] = i; 279 } 280 281 return problem; 282} 283 284// BlockSparseMatrix version 285LinearLeastSquaresProblem* LinearLeastSquaresProblem2() { 286 int num_rows = 6; 287 int num_cols = 5; 288 289 LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem; 290 291 problem->b.reset(new double[num_rows]); 292 problem->D.reset(new double[num_cols]); 293 problem->num_eliminate_blocks = 2; 294 295 CompressedRowBlockStructure* bs = new CompressedRowBlockStructure; 296 scoped_array<double> values(new double[num_rows * num_cols]); 297 298 for (int c = 0; c < num_cols; ++c) { 299 bs->cols.push_back(Block()); 300 bs->cols.back().size = 1; 301 bs->cols.back().position = c; 302 } 303 304 int nnz = 0; 305 306 // Row 1 307 { 308 values[nnz++] = 1; 309 values[nnz++] = 2; 310 311 bs->rows.push_back(CompressedRow()); 312 CompressedRow& row = bs->rows.back(); 313 row.block.size = 1; 314 row.block.position = 0; 315 row.cells.push_back(Cell(0, 0)); 316 row.cells.push_back(Cell(2, 1)); 317 } 318 319 // Row 2 320 { 321 values[nnz++] = 3; 322 values[nnz++] = 4; 323 324 bs->rows.push_back(CompressedRow()); 325 CompressedRow& row = bs->rows.back(); 326 row.block.size = 1; 327 row.block.position = 1; 328 row.cells.push_back(Cell(0, 2)); 329 row.cells.push_back(Cell(3, 3)); 330 } 331 332 // Row 3 333 { 334 values[nnz++] = 5; 335 values[nnz++] = 6; 336 337 bs->rows.push_back(CompressedRow()); 338 CompressedRow& row = bs->rows.back(); 339 row.block.size = 1; 340 row.block.position = 2; 341 row.cells.push_back(Cell(1, 4)); 342 row.cells.push_back(Cell(4, 5)); 343 } 344 345 // Row 4 346 { 347 values[nnz++] = 7; 348 values[nnz++] = 8; 349 350 bs->rows.push_back(CompressedRow()); 351 CompressedRow& row = bs->rows.back(); 352 row.block.size = 1; 353 row.block.position = 3; 354 row.cells.push_back(Cell(1, 6)); 355 row.cells.push_back(Cell(2, 7)); 356 } 357 358 // Row 5 359 { 360 values[nnz++] = 9; 361 values[nnz++] = 1; 362 363 bs->rows.push_back(CompressedRow()); 364 CompressedRow& row = bs->rows.back(); 365 row.block.size = 1; 366 row.block.position = 4; 367 row.cells.push_back(Cell(1, 8)); 368 row.cells.push_back(Cell(2, 9)); 369 } 370 371 // Row 6 372 { 373 values[nnz++] = 1; 374 values[nnz++] = 1; 375 values[nnz++] = 1; 376 377 bs->rows.push_back(CompressedRow()); 378 CompressedRow& row = bs->rows.back(); 379 row.block.size = 1; 380 row.block.position = 5; 381 row.cells.push_back(Cell(2, 10)); 382 row.cells.push_back(Cell(3, 11)); 383 row.cells.push_back(Cell(4, 12)); 384 } 385 386 BlockSparseMatrix* A = new BlockSparseMatrix(bs); 387 memcpy(A->mutable_values(), values.get(), nnz * sizeof(*A->values())); 388 389 for (int i = 0; i < num_cols; ++i) { 390 problem->D.get()[i] = 1; 391 } 392 393 for (int i = 0; i < num_rows; ++i) { 394 problem->b.get()[i] = i; 395 } 396 397 problem->A.reset(A); 398 399 return problem; 400} 401 402 403/* 404 A = [1 0 405 3 0 406 0 5 407 0 7 408 0 9 409 0 0] 410 411 b = [0 412 1 413 2 414 3 415 4 416 5] 417*/ 418// BlockSparseMatrix version 419LinearLeastSquaresProblem* LinearLeastSquaresProblem3() { 420 int num_rows = 5; 421 int num_cols = 2; 422 423 LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem; 424 425 problem->b.reset(new double[num_rows]); 426 problem->D.reset(new double[num_cols]); 427 problem->num_eliminate_blocks = 2; 428 429 CompressedRowBlockStructure* bs = new CompressedRowBlockStructure; 430 scoped_array<double> values(new double[num_rows * num_cols]); 431 432 for (int c = 0; c < num_cols; ++c) { 433 bs->cols.push_back(Block()); 434 bs->cols.back().size = 1; 435 bs->cols.back().position = c; 436 } 437 438 int nnz = 0; 439 440 // Row 1 441 { 442 values[nnz++] = 1; 443 bs->rows.push_back(CompressedRow()); 444 CompressedRow& row = bs->rows.back(); 445 row.block.size = 1; 446 row.block.position = 0; 447 row.cells.push_back(Cell(0, 0)); 448 } 449 450 // Row 2 451 { 452 values[nnz++] = 3; 453 bs->rows.push_back(CompressedRow()); 454 CompressedRow& row = bs->rows.back(); 455 row.block.size = 1; 456 row.block.position = 1; 457 row.cells.push_back(Cell(0, 1)); 458 } 459 460 // Row 3 461 { 462 values[nnz++] = 5; 463 bs->rows.push_back(CompressedRow()); 464 CompressedRow& row = bs->rows.back(); 465 row.block.size = 1; 466 row.block.position = 2; 467 row.cells.push_back(Cell(1, 2)); 468 } 469 470 // Row 4 471 { 472 values[nnz++] = 7; 473 bs->rows.push_back(CompressedRow()); 474 CompressedRow& row = bs->rows.back(); 475 row.block.size = 1; 476 row.block.position = 3; 477 row.cells.push_back(Cell(1, 3)); 478 } 479 480 // Row 5 481 { 482 values[nnz++] = 9; 483 bs->rows.push_back(CompressedRow()); 484 CompressedRow& row = bs->rows.back(); 485 row.block.size = 1; 486 row.block.position = 4; 487 row.cells.push_back(Cell(1, 4)); 488 } 489 490 BlockSparseMatrix* A = new BlockSparseMatrix(bs); 491 memcpy(A->mutable_values(), values.get(), nnz * sizeof(*A->values())); 492 493 for (int i = 0; i < num_cols; ++i) { 494 problem->D.get()[i] = 1; 495 } 496 497 for (int i = 0; i < num_rows; ++i) { 498 problem->b.get()[i] = i; 499 } 500 501 problem->A.reset(A); 502 503 return problem; 504} 505 506namespace { 507bool DumpLinearLeastSquaresProblemToConsole(const SparseMatrix* A, 508 const double* D, 509 const double* b, 510 const double* x, 511 int num_eliminate_blocks) { 512 CHECK_NOTNULL(A); 513 Matrix AA; 514 A->ToDenseMatrix(&AA); 515 LOG(INFO) << "A^T: \n" << AA.transpose(); 516 517 if (D != NULL) { 518 LOG(INFO) << "A's appended diagonal:\n" 519 << ConstVectorRef(D, A->num_cols()); 520 } 521 522 if (b != NULL) { 523 LOG(INFO) << "b: \n" << ConstVectorRef(b, A->num_rows()); 524 } 525 526 if (x != NULL) { 527 LOG(INFO) << "x: \n" << ConstVectorRef(x, A->num_cols()); 528 } 529 return true; 530}; 531 532void WriteArrayToFileOrDie(const string& filename, 533 const double* x, 534 const int size) { 535 CHECK_NOTNULL(x); 536 VLOG(2) << "Writing array to: " << filename; 537 FILE* fptr = fopen(filename.c_str(), "w"); 538 CHECK_NOTNULL(fptr); 539 for (int i = 0; i < size; ++i) { 540 fprintf(fptr, "%17f\n", x[i]); 541 } 542 fclose(fptr); 543} 544 545bool DumpLinearLeastSquaresProblemToTextFile(const string& filename_base, 546 const SparseMatrix* A, 547 const double* D, 548 const double* b, 549 const double* x, 550 int num_eliminate_blocks) { 551 CHECK_NOTNULL(A); 552 LOG(INFO) << "writing to: " << filename_base << "*"; 553 554 string matlab_script; 555 StringAppendF(&matlab_script, 556 "function lsqp = load_trust_region_problem()\n"); 557 StringAppendF(&matlab_script, 558 "lsqp.num_rows = %d;\n", A->num_rows()); 559 StringAppendF(&matlab_script, 560 "lsqp.num_cols = %d;\n", A->num_cols()); 561 562 { 563 string filename = filename_base + "_A.txt"; 564 FILE* fptr = fopen(filename.c_str(), "w"); 565 CHECK_NOTNULL(fptr); 566 A->ToTextFile(fptr); 567 fclose(fptr); 568 StringAppendF(&matlab_script, 569 "tmp = load('%s', '-ascii');\n", filename.c_str()); 570 StringAppendF( 571 &matlab_script, 572 "lsqp.A = sparse(tmp(:, 1) + 1, tmp(:, 2) + 1, tmp(:, 3), %d, %d);\n", 573 A->num_rows(), 574 A->num_cols()); 575 } 576 577 578 if (D != NULL) { 579 string filename = filename_base + "_D.txt"; 580 WriteArrayToFileOrDie(filename, D, A->num_cols()); 581 StringAppendF(&matlab_script, 582 "lsqp.D = load('%s', '-ascii');\n", filename.c_str()); 583 } 584 585 if (b != NULL) { 586 string filename = filename_base + "_b.txt"; 587 WriteArrayToFileOrDie(filename, b, A->num_rows()); 588 StringAppendF(&matlab_script, 589 "lsqp.b = load('%s', '-ascii');\n", filename.c_str()); 590 } 591 592 if (x != NULL) { 593 string filename = filename_base + "_x.txt"; 594 WriteArrayToFileOrDie(filename, x, A->num_cols()); 595 StringAppendF(&matlab_script, 596 "lsqp.x = load('%s', '-ascii');\n", filename.c_str()); 597 } 598 599 string matlab_filename = filename_base + ".m"; 600 WriteStringToFileOrDie(matlab_script, matlab_filename); 601 return true; 602} 603} // namespace 604 605bool DumpLinearLeastSquaresProblem(const string& filename_base, 606 DumpFormatType dump_format_type, 607 const SparseMatrix* A, 608 const double* D, 609 const double* b, 610 const double* x, 611 int num_eliminate_blocks) { 612 switch (dump_format_type) { 613 case CONSOLE: 614 return DumpLinearLeastSquaresProblemToConsole(A, D, b, x, 615 num_eliminate_blocks); 616 case TEXTFILE: 617 return DumpLinearLeastSquaresProblemToTextFile(filename_base, 618 A, D, b, x, 619 num_eliminate_blocks); 620 default: 621 LOG(FATAL) << "Unknown DumpFormatType " << dump_format_type; 622 }; 623 624 return true; 625} 626 627} // namespace internal 628} // namespace ceres 629