Matrix.java revision 5b0fa1e4851b9f4a8fd8efe7afa89b575be727bd
1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package com.android.server.wifi.util;
18
19/**
20 * Utility for doing basic matix calculations
21 */
22public class Matrix {
23    public final int n;
24    public final int m;
25    public final double[] mem;
26
27    /**
28     * Creates a new matrix, initialized to zeros
29     *
30     * @param rows - number of rows (n)
31     * @param cols - number of columns (m)
32     */
33    public Matrix(int rows, int cols) {
34        n = rows;
35        m = cols;
36        mem = new double[rows * cols];
37    }
38
39    /**
40     * Creates a new matrix using the provided array of values
41     * <p>
42     * Values are in row-major order.
43     *
44     * @param stride is the number of columns.
45     * @param values is the array of values.
46     * @throws IllegalArgumentException if length of values array not a multiple of stride
47     */
48    public Matrix(int stride, double[] values) {
49        n = (values.length + stride - 1) / stride;
50        m = stride;
51        mem = values;
52        if (mem.length != n * m) throw new IllegalArgumentException();
53    }
54
55    /**
56     * Creates a new matrix duplicating the given one
57     *
58     * @param that is the source Matrix.
59     */
60    public Matrix(Matrix that) {
61        n = that.n;
62        m = that.m;
63        mem = new double[that.mem.length];
64        for (int i = 0; i < mem.length; i++) {
65            mem[i] = that.mem[i];
66        }
67    }
68
69    /**
70     * Gets the matrix coefficient from row i, column j
71     *
72     * @param i row number
73     * @param j column number
74     * @return Coefficient at i,j
75     * @throws IndexOutOfBoundsException if an index is out of bounds
76     */
77    public double get(int i, int j) {
78        if (!(0 <= i && i < n && 0 <= j && j < m)) throw new IndexOutOfBoundsException();
79        return mem[i * m + j];
80    }
81
82    /**
83     * Store a matrix coefficient in row i, column j
84     *
85     * @param i row number
86     * @param j column number
87     * @param v Coefficient to store at i,j
88     * @throws IndexOutOfBoundsException if an index is out of bounds
89     */
90    public void put(int i, int j, double v) {
91        if (!(0 <= i && i < n && 0 <= j && j < m)) throw new IndexOutOfBoundsException();
92        mem[i * m + j] = v;
93    }
94
95    /**
96     * Forms the sum of two matrices, this and that
97     *
98     * @param that is the other matrix
99     * @return newly allocated matrix representing the sum of this and that
100     * @throws IllegalArgumentException if shapes differ
101     */
102    public Matrix plus(Matrix that) {
103        return plus(that, new Matrix(n, m));
104
105    }
106
107    /**
108     * Forms the sum of two matrices, this and that
109     *
110     * @param that   is the other matrix
111     * @param result is space to hold the result
112     * @return result, filled with the matrix sum
113     * @throws IllegalArgumentException if shapes differ
114     */
115    public Matrix plus(Matrix that, Matrix result) {
116        if (!(this.n == that.n && this.m == that.m && this.n == result.n && this.m == result.m)) {
117            throw new IllegalArgumentException();
118        }
119        for (int i = 0; i < mem.length; i++) {
120            result.mem[i] = this.mem[i] + that.mem[i];
121        }
122        return result;
123    }
124
125    /**
126     * Forms the difference of two matrices, this and that
127     *
128     * @param that is the other matrix
129     * @return newly allocated matrix representing the difference of this and that
130     * @throws IllegalArgumentException if shapes differ
131     */
132    public Matrix minus(Matrix that) {
133        return minus(that, new Matrix(n, m));
134    }
135
136    /**
137     * Forms the difference of two matrices, this and that
138     *
139     * @param that   is the other matrix
140     * @param result is space to hold the result
141     * @return result, filled with the matrix difference
142     * @throws IllegalArgumentException if shapes differ
143     */
144    public Matrix minus(Matrix that, Matrix result) {
145        if (!(this.n == that.n && this.m == that.m && this.n == result.n && this.m == result.m)) {
146            throw new IllegalArgumentException();
147        }
148        for (int i = 0; i < mem.length; i++) {
149            result.mem[i] = this.mem[i] - that.mem[i];
150        }
151        return result;
152    }
153
154    /**
155     * Forms the matrix product of two matrices, this and that
156     *
157     * @param that is the other matrix
158     * @return newly allocated matrix representing the matrix product of this and that
159     * @throws IllegalArgumentException if shapes are not conformant
160     */
161    public Matrix dot(Matrix that) {
162        return dot(that, new Matrix(this.n, that.m));
163    }
164
165    /**
166     * Forms the matrix product of two matrices, this and that
167     * <p>
168     * Caller supplies an object to contain the result, as well as scratch space
169     *
170     * @param that   is the other matrix
171     * @param result is space to hold the result
172     * @return result, filled with the matrix product
173     * @throws IllegalArgumentException if shapes are not conformant
174     */
175    public Matrix dot(Matrix that, Matrix result) {
176        if (!(this.n == result.n && this.m == that.n && that.m == result.m)) {
177            throw new IllegalArgumentException();
178        }
179        for (int i = 0; i < n; i++) {
180            for (int j = 0; j < that.m; j++) {
181                double s = 0.0;
182                for (int k = 0; k < m; k++) {
183                    s += this.get(i, k) * that.get(k, j);
184                }
185                result.put(i, j, s);
186            }
187        }
188        return result;
189    }
190
191    /**
192     * Forms the matrix transpose
193     *
194     * @return newly allocated transpose matrix
195     */
196    public Matrix transpose() {
197        return transpose(new Matrix(m, n));
198    }
199
200    /**
201     * Forms the matrix transpose
202     * <p>
203     * Caller supplies an object to contain the result
204     *
205     * @param result is space to hold the result
206     * @return result, filled with the matrix transpose
207     * @throws IllegalArgumentException if result shape is wrong
208     */
209    public Matrix transpose(Matrix result) {
210        if (!(this.n == result.m && this.m == result.n)) throw new IllegalArgumentException();
211        for (int i = 0; i < n; i++) {
212            for (int j = 0; j < m; j++) {
213                result.put(j, i, get(i, j));
214            }
215        }
216        return result;
217    }
218
219    /**
220     * Forms the inverse of a square matrix
221     *
222     * @return newly allocated matrix representing the matrix inverse
223     * @throws ArithmeticException if the matrix is not invertible
224     */
225    public Matrix inverse() {
226        return inverse(new Matrix(n, m), new Matrix(n, 2 * m));
227    }
228
229    /**
230     * Forms the inverse of a square matrix
231     *
232     * @param result  is space to hold the result
233     * @param scratch is workspace of dimension n by 2*n
234     * @return result, filled with the matrix inverse
235     * @throws ArithmeticException if the matrix is not invertible
236     * @throws IllegalArgumentException if shape of scratch or result is wrong
237     */
238    public Matrix inverse(Matrix result, Matrix scratch) {
239        if (!(n == m && n == result.n && m == result.m && n == scratch.n && 2 * m == scratch.m)) {
240            throw new IllegalArgumentException();
241        }
242
243        for (int i = 0; i < n; i++) {
244            for (int j = 0; j < m; j++) {
245                scratch.put(i, j, get(i, j));
246                scratch.put(i, m + j, i == j ? 1.0 : 0.0);
247            }
248        }
249
250        for (int i = 0; i < n; i++) {
251            int ibest = i;
252            double vbest = Math.abs(scratch.get(ibest, ibest));
253            for (int ii = i + 1; ii < n; ii++) {
254                double v = Math.abs(scratch.get(ii, i));
255                if (v > vbest) {
256                    ibest = ii;
257                    vbest = v;
258                }
259            }
260            if (ibest != i) {
261                for (int j = 0; j < scratch.m; j++) {
262                    double t = scratch.get(i, j);
263                    scratch.put(i, j, scratch.get(ibest, j));
264                    scratch.put(ibest, j, t);
265                }
266            }
267            double d = scratch.get(i, i);
268            if (d == 0.0) throw new ArithmeticException("Singular matrix");
269            for (int j = 0; j < scratch.m; j++) {
270                scratch.put(i, j, scratch.get(i, j) / d);
271            }
272            for (int ii = i + 1; ii < n; ii++) {
273                d = scratch.get(ii, i);
274                for (int j = 0; j < scratch.m; j++) {
275                    scratch.put(ii, j, scratch.get(ii, j) - d * scratch.get(i, j));
276                }
277            }
278        }
279        for (int i = n - 1; i >= 0; i--) {
280            for (int ii = 0; ii < i; ii++) {
281                double d = scratch.get(ii, i);
282                for (int j = 0; j < scratch.m; j++) {
283                    scratch.put(ii, j, scratch.get(ii, j) - d * scratch.get(i, j));
284                }
285            }
286        }
287        for (int i = 0; i < result.n; i++) {
288            for (int j = 0; j < result.m; j++) {
289                result.put(i, j, scratch.get(i, m + j));
290            }
291        }
292        return result;
293    }
294
295    /**
296     * Tests for equality
297     */
298    @Override
299    public boolean equals(Object that) {
300        if (this == that) return true;
301        if (!(that instanceof Matrix)) return false;
302        Matrix other = (Matrix) that;
303        if (n != other.n) return false;
304        if (m != other.m) return false;
305        for (int i = 0; i < mem.length; i++) {
306            if (mem[i] != other.mem[i]) return false;
307        }
308        return true;
309    }
310
311    /**
312     * Calculates a hash code
313     */
314    @Override
315    public int hashCode() {
316        int h = n * 101 + m;
317        for (int i = 0; i < mem.length; i++) {
318            h = h * 37 + Double.hashCode(mem[i]);
319        }
320        return h;
321    }
322
323    /**
324     * Makes a string representation
325     *
326     * @return string like "[a, b; c, d]"
327     */
328    @Override
329    public String toString() {
330        StringBuilder sb = new StringBuilder(n * m * 8);
331        sb.append("[");
332        for (int i = 0; i < mem.length; i++) {
333            if (i > 0) sb.append(i % m == 0 ? "; " : ", ");
334            sb.append(mem[i]);
335        }
336        sb.append("]");
337        return sb.toString();
338    }
339
340}
341