1// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2010, 2011, 2012 Google Inc. All rights reserved.
3// http://code.google.com/p/ceres-solver/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9//   this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11//   this list of conditions and the following disclaimer in the documentation
12//   and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14//   used to endorse or promote products derived from this software without
15//   specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: sameeragarwal@google.com (Sameer Agarwal)
30
31#include "ceres/linear_least_squares_problems.h"
32
33#include <cstdio>
34#include <string>
35#include <vector>
36#include "ceres/block_sparse_matrix.h"
37#include "ceres/block_structure.h"
38#include "ceres/casts.h"
39#include "ceres/file.h"
40#include "ceres/internal/scoped_ptr.h"
41#include "ceres/stringprintf.h"
42#include "ceres/triplet_sparse_matrix.h"
43#include "ceres/types.h"
44#include "glog/logging.h"
45
46namespace ceres {
47namespace internal {
48
49LinearLeastSquaresProblem* CreateLinearLeastSquaresProblemFromId(int id) {
50  switch (id) {
51    case 0:
52      return LinearLeastSquaresProblem0();
53    case 1:
54      return LinearLeastSquaresProblem1();
55    case 2:
56      return LinearLeastSquaresProblem2();
57    case 3:
58      return LinearLeastSquaresProblem3();
59    default:
60      LOG(FATAL) << "Unknown problem id requested " << id;
61  }
62  return NULL;
63}
64
65/*
66A = [1   2]
67    [3   4]
68    [6 -10]
69
70b = [  8
71      18
72     -18]
73
74x = [2
75     3]
76
77D = [1
78     2]
79
80x_D = [1.78448275;
81       2.82327586;]
82 */
83LinearLeastSquaresProblem* LinearLeastSquaresProblem0() {
84  LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;
85
86  TripletSparseMatrix* A = new TripletSparseMatrix(3, 2, 6);
87  problem->b.reset(new double[3]);
88  problem->D.reset(new double[2]);
89
90  problem->x.reset(new double[2]);
91  problem->x_D.reset(new double[2]);
92
93  int* Ai = A->mutable_rows();
94  int* Aj = A->mutable_cols();
95  double* Ax = A->mutable_values();
96
97  int counter = 0;
98  for (int i = 0; i < 3; ++i) {
99    for (int j = 0; j< 2; ++j) {
100      Ai[counter]=i;
101      Aj[counter]=j;
102      ++counter;
103    }
104  };
105
106  Ax[0] = 1.;
107  Ax[1] = 2.;
108  Ax[2] = 3.;
109  Ax[3] = 4.;
110  Ax[4] = 6;
111  Ax[5] = -10;
112  A->set_num_nonzeros(6);
113  problem->A.reset(A);
114
115  problem->b[0] = 8;
116  problem->b[1] = 18;
117  problem->b[2] = -18;
118
119  problem->x[0] = 2.0;
120  problem->x[1] = 3.0;
121
122  problem->D[0] = 1;
123  problem->D[1] = 2;
124
125  problem->x_D[0] = 1.78448275;
126  problem->x_D[1] = 2.82327586;
127  return problem;
128}
129
130
131/*
132      A = [1 0  | 2 0 0
133           3 0  | 0 4 0
134           0 5  | 0 0 6
135           0 7  | 8 0 0
136           0 9  | 1 0 0
137           0 0  | 1 1 1]
138
139      b = [0
140           1
141           2
142           3
143           4
144           5]
145
146      c = A'* b = [ 3
147                   67
148                   33
149                    9
150                   17]
151
152      A'A = [10    0    2   12   0
153              0  155   65    0  30
154              2   65   70    1   1
155             12    0    1   17   1
156              0   30    1    1  37]
157
158      S = [ 42.3419  -1.4000  -11.5806
159            -1.4000   2.6000    1.0000
160            11.5806   1.0000   31.1935]
161
162      r = [ 4.3032
163            5.4000
164            5.0323]
165
166      S\r = [ 0.2102
167              2.1367
168              0.1388]
169
170      A\b = [-2.3061
171              0.3172
172              0.2102
173              2.1367
174              0.1388]
175*/
176// The following two functions create a TripletSparseMatrix and a
177// BlockSparseMatrix version of this problem.
178
179// TripletSparseMatrix version.
180LinearLeastSquaresProblem* LinearLeastSquaresProblem1() {
181  int num_rows = 6;
182  int num_cols = 5;
183
184  LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;
185  TripletSparseMatrix* A = new TripletSparseMatrix(num_rows,
186                                                   num_cols,
187                                                   num_rows * num_cols);
188  problem->b.reset(new double[num_rows]);
189  problem->D.reset(new double[num_cols]);
190  problem->num_eliminate_blocks = 2;
191
192  int* rows = A->mutable_rows();
193  int* cols = A->mutable_cols();
194  double* values = A->mutable_values();
195
196  int nnz = 0;
197
198  // Row 1
199  {
200    rows[nnz] = 0;
201    cols[nnz] = 0;
202    values[nnz++] = 1;
203
204    rows[nnz] = 0;
205    cols[nnz] = 2;
206    values[nnz++] = 2;
207  }
208
209  // Row 2
210  {
211    rows[nnz] = 1;
212    cols[nnz] = 0;
213    values[nnz++] = 3;
214
215    rows[nnz] = 1;
216    cols[nnz] = 3;
217    values[nnz++] = 4;
218  }
219
220  // Row 3
221  {
222    rows[nnz] = 2;
223    cols[nnz] = 1;
224    values[nnz++] = 5;
225
226    rows[nnz] = 2;
227    cols[nnz] = 4;
228    values[nnz++] = 6;
229  }
230
231  // Row 4
232  {
233    rows[nnz] = 3;
234    cols[nnz] = 1;
235    values[nnz++] = 7;
236
237    rows[nnz] = 3;
238    cols[nnz] = 2;
239    values[nnz++] = 8;
240  }
241
242  // Row 5
243  {
244    rows[nnz] = 4;
245    cols[nnz] = 1;
246    values[nnz++] = 9;
247
248    rows[nnz] = 4;
249    cols[nnz] = 2;
250    values[nnz++] = 1;
251  }
252
253  // Row 6
254  {
255    rows[nnz] = 5;
256    cols[nnz] = 2;
257    values[nnz++] = 1;
258
259    rows[nnz] = 5;
260    cols[nnz] = 3;
261    values[nnz++] = 1;
262
263    rows[nnz] = 5;
264    cols[nnz] = 4;
265    values[nnz++] = 1;
266  }
267
268  A->set_num_nonzeros(nnz);
269  CHECK(A->IsValid());
270
271  problem->A.reset(A);
272
273  for (int i = 0; i < num_cols; ++i) {
274    problem->D.get()[i] = 1;
275  }
276
277  for (int i = 0; i < num_rows; ++i) {
278    problem->b.get()[i] = i;
279  }
280
281  return problem;
282}
283
284// BlockSparseMatrix version
285LinearLeastSquaresProblem* LinearLeastSquaresProblem2() {
286  int num_rows = 6;
287  int num_cols = 5;
288
289  LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;
290
291  problem->b.reset(new double[num_rows]);
292  problem->D.reset(new double[num_cols]);
293  problem->num_eliminate_blocks = 2;
294
295  CompressedRowBlockStructure* bs = new CompressedRowBlockStructure;
296  scoped_array<double> values(new double[num_rows * num_cols]);
297
298  for (int c = 0; c < num_cols; ++c) {
299    bs->cols.push_back(Block());
300    bs->cols.back().size = 1;
301    bs->cols.back().position = c;
302  }
303
304  int nnz = 0;
305
306  // Row 1
307  {
308    values[nnz++] = 1;
309    values[nnz++] = 2;
310
311    bs->rows.push_back(CompressedRow());
312    CompressedRow& row = bs->rows.back();
313    row.block.size = 1;
314    row.block.position = 0;
315    row.cells.push_back(Cell(0, 0));
316    row.cells.push_back(Cell(2, 1));
317  }
318
319  // Row 2
320  {
321    values[nnz++] = 3;
322    values[nnz++] = 4;
323
324    bs->rows.push_back(CompressedRow());
325    CompressedRow& row = bs->rows.back();
326    row.block.size = 1;
327    row.block.position = 1;
328    row.cells.push_back(Cell(0, 2));
329    row.cells.push_back(Cell(3, 3));
330  }
331
332  // Row 3
333  {
334    values[nnz++] = 5;
335    values[nnz++] = 6;
336
337    bs->rows.push_back(CompressedRow());
338    CompressedRow& row = bs->rows.back();
339    row.block.size = 1;
340    row.block.position = 2;
341    row.cells.push_back(Cell(1, 4));
342    row.cells.push_back(Cell(4, 5));
343  }
344
345  // Row 4
346  {
347    values[nnz++] = 7;
348    values[nnz++] = 8;
349
350    bs->rows.push_back(CompressedRow());
351    CompressedRow& row = bs->rows.back();
352    row.block.size = 1;
353    row.block.position = 3;
354    row.cells.push_back(Cell(1, 6));
355    row.cells.push_back(Cell(2, 7));
356  }
357
358  // Row 5
359  {
360    values[nnz++] = 9;
361    values[nnz++] = 1;
362
363    bs->rows.push_back(CompressedRow());
364    CompressedRow& row = bs->rows.back();
365    row.block.size = 1;
366    row.block.position = 4;
367    row.cells.push_back(Cell(1, 8));
368    row.cells.push_back(Cell(2, 9));
369  }
370
371  // Row 6
372  {
373    values[nnz++] = 1;
374    values[nnz++] = 1;
375    values[nnz++] = 1;
376
377    bs->rows.push_back(CompressedRow());
378    CompressedRow& row = bs->rows.back();
379    row.block.size = 1;
380    row.block.position = 5;
381    row.cells.push_back(Cell(2, 10));
382    row.cells.push_back(Cell(3, 11));
383    row.cells.push_back(Cell(4, 12));
384  }
385
386  BlockSparseMatrix* A = new BlockSparseMatrix(bs);
387  memcpy(A->mutable_values(), values.get(), nnz * sizeof(*A->values()));
388
389  for (int i = 0; i < num_cols; ++i) {
390    problem->D.get()[i] = 1;
391  }
392
393  for (int i = 0; i < num_rows; ++i) {
394    problem->b.get()[i] = i;
395  }
396
397  problem->A.reset(A);
398
399  return problem;
400}
401
402
403/*
404      A = [1 0
405           3 0
406           0 5
407           0 7
408           0 9
409           0 0]
410
411      b = [0
412           1
413           2
414           3
415           4
416           5]
417*/
418// BlockSparseMatrix version
419LinearLeastSquaresProblem* LinearLeastSquaresProblem3() {
420  int num_rows = 5;
421  int num_cols = 2;
422
423  LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;
424
425  problem->b.reset(new double[num_rows]);
426  problem->D.reset(new double[num_cols]);
427  problem->num_eliminate_blocks = 2;
428
429  CompressedRowBlockStructure* bs = new CompressedRowBlockStructure;
430  scoped_array<double> values(new double[num_rows * num_cols]);
431
432  for (int c = 0; c < num_cols; ++c) {
433    bs->cols.push_back(Block());
434    bs->cols.back().size = 1;
435    bs->cols.back().position = c;
436  }
437
438  int nnz = 0;
439
440  // Row 1
441  {
442    values[nnz++] = 1;
443    bs->rows.push_back(CompressedRow());
444    CompressedRow& row = bs->rows.back();
445    row.block.size = 1;
446    row.block.position = 0;
447    row.cells.push_back(Cell(0, 0));
448  }
449
450  // Row 2
451  {
452    values[nnz++] = 3;
453    bs->rows.push_back(CompressedRow());
454    CompressedRow& row = bs->rows.back();
455    row.block.size = 1;
456    row.block.position = 1;
457    row.cells.push_back(Cell(0, 1));
458  }
459
460  // Row 3
461  {
462    values[nnz++] = 5;
463    bs->rows.push_back(CompressedRow());
464    CompressedRow& row = bs->rows.back();
465    row.block.size = 1;
466    row.block.position = 2;
467    row.cells.push_back(Cell(1, 2));
468  }
469
470  // Row 4
471  {
472    values[nnz++] = 7;
473    bs->rows.push_back(CompressedRow());
474    CompressedRow& row = bs->rows.back();
475    row.block.size = 1;
476    row.block.position = 3;
477    row.cells.push_back(Cell(1, 3));
478  }
479
480  // Row 5
481  {
482    values[nnz++] = 9;
483    bs->rows.push_back(CompressedRow());
484    CompressedRow& row = bs->rows.back();
485    row.block.size = 1;
486    row.block.position = 4;
487    row.cells.push_back(Cell(1, 4));
488  }
489
490  BlockSparseMatrix* A = new BlockSparseMatrix(bs);
491  memcpy(A->mutable_values(), values.get(), nnz * sizeof(*A->values()));
492
493  for (int i = 0; i < num_cols; ++i) {
494    problem->D.get()[i] = 1;
495  }
496
497  for (int i = 0; i < num_rows; ++i) {
498    problem->b.get()[i] = i;
499  }
500
501  problem->A.reset(A);
502
503  return problem;
504}
505
506namespace {
507bool DumpLinearLeastSquaresProblemToConsole(const SparseMatrix* A,
508                                            const double* D,
509                                            const double* b,
510                                            const double* x,
511                                            int num_eliminate_blocks) {
512  CHECK_NOTNULL(A);
513  Matrix AA;
514  A->ToDenseMatrix(&AA);
515  LOG(INFO) << "A^T: \n" << AA.transpose();
516
517  if (D != NULL) {
518    LOG(INFO) << "A's appended diagonal:\n"
519              << ConstVectorRef(D, A->num_cols());
520  }
521
522  if (b != NULL) {
523    LOG(INFO) << "b: \n" << ConstVectorRef(b, A->num_rows());
524  }
525
526  if (x != NULL) {
527    LOG(INFO) << "x: \n" << ConstVectorRef(x, A->num_cols());
528  }
529  return true;
530};
531
532void WriteArrayToFileOrDie(const string& filename,
533                           const double* x,
534                           const int size) {
535  CHECK_NOTNULL(x);
536  VLOG(2) << "Writing array to: " << filename;
537  FILE* fptr = fopen(filename.c_str(), "w");
538  CHECK_NOTNULL(fptr);
539  for (int i = 0; i < size; ++i) {
540    fprintf(fptr, "%17f\n", x[i]);
541  }
542  fclose(fptr);
543}
544
545bool DumpLinearLeastSquaresProblemToTextFile(const string& filename_base,
546                                             const SparseMatrix* A,
547                                             const double* D,
548                                             const double* b,
549                                             const double* x,
550                                             int num_eliminate_blocks) {
551  CHECK_NOTNULL(A);
552  LOG(INFO) << "writing to: " << filename_base << "*";
553
554  string matlab_script;
555  StringAppendF(&matlab_script,
556                "function lsqp = load_trust_region_problem()\n");
557  StringAppendF(&matlab_script,
558                "lsqp.num_rows = %d;\n", A->num_rows());
559  StringAppendF(&matlab_script,
560                "lsqp.num_cols = %d;\n", A->num_cols());
561
562  {
563    string filename = filename_base + "_A.txt";
564    FILE* fptr = fopen(filename.c_str(), "w");
565    CHECK_NOTNULL(fptr);
566    A->ToTextFile(fptr);
567    fclose(fptr);
568    StringAppendF(&matlab_script,
569                  "tmp = load('%s', '-ascii');\n", filename.c_str());
570    StringAppendF(
571        &matlab_script,
572        "lsqp.A = sparse(tmp(:, 1) + 1, tmp(:, 2) + 1, tmp(:, 3), %d, %d);\n",
573        A->num_rows(),
574        A->num_cols());
575  }
576
577
578  if (D != NULL) {
579    string filename = filename_base + "_D.txt";
580    WriteArrayToFileOrDie(filename, D, A->num_cols());
581    StringAppendF(&matlab_script,
582                  "lsqp.D = load('%s', '-ascii');\n", filename.c_str());
583  }
584
585  if (b != NULL) {
586    string filename = filename_base + "_b.txt";
587    WriteArrayToFileOrDie(filename, b, A->num_rows());
588    StringAppendF(&matlab_script,
589                  "lsqp.b = load('%s', '-ascii');\n", filename.c_str());
590  }
591
592  if (x != NULL) {
593    string filename = filename_base + "_x.txt";
594    WriteArrayToFileOrDie(filename, x, A->num_cols());
595    StringAppendF(&matlab_script,
596                  "lsqp.x = load('%s', '-ascii');\n", filename.c_str());
597  }
598
599  string matlab_filename = filename_base + ".m";
600  WriteStringToFileOrDie(matlab_script, matlab_filename);
601  return true;
602}
603}  // namespace
604
605bool DumpLinearLeastSquaresProblem(const string& filename_base,
606                                   DumpFormatType dump_format_type,
607                                   const SparseMatrix* A,
608                                   const double* D,
609                                   const double* b,
610                                   const double* x,
611                                   int num_eliminate_blocks) {
612  switch (dump_format_type) {
613    case CONSOLE:
614      return DumpLinearLeastSquaresProblemToConsole(A, D, b, x,
615                                                    num_eliminate_blocks);
616    case TEXTFILE:
617      return DumpLinearLeastSquaresProblemToTextFile(filename_base,
618                                                     A, D, b, x,
619                                                     num_eliminate_blocks);
620    default:
621      LOG(FATAL) << "Unknown DumpFormatType " << dump_format_type;
622  };
623
624  return true;
625}
626
627}  // namespace internal
628}  // namespace ceres
629