1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *      http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.commons.math.linear;
19
20import java.util.Arrays;
21
22import org.apache.commons.math.MathRuntimeException;
23import org.apache.commons.math.exception.util.LocalizedFormats;
24import org.apache.commons.math.util.FastMath;
25
26
27/**
28 * Calculates the QR-decomposition of a matrix.
29 * <p>The QR-decomposition of a matrix A consists of two matrices Q and R
30 * that satisfy: A = QR, Q is orthogonal (Q<sup>T</sup>Q = I), and R is
31 * upper triangular. If A is m&times;n, Q is m&times;m and R m&times;n.</p>
32 * <p>This class compute the decomposition using Householder reflectors.</p>
33 * <p>For efficiency purposes, the decomposition in packed form is transposed.
34 * This allows inner loop to iterate inside rows, which is much more cache-efficient
35 * in Java.</p>
36 *
37 * @see <a href="http://mathworld.wolfram.com/QRDecomposition.html">MathWorld</a>
38 * @see <a href="http://en.wikipedia.org/wiki/QR_decomposition">Wikipedia</a>
39 *
40 * @version $Revision: 990655 $ $Date: 2010-08-29 23:49:40 +0200 (dim. 29 août 2010) $
41 * @since 1.2
42 */
43public class QRDecompositionImpl implements QRDecomposition {
44
45    /**
46     * A packed TRANSPOSED representation of the QR decomposition.
47     * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
48     * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
49     * from which an explicit form of Q can be recomputed if desired.</p>
50     */
51    private double[][] qrt;
52
53    /** The diagonal elements of R. */
54    private double[] rDiag;
55
56    /** Cached value of Q. */
57    private RealMatrix cachedQ;
58
59    /** Cached value of QT. */
60    private RealMatrix cachedQT;
61
62    /** Cached value of R. */
63    private RealMatrix cachedR;
64
65    /** Cached value of H. */
66    private RealMatrix cachedH;
67
68    /**
69     * Calculates the QR-decomposition of the given matrix.
70     * @param matrix The matrix to decompose.
71     */
72    public QRDecompositionImpl(RealMatrix matrix) {
73
74        final int m = matrix.getRowDimension();
75        final int n = matrix.getColumnDimension();
76        qrt = matrix.transpose().getData();
77        rDiag = new double[FastMath.min(m, n)];
78        cachedQ  = null;
79        cachedQT = null;
80        cachedR  = null;
81        cachedH  = null;
82
83        /*
84         * The QR decomposition of a matrix A is calculated using Householder
85         * reflectors by repeating the following operations to each minor
86         * A(minor,minor) of A:
87         */
88        for (int minor = 0; minor < FastMath.min(m, n); minor++) {
89
90            final double[] qrtMinor = qrt[minor];
91
92            /*
93             * Let x be the first column of the minor, and a^2 = |x|^2.
94             * x will be in the positions qr[minor][minor] through qr[m][minor].
95             * The first column of the transformed minor will be (a,0,0,..)'
96             * The sign of a is chosen to be opposite to the sign of the first
97             * component of x. Let's find a:
98             */
99            double xNormSqr = 0;
100            for (int row = minor; row < m; row++) {
101                final double c = qrtMinor[row];
102                xNormSqr += c * c;
103            }
104            final double a = (qrtMinor[minor] > 0) ? -FastMath.sqrt(xNormSqr) : FastMath.sqrt(xNormSqr);
105            rDiag[minor] = a;
106
107            if (a != 0.0) {
108
109                /*
110                 * Calculate the normalized reflection vector v and transform
111                 * the first column. We know the norm of v beforehand: v = x-ae
112                 * so |v|^2 = <x-ae,x-ae> = <x,x>-2a<x,e>+a^2<e,e> =
113                 * a^2+a^2-2a<x,e> = 2a*(a - <x,e>).
114                 * Here <x, e> is now qr[minor][minor].
115                 * v = x-ae is stored in the column at qr:
116                 */
117                qrtMinor[minor] -= a; // now |v|^2 = -2a*(qr[minor][minor])
118
119                /*
120                 * Transform the rest of the columns of the minor:
121                 * They will be transformed by the matrix H = I-2vv'/|v|^2.
122                 * If x is a column vector of the minor, then
123                 * Hx = (I-2vv'/|v|^2)x = x-2vv'x/|v|^2 = x - 2<x,v>/|v|^2 v.
124                 * Therefore the transformation is easily calculated by
125                 * subtracting the column vector (2<x,v>/|v|^2)v from x.
126                 *
127                 * Let 2<x,v>/|v|^2 = alpha. From above we have
128                 * |v|^2 = -2a*(qr[minor][minor]), so
129                 * alpha = -<x,v>/(a*qr[minor][minor])
130                 */
131                for (int col = minor+1; col < n; col++) {
132                    final double[] qrtCol = qrt[col];
133                    double alpha = 0;
134                    for (int row = minor; row < m; row++) {
135                        alpha -= qrtCol[row] * qrtMinor[row];
136                    }
137                    alpha /= a * qrtMinor[minor];
138
139                    // Subtract the column vector alpha*v from x.
140                    for (int row = minor; row < m; row++) {
141                        qrtCol[row] -= alpha * qrtMinor[row];
142                    }
143                }
144            }
145        }
146    }
147
148    /** {@inheritDoc} */
149    public RealMatrix getR() {
150
151        if (cachedR == null) {
152
153            // R is supposed to be m x n
154            final int n = qrt.length;
155            final int m = qrt[0].length;
156            cachedR = MatrixUtils.createRealMatrix(m, n);
157
158            // copy the diagonal from rDiag and the upper triangle of qr
159            for (int row = FastMath.min(m, n) - 1; row >= 0; row--) {
160                cachedR.setEntry(row, row, rDiag[row]);
161                for (int col = row + 1; col < n; col++) {
162                    cachedR.setEntry(row, col, qrt[col][row]);
163                }
164            }
165
166        }
167
168        // return the cached matrix
169        return cachedR;
170
171    }
172
173    /** {@inheritDoc} */
174    public RealMatrix getQ() {
175        if (cachedQ == null) {
176            cachedQ = getQT().transpose();
177        }
178        return cachedQ;
179    }
180
181    /** {@inheritDoc} */
182    public RealMatrix getQT() {
183
184        if (cachedQT == null) {
185
186            // QT is supposed to be m x m
187            final int n = qrt.length;
188            final int m = qrt[0].length;
189            cachedQT = MatrixUtils.createRealMatrix(m, m);
190
191            /*
192             * Q = Q1 Q2 ... Q_m, so Q is formed by first constructing Q_m and then
193             * applying the Householder transformations Q_(m-1),Q_(m-2),...,Q1 in
194             * succession to the result
195             */
196            for (int minor = m - 1; minor >= FastMath.min(m, n); minor--) {
197                cachedQT.setEntry(minor, minor, 1.0);
198            }
199
200            for (int minor = FastMath.min(m, n)-1; minor >= 0; minor--){
201                final double[] qrtMinor = qrt[minor];
202                cachedQT.setEntry(minor, minor, 1.0);
203                if (qrtMinor[minor] != 0.0) {
204                    for (int col = minor; col < m; col++) {
205                        double alpha = 0;
206                        for (int row = minor; row < m; row++) {
207                            alpha -= cachedQT.getEntry(col, row) * qrtMinor[row];
208                        }
209                        alpha /= rDiag[minor] * qrtMinor[minor];
210
211                        for (int row = minor; row < m; row++) {
212                            cachedQT.addToEntry(col, row, -alpha * qrtMinor[row]);
213                        }
214                    }
215                }
216            }
217
218        }
219
220        // return the cached matrix
221        return cachedQT;
222
223    }
224
225    /** {@inheritDoc} */
226    public RealMatrix getH() {
227
228        if (cachedH == null) {
229
230            final int n = qrt.length;
231            final int m = qrt[0].length;
232            cachedH = MatrixUtils.createRealMatrix(m, n);
233            for (int i = 0; i < m; ++i) {
234                for (int j = 0; j < FastMath.min(i + 1, n); ++j) {
235                    cachedH.setEntry(i, j, qrt[j][i] / -rDiag[j]);
236                }
237            }
238
239        }
240
241        // return the cached matrix
242        return cachedH;
243
244    }
245
246    /** {@inheritDoc} */
247    public DecompositionSolver getSolver() {
248        return new Solver(qrt, rDiag);
249    }
250
251    /** Specialized solver. */
252    private static class Solver implements DecompositionSolver {
253
254        /**
255         * A packed TRANSPOSED representation of the QR decomposition.
256         * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
257         * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
258         * from which an explicit form of Q can be recomputed if desired.</p>
259         */
260        private final double[][] qrt;
261
262        /** The diagonal elements of R. */
263        private final double[] rDiag;
264
265        /**
266         * Build a solver from decomposed matrix.
267         * @param qrt packed TRANSPOSED representation of the QR decomposition
268         * @param rDiag diagonal elements of R
269         */
270        private Solver(final double[][] qrt, final double[] rDiag) {
271            this.qrt   = qrt;
272            this.rDiag = rDiag;
273        }
274
275        /** {@inheritDoc} */
276        public boolean isNonSingular() {
277
278            for (double diag : rDiag) {
279                if (diag == 0) {
280                    return false;
281                }
282            }
283            return true;
284
285        }
286
287        /** {@inheritDoc} */
288        public double[] solve(double[] b)
289        throws IllegalArgumentException, InvalidMatrixException {
290
291            final int n = qrt.length;
292            final int m = qrt[0].length;
293            if (b.length != m) {
294                throw MathRuntimeException.createIllegalArgumentException(
295                        LocalizedFormats.VECTOR_LENGTH_MISMATCH,
296                        b.length, m);
297            }
298            if (!isNonSingular()) {
299                throw new SingularMatrixException();
300            }
301
302            final double[] x = new double[n];
303            final double[] y = b.clone();
304
305            // apply Householder transforms to solve Q.y = b
306            for (int minor = 0; minor < FastMath.min(m, n); minor++) {
307
308                final double[] qrtMinor = qrt[minor];
309                double dotProduct = 0;
310                for (int row = minor; row < m; row++) {
311                    dotProduct += y[row] * qrtMinor[row];
312                }
313                dotProduct /= rDiag[minor] * qrtMinor[minor];
314
315                for (int row = minor; row < m; row++) {
316                    y[row] += dotProduct * qrtMinor[row];
317                }
318
319            }
320
321            // solve triangular system R.x = y
322            for (int row = rDiag.length - 1; row >= 0; --row) {
323                y[row] /= rDiag[row];
324                final double yRow   = y[row];
325                final double[] qrtRow = qrt[row];
326                x[row] = yRow;
327                for (int i = 0; i < row; i++) {
328                    y[i] -= yRow * qrtRow[i];
329                }
330            }
331
332            return x;
333
334        }
335
336        /** {@inheritDoc} */
337        public RealVector solve(RealVector b)
338        throws IllegalArgumentException, InvalidMatrixException {
339            try {
340                return solve((ArrayRealVector) b);
341            } catch (ClassCastException cce) {
342                return new ArrayRealVector(solve(b.getData()), false);
343            }
344        }
345
346        /** Solve the linear equation A &times; X = B.
347         * <p>The A matrix is implicit here. It is </p>
348         * @param b right-hand side of the equation A &times; X = B
349         * @return a vector X that minimizes the two norm of A &times; X - B
350         * @throws IllegalArgumentException if matrices dimensions don't match
351         * @throws InvalidMatrixException if decomposed matrix is singular
352         */
353        public ArrayRealVector solve(ArrayRealVector b)
354        throws IllegalArgumentException, InvalidMatrixException {
355            return new ArrayRealVector(solve(b.getDataRef()), false);
356        }
357
358        /** {@inheritDoc} */
359        public RealMatrix solve(RealMatrix b)
360        throws IllegalArgumentException, InvalidMatrixException {
361
362            final int n = qrt.length;
363            final int m = qrt[0].length;
364            if (b.getRowDimension() != m) {
365                throw MathRuntimeException.createIllegalArgumentException(
366                        LocalizedFormats.DIMENSIONS_MISMATCH_2x2,
367                        b.getRowDimension(), b.getColumnDimension(), m, "n");
368            }
369            if (!isNonSingular()) {
370                throw new SingularMatrixException();
371            }
372
373            final int columns        = b.getColumnDimension();
374            final int blockSize      = BlockRealMatrix.BLOCK_SIZE;
375            final int cBlocks        = (columns + blockSize - 1) / blockSize;
376            final double[][] xBlocks = BlockRealMatrix.createBlocksLayout(n, columns);
377            final double[][] y       = new double[b.getRowDimension()][blockSize];
378            final double[]   alpha   = new double[blockSize];
379
380            for (int kBlock = 0; kBlock < cBlocks; ++kBlock) {
381                final int kStart = kBlock * blockSize;
382                final int kEnd   = FastMath.min(kStart + blockSize, columns);
383                final int kWidth = kEnd - kStart;
384
385                // get the right hand side vector
386                b.copySubMatrix(0, m - 1, kStart, kEnd - 1, y);
387
388                // apply Householder transforms to solve Q.y = b
389                for (int minor = 0; minor < FastMath.min(m, n); minor++) {
390                    final double[] qrtMinor = qrt[minor];
391                    final double factor     = 1.0 / (rDiag[minor] * qrtMinor[minor]);
392
393                    Arrays.fill(alpha, 0, kWidth, 0.0);
394                    for (int row = minor; row < m; ++row) {
395                        final double   d    = qrtMinor[row];
396                        final double[] yRow = y[row];
397                        for (int k = 0; k < kWidth; ++k) {
398                            alpha[k] += d * yRow[k];
399                        }
400                    }
401                    for (int k = 0; k < kWidth; ++k) {
402                        alpha[k] *= factor;
403                    }
404
405                    for (int row = minor; row < m; ++row) {
406                        final double   d    = qrtMinor[row];
407                        final double[] yRow = y[row];
408                        for (int k = 0; k < kWidth; ++k) {
409                            yRow[k] += alpha[k] * d;
410                        }
411                    }
412
413                }
414
415                // solve triangular system R.x = y
416                for (int j = rDiag.length - 1; j >= 0; --j) {
417                    final int      jBlock = j / blockSize;
418                    final int      jStart = jBlock * blockSize;
419                    final double   factor = 1.0 / rDiag[j];
420                    final double[] yJ     = y[j];
421                    final double[] xBlock = xBlocks[jBlock * cBlocks + kBlock];
422                    int index = (j - jStart) * kWidth;
423                    for (int k = 0; k < kWidth; ++k) {
424                        yJ[k]          *= factor;
425                        xBlock[index++] = yJ[k];
426                    }
427
428                    final double[] qrtJ = qrt[j];
429                    for (int i = 0; i < j; ++i) {
430                        final double rIJ  = qrtJ[i];
431                        final double[] yI = y[i];
432                        for (int k = 0; k < kWidth; ++k) {
433                            yI[k] -= yJ[k] * rIJ;
434                        }
435                    }
436
437                }
438
439            }
440
441            return new BlockRealMatrix(n, columns, xBlocks, false);
442
443        }
444
445        /** {@inheritDoc} */
446        public RealMatrix getInverse()
447        throws InvalidMatrixException {
448            return solve(MatrixUtils.createRealIdentityMatrix(rDiag.length));
449        }
450
451    }
452
453}
454