1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *      http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.commons.math.linear;
19
20import java.lang.reflect.Array;
21
22import org.apache.commons.math.Field;
23import org.apache.commons.math.FieldElement;
24import org.apache.commons.math.MathRuntimeException;
25import org.apache.commons.math.exception.util.LocalizedFormats;
26
27/**
28 * Calculates the LUP-decomposition of a square matrix.
29 * <p>The LUP-decomposition of a matrix A consists of three matrices
30 * L, U and P that satisfy: PA = LU, L is lower triangular, and U is
31 * upper triangular and P is a permutation matrix. All matrices are
32 * m&times;m.</p>
33 * <p>Since {@link FieldElement field elements} do not provide an ordering
34 * operator, the permutation matrix is computed here only in order to avoid
35 * a zero pivot element, no attempt is done to get the largest pivot element.</p>
36 *
37 * @param <T> the type of the field elements
38 * @version $Revision: 983921 $ $Date: 2010-08-10 12:46:06 +0200 (mar. 10 août 2010) $
39 * @since 2.0
40 */
41public class FieldLUDecompositionImpl<T extends FieldElement<T>> implements FieldLUDecomposition<T> {
42
43    /** Field to which the elements belong. */
44    private final Field<T> field;
45
46    /** Entries of LU decomposition. */
47    private T lu[][];
48
49    /** Pivot permutation associated with LU decomposition */
50    private int[] pivot;
51
52    /** Parity of the permutation associated with the LU decomposition */
53    private boolean even;
54
55    /** Singularity indicator. */
56    private boolean singular;
57
58    /** Cached value of L. */
59    private FieldMatrix<T> cachedL;
60
61    /** Cached value of U. */
62    private FieldMatrix<T> cachedU;
63
64    /** Cached value of P. */
65    private FieldMatrix<T> cachedP;
66
67    /**
68     * Calculates the LU-decomposition of the given matrix.
69     * @param matrix The matrix to decompose.
70     * @exception NonSquareMatrixException if matrix is not square
71     */
72    public FieldLUDecompositionImpl(FieldMatrix<T> matrix)
73        throws NonSquareMatrixException {
74
75        if (!matrix.isSquare()) {
76            throw new NonSquareMatrixException(matrix.getRowDimension(), matrix.getColumnDimension());
77        }
78
79        final int m = matrix.getColumnDimension();
80        field = matrix.getField();
81        lu = matrix.getData();
82        pivot = new int[m];
83        cachedL = null;
84        cachedU = null;
85        cachedP = null;
86
87        // Initialize permutation array and parity
88        for (int row = 0; row < m; row++) {
89            pivot[row] = row;
90        }
91        even     = true;
92        singular = false;
93
94        // Loop over columns
95        for (int col = 0; col < m; col++) {
96
97            T sum = field.getZero();
98
99            // upper
100            for (int row = 0; row < col; row++) {
101                final T[] luRow = lu[row];
102                sum = luRow[col];
103                for (int i = 0; i < row; i++) {
104                    sum = sum.subtract(luRow[i].multiply(lu[i][col]));
105                }
106                luRow[col] = sum;
107            }
108
109            // lower
110            int nonZero = col; // permutation row
111            for (int row = col; row < m; row++) {
112                final T[] luRow = lu[row];
113                sum = luRow[col];
114                for (int i = 0; i < col; i++) {
115                    sum = sum.subtract(luRow[i].multiply(lu[i][col]));
116                }
117                luRow[col] = sum;
118
119                if (lu[nonZero][col].equals(field.getZero())) {
120                    // try to select a better permutation choice
121                    ++nonZero;
122                }
123            }
124
125            // Singularity check
126            if (nonZero >= m) {
127                singular = true;
128                return;
129            }
130
131            // Pivot if necessary
132            if (nonZero != col) {
133                T tmp = field.getZero();
134                for (int i = 0; i < m; i++) {
135                    tmp = lu[nonZero][i];
136                    lu[nonZero][i] = lu[col][i];
137                    lu[col][i] = tmp;
138                }
139                int temp = pivot[nonZero];
140                pivot[nonZero] = pivot[col];
141                pivot[col] = temp;
142                even = !even;
143            }
144
145            // Divide the lower elements by the "winning" diagonal elt.
146            final T luDiag = lu[col][col];
147            for (int row = col + 1; row < m; row++) {
148                final T[] luRow = lu[row];
149                luRow[col] = luRow[col].divide(luDiag);
150            }
151        }
152
153    }
154
155    /** {@inheritDoc} */
156    public FieldMatrix<T> getL() {
157        if ((cachedL == null) && !singular) {
158            final int m = pivot.length;
159            cachedL = new Array2DRowFieldMatrix<T>(field, m, m);
160            for (int i = 0; i < m; ++i) {
161                final T[] luI = lu[i];
162                for (int j = 0; j < i; ++j) {
163                    cachedL.setEntry(i, j, luI[j]);
164                }
165                cachedL.setEntry(i, i, field.getOne());
166            }
167        }
168        return cachedL;
169    }
170
171    /** {@inheritDoc} */
172    public FieldMatrix<T> getU() {
173        if ((cachedU == null) && !singular) {
174            final int m = pivot.length;
175            cachedU = new Array2DRowFieldMatrix<T>(field, m, m);
176            for (int i = 0; i < m; ++i) {
177                final T[] luI = lu[i];
178                for (int j = i; j < m; ++j) {
179                    cachedU.setEntry(i, j, luI[j]);
180                }
181            }
182        }
183        return cachedU;
184    }
185
186    /** {@inheritDoc} */
187    public FieldMatrix<T> getP() {
188        if ((cachedP == null) && !singular) {
189            final int m = pivot.length;
190            cachedP = new Array2DRowFieldMatrix<T>(field, m, m);
191            for (int i = 0; i < m; ++i) {
192                cachedP.setEntry(i, pivot[i], field.getOne());
193            }
194        }
195        return cachedP;
196    }
197
198    /** {@inheritDoc} */
199    public int[] getPivot() {
200        return pivot.clone();
201    }
202
203    /** {@inheritDoc} */
204    public T getDeterminant() {
205        if (singular) {
206            return field.getZero();
207        } else {
208            final int m = pivot.length;
209            T determinant = even ? field.getOne() : field.getZero().subtract(field.getOne());
210            for (int i = 0; i < m; i++) {
211                determinant = determinant.multiply(lu[i][i]);
212            }
213            return determinant;
214        }
215    }
216
217    /** {@inheritDoc} */
218    public FieldDecompositionSolver<T> getSolver() {
219        return new Solver<T>(field, lu, pivot, singular);
220    }
221
222    /** Specialized solver. */
223    private static class Solver<T extends FieldElement<T>> implements FieldDecompositionSolver<T> {
224
225        /** Serializable version identifier. */
226        private static final long serialVersionUID = -6353105415121373022L;
227
228        /** Field to which the elements belong. */
229        private final Field<T> field;
230
231        /** Entries of LU decomposition. */
232        private final T lu[][];
233
234        /** Pivot permutation associated with LU decomposition. */
235        private final int[] pivot;
236
237        /** Singularity indicator. */
238        private final boolean singular;
239
240        /**
241         * Build a solver from decomposed matrix.
242         * @param field field to which the matrix elements belong
243         * @param lu entries of LU decomposition
244         * @param pivot pivot permutation associated with LU decomposition
245         * @param singular singularity indicator
246         */
247        private Solver(final Field<T> field, final T[][] lu,
248                       final int[] pivot, final boolean singular) {
249            this.field    = field;
250            this.lu       = lu;
251            this.pivot    = pivot;
252            this.singular = singular;
253        }
254
255        /** {@inheritDoc} */
256        public boolean isNonSingular() {
257            return !singular;
258        }
259
260        /** {@inheritDoc} */
261        public T[] solve(T[] b)
262            throws IllegalArgumentException, InvalidMatrixException {
263
264            final int m = pivot.length;
265            if (b.length != m) {
266                throw MathRuntimeException.createIllegalArgumentException(
267                        LocalizedFormats.VECTOR_LENGTH_MISMATCH,
268                        b.length, m);
269            }
270            if (singular) {
271                throw new SingularMatrixException();
272            }
273
274            @SuppressWarnings("unchecked") // field is of type T
275            final T[] bp = (T[]) Array.newInstance(field.getZero().getClass(), m);
276
277            // Apply permutations to b
278            for (int row = 0; row < m; row++) {
279                bp[row] = b[pivot[row]];
280            }
281
282            // Solve LY = b
283            for (int col = 0; col < m; col++) {
284                final T bpCol = bp[col];
285                for (int i = col + 1; i < m; i++) {
286                    bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
287                }
288            }
289
290            // Solve UX = Y
291            for (int col = m - 1; col >= 0; col--) {
292                bp[col] = bp[col].divide(lu[col][col]);
293                final T bpCol = bp[col];
294                for (int i = 0; i < col; i++) {
295                    bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
296                }
297            }
298
299            return bp;
300
301        }
302
303        /** {@inheritDoc} */
304        public FieldVector<T> solve(FieldVector<T> b)
305            throws IllegalArgumentException, InvalidMatrixException {
306            try {
307                return solve((ArrayFieldVector<T>) b);
308            } catch (ClassCastException cce) {
309
310                final int m = pivot.length;
311                if (b.getDimension() != m) {
312                    throw MathRuntimeException.createIllegalArgumentException(
313                            LocalizedFormats.VECTOR_LENGTH_MISMATCH,
314                            b.getDimension(), m);
315                }
316                if (singular) {
317                    throw new SingularMatrixException();
318                }
319
320                @SuppressWarnings("unchecked") // field is of type T
321                final T[] bp = (T[]) Array.newInstance(field.getZero().getClass(), m);
322
323                // Apply permutations to b
324                for (int row = 0; row < m; row++) {
325                    bp[row] = b.getEntry(pivot[row]);
326                }
327
328                // Solve LY = b
329                for (int col = 0; col < m; col++) {
330                    final T bpCol = bp[col];
331                    for (int i = col + 1; i < m; i++) {
332                        bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
333                    }
334                }
335
336                // Solve UX = Y
337                for (int col = m - 1; col >= 0; col--) {
338                    bp[col] = bp[col].divide(lu[col][col]);
339                    final T bpCol = bp[col];
340                    for (int i = 0; i < col; i++) {
341                        bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
342                    }
343                }
344
345                return new ArrayFieldVector<T>(bp, false);
346
347            }
348        }
349
350        /** Solve the linear equation A &times; X = B.
351         * <p>The A matrix is implicit here. It is </p>
352         * @param b right-hand side of the equation A &times; X = B
353         * @return a vector X such that A &times; X = B
354         * @exception IllegalArgumentException if matrices dimensions don't match
355         * @exception InvalidMatrixException if decomposed matrix is singular
356         */
357        public ArrayFieldVector<T> solve(ArrayFieldVector<T> b)
358            throws IllegalArgumentException, InvalidMatrixException {
359            return new ArrayFieldVector<T>(solve(b.getDataRef()), false);
360        }
361
362        /** {@inheritDoc} */
363        public FieldMatrix<T> solve(FieldMatrix<T> b)
364            throws IllegalArgumentException, InvalidMatrixException {
365
366            final int m = pivot.length;
367            if (b.getRowDimension() != m) {
368                throw MathRuntimeException.createIllegalArgumentException(
369                        LocalizedFormats.DIMENSIONS_MISMATCH_2x2,
370                        b.getRowDimension(), b.getColumnDimension(), m, "n");
371            }
372            if (singular) {
373                throw new SingularMatrixException();
374            }
375
376            final int nColB = b.getColumnDimension();
377
378            // Apply permutations to b
379            @SuppressWarnings("unchecked") // field is of type T
380            final T[][] bp = (T[][]) Array.newInstance(field.getZero().getClass(), new int[] { m, nColB });
381            for (int row = 0; row < m; row++) {
382                final T[] bpRow = bp[row];
383                final int pRow = pivot[row];
384                for (int col = 0; col < nColB; col++) {
385                    bpRow[col] = b.getEntry(pRow, col);
386                }
387            }
388
389            // Solve LY = b
390            for (int col = 0; col < m; col++) {
391                final T[] bpCol = bp[col];
392                for (int i = col + 1; i < m; i++) {
393                    final T[] bpI = bp[i];
394                    final T luICol = lu[i][col];
395                    for (int j = 0; j < nColB; j++) {
396                        bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
397                    }
398                }
399            }
400
401            // Solve UX = Y
402            for (int col = m - 1; col >= 0; col--) {
403                final T[] bpCol = bp[col];
404                final T luDiag = lu[col][col];
405                for (int j = 0; j < nColB; j++) {
406                    bpCol[j] = bpCol[j].divide(luDiag);
407                }
408                for (int i = 0; i < col; i++) {
409                    final T[] bpI = bp[i];
410                    final T luICol = lu[i][col];
411                    for (int j = 0; j < nColB; j++) {
412                        bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
413                    }
414                }
415            }
416
417            return new Array2DRowFieldMatrix<T>(bp, false);
418
419        }
420
421        /** {@inheritDoc} */
422        public FieldMatrix<T> getInverse() throws InvalidMatrixException {
423            final int m = pivot.length;
424            final T one = field.getOne();
425            FieldMatrix<T> identity = new Array2DRowFieldMatrix<T>(field, m, m);
426            for (int i = 0; i < m; ++i) {
427                identity.setEntry(i, i, one);
428            }
429            return solve(identity);
430        }
431
432    }
433
434}
435