/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.android.server.wifi.util; /** * Utility for doing basic matix calculations */ public class Matrix { public final int n; public final int m; public final double[] mem; /** * Creates a new matrix, initialized to zeros * * @param rows - number of rows (n) * @param cols - number of columns (m) */ public Matrix(int rows, int cols) { n = rows; m = cols; mem = new double[rows * cols]; } /** * Creates a new matrix using the provided array of values *

* Values are in row-major order. * * @param stride is the number of columns. * @param values is the array of values. * @throws IllegalArgumentException if length of values array not a multiple of stride */ public Matrix(int stride, double[] values) { n = (values.length + stride - 1) / stride; m = stride; mem = values; if (mem.length != n * m) throw new IllegalArgumentException(); } /** * Creates a new matrix duplicating the given one * * @param that is the source Matrix. */ public Matrix(Matrix that) { n = that.n; m = that.m; mem = new double[that.mem.length]; for (int i = 0; i < mem.length; i++) { mem[i] = that.mem[i]; } } /** * Gets the matrix coefficient from row i, column j * * @param i row number * @param j column number * @return Coefficient at i,j * @throws IndexOutOfBoundsException if an index is out of bounds */ public double get(int i, int j) { if (!(0 <= i && i < n && 0 <= j && j < m)) throw new IndexOutOfBoundsException(); return mem[i * m + j]; } /** * Store a matrix coefficient in row i, column j * * @param i row number * @param j column number * @param v Coefficient to store at i,j * @throws IndexOutOfBoundsException if an index is out of bounds */ public void put(int i, int j, double v) { if (!(0 <= i && i < n && 0 <= j && j < m)) throw new IndexOutOfBoundsException(); mem[i * m + j] = v; } /** * Forms the sum of two matrices, this and that * * @param that is the other matrix * @return newly allocated matrix representing the sum of this and that * @throws IllegalArgumentException if shapes differ */ public Matrix plus(Matrix that) { return plus(that, new Matrix(n, m)); } /** * Forms the sum of two matrices, this and that * * @param that is the other matrix * @param result is space to hold the result * @return result, filled with the matrix sum * @throws IllegalArgumentException if shapes differ */ public Matrix plus(Matrix that, Matrix result) { if (!(this.n == that.n && this.m == that.m && this.n == result.n && this.m == result.m)) { throw new IllegalArgumentException(); } for (int i = 0; i < mem.length; i++) { result.mem[i] = this.mem[i] + that.mem[i]; } return result; } /** * Forms the difference of two matrices, this and that * * @param that is the other matrix * @return newly allocated matrix representing the difference of this and that * @throws IllegalArgumentException if shapes differ */ public Matrix minus(Matrix that) { return minus(that, new Matrix(n, m)); } /** * Forms the difference of two matrices, this and that * * @param that is the other matrix * @param result is space to hold the result * @return result, filled with the matrix difference * @throws IllegalArgumentException if shapes differ */ public Matrix minus(Matrix that, Matrix result) { if (!(this.n == that.n && this.m == that.m && this.n == result.n && this.m == result.m)) { throw new IllegalArgumentException(); } for (int i = 0; i < mem.length; i++) { result.mem[i] = this.mem[i] - that.mem[i]; } return result; } /** * Forms a scalar product * * @param scalar is the value to multiply by * @return newly allocated matrix representing the product this and scalar */ public Matrix times(double scalar) { return times(scalar, new Matrix(n, m)); } /** * Forms a scalar product * * @param scalar is the value to multiply by * @param result is space to hold the result * @return result, filled with the matrix difference * @throws IllegalArgumentException if shapes differ */ public Matrix times(double scalar, Matrix result) { if (!(this.n == result.n && this.m == result.m)) { throw new IllegalArgumentException(); } for (int i = 0; i < mem.length; i++) { result.mem[i] = this.mem[i] * scalar; } return result; } /** * Forms the matrix product of two matrices, this and that * * @param that is the other matrix * @return newly allocated matrix representing the matrix product of this and that * @throws IllegalArgumentException if shapes are not conformant */ public Matrix dot(Matrix that) { return dot(that, new Matrix(this.n, that.m)); } /** * Forms the matrix product of two matrices, this and that *

* Caller supplies an object to contain the result, as well as scratch space * * @param that is the other matrix * @param result is space to hold the result * @return result, filled with the matrix product * @throws IllegalArgumentException if shapes are not conformant */ public Matrix dot(Matrix that, Matrix result) { if (!(this.n == result.n && this.m == that.n && that.m == result.m)) { throw new IllegalArgumentException(); } for (int i = 0; i < n; i++) { for (int j = 0; j < that.m; j++) { double s = 0.0; for (int k = 0; k < m; k++) { s += this.get(i, k) * that.get(k, j); } result.put(i, j, s); } } return result; } /** * Forms the matrix transpose * * @return newly allocated transpose matrix */ public Matrix transpose() { return transpose(new Matrix(m, n)); } /** * Forms the matrix transpose *

* Caller supplies an object to contain the result * * @param result is space to hold the result * @return result, filled with the matrix transpose * @throws IllegalArgumentException if result shape is wrong */ public Matrix transpose(Matrix result) { if (!(this.n == result.m && this.m == result.n)) throw new IllegalArgumentException(); for (int i = 0; i < n; i++) { for (int j = 0; j < m; j++) { result.put(j, i, get(i, j)); } } return result; } /** * Forms the inverse of a square matrix * * @return newly allocated matrix representing the matrix inverse * @throws ArithmeticException if the matrix is not invertible */ public Matrix inverse() { return inverse(new Matrix(n, m), new Matrix(n, 2 * m)); } /** * Forms the inverse of a square matrix * * @param result is space to hold the result * @param scratch is workspace of dimension n by 2*n * @return result, filled with the matrix inverse * @throws ArithmeticException if the matrix is not invertible * @throws IllegalArgumentException if shape of scratch or result is wrong */ public Matrix inverse(Matrix result, Matrix scratch) { if (!(n == m && n == result.n && m == result.m && n == scratch.n && 2 * m == scratch.m)) { throw new IllegalArgumentException(); } for (int i = 0; i < n; i++) { for (int j = 0; j < m; j++) { scratch.put(i, j, get(i, j)); scratch.put(i, m + j, i == j ? 1.0 : 0.0); } } for (int i = 0; i < n; i++) { int ibest = i; double vbest = Math.abs(scratch.get(ibest, ibest)); for (int ii = i + 1; ii < n; ii++) { double v = Math.abs(scratch.get(ii, i)); if (v > vbest) { ibest = ii; vbest = v; } } if (ibest != i) { for (int j = 0; j < scratch.m; j++) { double t = scratch.get(i, j); scratch.put(i, j, scratch.get(ibest, j)); scratch.put(ibest, j, t); } } double d = scratch.get(i, i); if (d == 0.0) throw new ArithmeticException("Singular matrix"); for (int j = 0; j < scratch.m; j++) { scratch.put(i, j, scratch.get(i, j) / d); } for (int ii = i + 1; ii < n; ii++) { d = scratch.get(ii, i); for (int j = 0; j < scratch.m; j++) { scratch.put(ii, j, scratch.get(ii, j) - d * scratch.get(i, j)); } } } for (int i = n - 1; i >= 0; i--) { for (int ii = 0; ii < i; ii++) { double d = scratch.get(ii, i); for (int j = 0; j < scratch.m; j++) { scratch.put(ii, j, scratch.get(ii, j) - d * scratch.get(i, j)); } } } for (int i = 0; i < result.n; i++) { for (int j = 0; j < result.m; j++) { result.put(i, j, scratch.get(i, m + j)); } } return result; } /** * Forms the matrix product with the transpose of a second matrix * * @param that is the other matrix * @return newly allocated matrix representing the matrix product of this and that.transpose() * @throws IllegalArgumentException if shapes are not conformant */ public Matrix dotTranspose(Matrix that) { return dotTranspose(that, new Matrix(this.n, that.n)); } /** * Forms the matrix product with the transpose of a second matrix *

* Caller supplies an object to contain the result, as well as scratch space * * @param that is the other matrix * @param result is space to hold the result * @return result, filled with the matrix product of this and that.transpose() * @throws IllegalArgumentException if shapes are not conformant */ public Matrix dotTranspose(Matrix that, Matrix result) { if (!(this.n == result.n && this.m == that.m && that.n == result.m)) { throw new IllegalArgumentException(); } for (int i = 0; i < n; i++) { for (int j = 0; j < that.n; j++) { double s = 0.0; for (int k = 0; k < m; k++) { s += this.get(i, k) * that.get(j, k); } result.put(i, j, s); } } return result; } /** * Tests for equality */ @Override public boolean equals(Object that) { if (this == that) return true; if (!(that instanceof Matrix)) return false; Matrix other = (Matrix) that; if (n != other.n) return false; if (m != other.m) return false; for (int i = 0; i < mem.length; i++) { if (mem[i] != other.mem[i]) return false; } return true; } /** * Calculates a hash code */ @Override public int hashCode() { int h = n * 101 + m; for (int i = 0; i < mem.length; i++) { h = h * 37 + Double.hashCode(mem[i]); } return h; } /** * Makes a string representation * * @return string like "[a, b; c, d]" */ @Override public String toString() { StringBuilder sb = new StringBuilder(n * m * 8); sb.append("["); for (int i = 0; i < mem.length; i++) { if (i > 0) sb.append(i % m == 0 ? "; " : ", "); sb.append(mem[i]); } sb.append("]"); return sb.toString(); } }