1/*
2 * Copyright (C) 2013 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package com.android.gallery3d.filtershow.tools;
18
19import android.util.Log;
20
21public class MatrixFit {
22    // Simple implementation of a matrix fit in N dimensions.
23
24    private static final String LOGTAG = "MatrixFit";
25
26    private double[][] mMatrix;
27    private int mDimension;
28    private boolean mValid = false;
29    private static double sEPS = 1.0f/10000000000.0f;
30
31    public MatrixFit(double[][] from, double[][] to) {
32        mValid = fit(from, to);
33    }
34
35    public int getDimension() {
36        return mDimension;
37    }
38
39    public boolean isValid() {
40        return mValid;
41    }
42
43    public double[][] getMatrix() {
44        return mMatrix;
45    }
46
47    public boolean fit(double[][] from, double[][] to) {
48        if ((from.length != to.length) || (from.length < 1)) {
49            Log.e(LOGTAG, "from and to must be of same size");
50            return false;
51        }
52
53        mDimension = from[0].length;
54        mMatrix = new double[mDimension +1][mDimension + mDimension +1];
55
56        if (from.length < mDimension) {
57            Log.e(LOGTAG, "Too few points => under-determined system");
58            return false;
59        }
60
61        double[][] q = new double[from.length][mDimension];
62        for (int i = 0; i < from.length; i++) {
63            for (int j = 0; j < mDimension; j++) {
64                q[i][j] = from[i][j];
65            }
66        }
67
68        double[][] p = new double[to.length][mDimension];
69        for (int i = 0; i < to.length; i++) {
70            for (int j = 0; j < mDimension; j++) {
71                p[i][j] = to[i][j];
72            }
73        }
74
75        // Make an empty (dim) x (dim + 1) matrix and fill it
76        double[][] c = new double[mDimension+1][mDimension];
77        for (int j = 0; j < mDimension; j++) {
78            for (int k = 0; k < mDimension + 1; k++) {
79                for (int i = 0; i < q.length; i++) {
80                    double qt = 1;
81                    if (k < mDimension) {
82                        qt = q[i][k];
83                    }
84                    c[k][j] += qt * p[i][j];
85                }
86            }
87        }
88
89        // Make an empty (dim+1) x (dim+1) matrix and fill it
90        double[][] Q = new double[mDimension+1][mDimension+1];
91        for (int qi = 0; qi < q.length; qi++) {
92            double[] qt = new double[mDimension + 1];
93            for (int i = 0; i < mDimension; i++) {
94                qt[i] = q[qi][i];
95            }
96            qt[mDimension] = 1;
97            for (int i = 0; i < mDimension + 1; i++) {
98                for (int j = 0; j < mDimension + 1; j++) {
99                    Q[i][j] += qt[i] * qt[j];
100                }
101            }
102        }
103
104        // Use a gaussian elimination to solve the linear system
105        for (int i = 0; i < mDimension + 1; i++) {
106            for (int j = 0; j < mDimension + 1; j++) {
107                mMatrix[i][j] = Q[i][j];
108            }
109            for (int j = 0; j < mDimension; j++) {
110                mMatrix[i][mDimension + 1 + j] = c[i][j];
111            }
112        }
113        if (!gaussianElimination(mMatrix)) {
114            return false;
115        }
116        return true;
117    }
118
119    public double[] apply(double[] point) {
120        if (mDimension != point.length) {
121            return null;
122        }
123        double[] res = new double[mDimension];
124        for (int j = 0; j < mDimension; j++) {
125            for (int i = 0; i < mDimension; i++) {
126                res[j] += point[i] * mMatrix[i][j+ mDimension +1];
127            }
128            res[j] += mMatrix[mDimension][j+ mDimension +1];
129        }
130        return res;
131    }
132
133    public void printEquation() {
134        for (int j = 0; j < mDimension; j++) {
135            String str = "x" + j + "' = ";
136            for (int i = 0; i < mDimension; i++) {
137                str += "x" + i + " * " + mMatrix[i][j+mDimension+1] + " + ";
138            }
139            str += mMatrix[mDimension][j+mDimension+1];
140            Log.v(LOGTAG, str);
141        }
142    }
143
144    private void printMatrix(String name, double[][] matrix) {
145        Log.v(LOGTAG, "name: " + name);
146        for (int i = 0; i < matrix.length; i++) {
147            String str = "";
148            for (int j = 0; j < matrix[0].length; j++) {
149                str += "" + matrix[i][j] + " ";
150            }
151            Log.v(LOGTAG, str);
152        }
153    }
154
155    /*
156     * Transforms the given matrix into a row echelon matrix
157     */
158    private boolean gaussianElimination(double[][] m) {
159        int h = m.length;
160        int w = m[0].length;
161
162        for (int y = 0; y < h; y++) {
163            int maxrow = y;
164            for (int y2 = y + 1; y2 < h; y2++) { // Find max pivot
165                if (Math.abs(m[y2][y]) > Math.abs(m[maxrow][y])) {
166                    maxrow = y2;
167                }
168            }
169            // swap
170            for (int i = 0; i < mDimension; i++) {
171                double t = m[y][i];
172                m[y][i] = m[maxrow][i];
173                m[maxrow][i] = t;
174            }
175
176            if (Math.abs(m[y][y]) <= sEPS) { // Singular Matrix
177                return false;
178            }
179            for (int y2 = y + 1; y2 < h; y2++) { // Eliminate column y
180                double c = m[y2][y] / m[y][y];
181                for (int x = y; x < w; x++) {
182                    m[y2][x] -= m[y][x] * c;
183                }
184            }
185        }
186        for (int y = h -1; y > -1; y--) { // Back substitution
187            double c = m[y][y];
188            for (int y2 = 0; y2 < y; y2++) {
189                for (int x = w - 1; x > y - 1; x--) {
190                    m[y2][x] -= m[y][x] * m[y2][y] / c;
191                }
192            }
193            m[y][y] /= c;
194            for (int x = h; x < w; x++) { // Normalize row y
195                m[y][x] /= c;
196            }
197        }
198        return true;
199    }
200}
201