1/*
2 * Copyright 2013 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#pragma once
18
19#include <math.h>
20#include <stdint.h>
21#include <sys/types.h>
22
23#include <cmath>
24#include <exception>
25#include <iomanip>
26#include <stdexcept>
27
28#include <math/quat.h>
29#include <math/TVecHelpers.h>
30
31#include  <utils/String8.h>
32
33#ifndef LIKELY
34#define LIKELY_DEFINED_LOCAL
35#ifdef __cplusplus
36#   define LIKELY( exp )    (__builtin_expect( !!(exp), true ))
37#   define UNLIKELY( exp )  (__builtin_expect( !!(exp), false ))
38#else
39#   define LIKELY( exp )    (__builtin_expect( !!(exp), 1 ))
40#   define UNLIKELY( exp )  (__builtin_expect( !!(exp), 0 ))
41#endif
42#endif
43
44#define PURE __attribute__((pure))
45
46#if __cplusplus >= 201402L
47#define CONSTEXPR constexpr
48#else
49#define CONSTEXPR
50#endif
51
52namespace android {
53namespace details {
54// -------------------------------------------------------------------------------------
55
56/*
57 * No user serviceable parts here.
58 *
59 * Don't use this file directly, instead include ui/mat*.h
60 */
61
62
63/*
64 * Matrix utilities
65 */
66
67namespace matrix {
68
69inline constexpr int     transpose(int v)    { return v; }
70inline constexpr float   transpose(float v)  { return v; }
71inline constexpr double  transpose(double v) { return v; }
72
73inline constexpr int     trace(int v)    { return v; }
74inline constexpr float   trace(float v)  { return v; }
75inline constexpr double  trace(double v) { return v; }
76
77/*
78 * Matrix inversion
79 */
80template<typename MATRIX>
81MATRIX PURE gaussJordanInverse(const MATRIX& src) {
82    typedef typename MATRIX::value_type T;
83    static constexpr unsigned int N = MATRIX::NUM_ROWS;
84    MATRIX tmp(src);
85    MATRIX inverted(1);
86
87    for (size_t i = 0; i < N; ++i) {
88        // look for largest element in i'th column
89        size_t swap = i;
90        T t = std::abs(tmp[i][i]);
91        for (size_t j = i + 1; j < N; ++j) {
92            const T t2 = std::abs(tmp[j][i]);
93            if (t2 > t) {
94                swap = j;
95                t = t2;
96            }
97        }
98
99        if (swap != i) {
100            // swap columns.
101            std::swap(tmp[i], tmp[swap]);
102            std::swap(inverted[i], inverted[swap]);
103        }
104
105        const T denom(tmp[i][i]);
106        for (size_t k = 0; k < N; ++k) {
107            tmp[i][k] /= denom;
108            inverted[i][k] /= denom;
109        }
110
111        // Factor out the lower triangle
112        for (size_t j = 0; j < N; ++j) {
113            if (j != i) {
114                const T d = tmp[j][i];
115                for (size_t k = 0; k < N; ++k) {
116                    tmp[j][k] -= tmp[i][k] * d;
117                    inverted[j][k] -= inverted[i][k] * d;
118                }
119            }
120        }
121    }
122
123    return inverted;
124}
125
126
127//------------------------------------------------------------------------------
128// 2x2 matrix inverse is easy.
129template <typename MATRIX>
130CONSTEXPR MATRIX PURE fastInverse2(const MATRIX& x) {
131    typedef typename MATRIX::value_type T;
132
133    // Assuming the input matrix is:
134    // | a b |
135    // | c d |
136    //
137    // The analytic inverse is
138    // | d -b |
139    // | -c a | / (a d - b c)
140    //
141    // Importantly, our matrices are column-major!
142
143    MATRIX inverted(MATRIX::NO_INIT);
144
145    const T a = x[0][0];
146    const T c = x[0][1];
147    const T b = x[1][0];
148    const T d = x[1][1];
149
150    const T det((a * d) - (b * c));
151    inverted[0][0] =  d / det;
152    inverted[0][1] = -c / det;
153    inverted[1][0] = -b / det;
154    inverted[1][1] =  a / det;
155    return inverted;
156}
157
158
159//------------------------------------------------------------------------------
160// From the Wikipedia article on matrix inversion's section on fast 3x3
161// matrix inversion:
162// http://en.wikipedia.org/wiki/Invertible_matrix#Inversion_of_3.C3.973_matrices
163template <typename MATRIX>
164CONSTEXPR MATRIX PURE fastInverse3(const MATRIX& x) {
165    typedef typename MATRIX::value_type T;
166
167    // Assuming the input matrix is:
168    // | a b c |
169    // | d e f |
170    // | g h i |
171    //
172    // The analytic inverse is
173    // | A B C |^T
174    // | D E F |
175    // | G H I | / determinant
176    //
177    // Which is
178    // | A D G |
179    // | B E H |
180    // | C F I | / determinant
181    //
182    // Where:
183    // A = (ei - fh), B = (fg - di), C = (dh - eg)
184    // D = (ch - bi), E = (ai - cg), F = (bg - ah)
185    // G = (bf - ce), H = (cd - af), I = (ae - bd)
186    //
187    // and the determinant is a*A + b*B + c*C (The rule of Sarrus)
188    //
189    // Importantly, our matrices are column-major!
190
191    MATRIX inverted(MATRIX::NO_INIT);
192
193    const T a = x[0][0];
194    const T b = x[1][0];
195    const T c = x[2][0];
196    const T d = x[0][1];
197    const T e = x[1][1];
198    const T f = x[2][1];
199    const T g = x[0][2];
200    const T h = x[1][2];
201    const T i = x[2][2];
202
203    // Do the full analytic inverse
204    const T A = e * i - f * h;
205    const T B = f * g - d * i;
206    const T C = d * h - e * g;
207    inverted[0][0] = A;                 // A
208    inverted[0][1] = B;                 // B
209    inverted[0][2] = C;                 // C
210    inverted[1][0] = c * h - b * i;     // D
211    inverted[1][1] = a * i - c * g;     // E
212    inverted[1][2] = b * g - a * h;     // F
213    inverted[2][0] = b * f - c * e;     // G
214    inverted[2][1] = c * d - a * f;     // H
215    inverted[2][2] = a * e - b * d;     // I
216
217    const T det(a * A + b * B + c * C);
218    for (size_t col = 0; col < 3; ++col) {
219        for (size_t row = 0; row < 3; ++row) {
220            inverted[col][row] /= det;
221        }
222    }
223
224    return inverted;
225}
226
227/**
228 * Inversion function which switches on the matrix size.
229 * @warning This function assumes the matrix is invertible. The result is
230 * undefined if it is not. It is the responsibility of the caller to
231 * make sure the matrix is not singular.
232 */
233template <typename MATRIX>
234inline constexpr MATRIX PURE inverse(const MATRIX& matrix) {
235    static_assert(MATRIX::NUM_ROWS == MATRIX::NUM_COLS, "only square matrices can be inverted");
236    return (MATRIX::NUM_ROWS == 2) ? fastInverse2<MATRIX>(matrix) :
237          ((MATRIX::NUM_ROWS == 3) ? fastInverse3<MATRIX>(matrix) :
238                    gaussJordanInverse<MATRIX>(matrix));
239}
240
241template<typename MATRIX_R, typename MATRIX_A, typename MATRIX_B>
242CONSTEXPR MATRIX_R PURE multiply(const MATRIX_A& lhs, const MATRIX_B& rhs) {
243    // pre-requisite:
244    //  lhs : D columns, R rows
245    //  rhs : C columns, D rows
246    //  res : C columns, R rows
247
248    static_assert(MATRIX_A::NUM_COLS == MATRIX_B::NUM_ROWS,
249            "matrices can't be multiplied. invalid dimensions.");
250    static_assert(MATRIX_R::NUM_COLS == MATRIX_B::NUM_COLS,
251            "invalid dimension of matrix multiply result.");
252    static_assert(MATRIX_R::NUM_ROWS == MATRIX_A::NUM_ROWS,
253            "invalid dimension of matrix multiply result.");
254
255    MATRIX_R res(MATRIX_R::NO_INIT);
256    for (size_t col = 0; col < MATRIX_R::NUM_COLS; ++col) {
257        res[col] = lhs * rhs[col];
258    }
259    return res;
260}
261
262// transpose. this handles matrices of matrices
263template <typename MATRIX>
264CONSTEXPR MATRIX PURE transpose(const MATRIX& m) {
265    // for now we only handle square matrix transpose
266    static_assert(MATRIX::NUM_COLS == MATRIX::NUM_ROWS, "transpose only supports square matrices");
267    MATRIX result(MATRIX::NO_INIT);
268    for (size_t col = 0; col < MATRIX::NUM_COLS; ++col) {
269        for (size_t row = 0; row < MATRIX::NUM_ROWS; ++row) {
270            result[col][row] = transpose(m[row][col]);
271        }
272    }
273    return result;
274}
275
276// trace. this handles matrices of matrices
277template <typename MATRIX>
278CONSTEXPR typename MATRIX::value_type PURE trace(const MATRIX& m) {
279    static_assert(MATRIX::NUM_COLS == MATRIX::NUM_ROWS, "trace only defined for square matrices");
280    typename MATRIX::value_type result(0);
281    for (size_t col = 0; col < MATRIX::NUM_COLS; ++col) {
282        result += trace(m[col][col]);
283    }
284    return result;
285}
286
287// diag. this handles matrices of matrices
288template <typename MATRIX>
289CONSTEXPR typename MATRIX::col_type PURE diag(const MATRIX& m) {
290    static_assert(MATRIX::NUM_COLS == MATRIX::NUM_ROWS, "diag only defined for square matrices");
291    typename MATRIX::col_type result(MATRIX::col_type::NO_INIT);
292    for (size_t col = 0; col < MATRIX::NUM_COLS; ++col) {
293        result[col] = m[col][col];
294    }
295    return result;
296}
297
298//------------------------------------------------------------------------------
299// This is taken from the Imath MatrixAlgo code, and is identical to Eigen.
300template <typename MATRIX>
301TQuaternion<typename MATRIX::value_type> extractQuat(const MATRIX& mat) {
302    typedef typename MATRIX::value_type T;
303
304    TQuaternion<T> quat(TQuaternion<T>::NO_INIT);
305
306    // Compute the trace to see if it is positive or not.
307    const T trace = mat[0][0] + mat[1][1] + mat[2][2];
308
309    // check the sign of the trace
310    if (LIKELY(trace > 0)) {
311        // trace is positive
312        T s = std::sqrt(trace + 1);
313        quat.w = T(0.5) * s;
314        s = T(0.5) / s;
315        quat.x = (mat[1][2] - mat[2][1]) * s;
316        quat.y = (mat[2][0] - mat[0][2]) * s;
317        quat.z = (mat[0][1] - mat[1][0]) * s;
318    } else {
319        // trace is negative
320
321        // Find the index of the greatest diagonal
322        size_t i = 0;
323        if (mat[1][1] > mat[0][0]) { i = 1; }
324        if (mat[2][2] > mat[i][i]) { i = 2; }
325
326        // Get the next indices: (n+1)%3
327        static constexpr size_t next_ijk[3] = { 1, 2, 0 };
328        size_t j = next_ijk[i];
329        size_t k = next_ijk[j];
330        T s = std::sqrt((mat[i][i] - (mat[j][j] + mat[k][k])) + 1);
331        quat[i] = T(0.5) * s;
332        if (s != 0) {
333            s = T(0.5) / s;
334        }
335        quat.w  = (mat[j][k] - mat[k][j]) * s;
336        quat[j] = (mat[i][j] + mat[j][i]) * s;
337        quat[k] = (mat[i][k] + mat[k][i]) * s;
338    }
339    return quat;
340}
341
342template <typename MATRIX>
343String8 asString(const MATRIX& m) {
344    String8 s;
345    for (size_t c = 0; c < MATRIX::col_size(); c++) {
346        s.append("|  ");
347        for (size_t r = 0; r < MATRIX::row_size(); r++) {
348            s.appendFormat("%7.2f  ", m[r][c]);
349        }
350        s.append("|\n");
351    }
352    return s;
353}
354
355}  // namespace matrix
356
357// -------------------------------------------------------------------------------------
358
359/*
360 * TMatProductOperators implements basic arithmetic and basic compound assignments
361 * operators on a vector of type BASE<T>.
362 *
363 * BASE only needs to implement operator[] and size().
364 * By simply inheriting from TMatProductOperators<BASE, T> BASE will automatically
365 * get all the functionality here.
366 */
367
368template <template<typename T> class BASE, typename T>
369class TMatProductOperators {
370public:
371    // multiply by a scalar
372    BASE<T>& operator *= (T v) {
373        BASE<T>& lhs(static_cast< BASE<T>& >(*this));
374        for (size_t col = 0; col < BASE<T>::NUM_COLS; ++col) {
375            lhs[col] *= v;
376        }
377        return lhs;
378    }
379
380    //  matrix *= matrix
381    template<typename U>
382    const BASE<T>& operator *= (const BASE<U>& rhs) {
383        BASE<T>& lhs(static_cast< BASE<T>& >(*this));
384        lhs = matrix::multiply<BASE<T> >(lhs, rhs);
385        return lhs;
386    }
387
388    // divide by a scalar
389    BASE<T>& operator /= (T v) {
390        BASE<T>& lhs(static_cast< BASE<T>& >(*this));
391        for (size_t col = 0; col < BASE<T>::NUM_COLS; ++col) {
392            lhs[col] /= v;
393        }
394        return lhs;
395    }
396
397    // matrix * matrix, result is a matrix of the same type than the lhs matrix
398    template<typename U>
399    friend CONSTEXPR BASE<T> PURE operator *(const BASE<T>& lhs, const BASE<U>& rhs) {
400        return matrix::multiply<BASE<T> >(lhs, rhs);
401    }
402};
403
404/*
405 * TMatSquareFunctions implements functions on a matrix of type BASE<T>.
406 *
407 * BASE only needs to implement:
408 *  - operator[]
409 *  - col_type
410 *  - row_type
411 *  - COL_SIZE
412 *  - ROW_SIZE
413 *
414 * By simply inheriting from TMatSquareFunctions<BASE, T> BASE will automatically
415 * get all the functionality here.
416 */
417
418template<template<typename U> class BASE, typename T>
419class TMatSquareFunctions {
420public:
421
422    /*
423     * NOTE: the functions below ARE NOT member methods. They are friend functions
424     * with they definition inlined with their declaration. This makes these
425     * template functions available to the compiler when (and only when) this class
426     * is instantiated, at which point they're only templated on the 2nd parameter
427     * (the first one, BASE<T> being known).
428     */
429    friend inline CONSTEXPR BASE<T> PURE inverse(const BASE<T>& matrix) {
430        return matrix::inverse(matrix);
431    }
432    friend inline constexpr BASE<T> PURE transpose(const BASE<T>& m) {
433        return matrix::transpose(m);
434    }
435    friend inline constexpr T PURE trace(const BASE<T>& m) {
436        return matrix::trace(m);
437    }
438};
439
440template<template<typename U> class BASE, typename T>
441class TMatHelpers {
442public:
443    constexpr inline size_t getColumnSize() const   { return BASE<T>::COL_SIZE; }
444    constexpr inline size_t getRowSize() const      { return BASE<T>::ROW_SIZE; }
445    constexpr inline size_t getColumnCount() const  { return BASE<T>::NUM_COLS; }
446    constexpr inline size_t getRowCount() const     { return BASE<T>::NUM_ROWS; }
447    constexpr inline size_t size()  const           { return BASE<T>::ROW_SIZE; }  // for TVec*<>
448
449    // array access
450    constexpr T const* asArray() const {
451        return &static_cast<BASE<T> const &>(*this)[0][0];
452    }
453
454    // element access
455    inline constexpr T const& operator()(size_t row, size_t col) const {
456        return static_cast<BASE<T> const &>(*this)[col][row];
457    }
458
459    inline T& operator()(size_t row, size_t col) {
460        return static_cast<BASE<T>&>(*this)[col][row];
461    }
462
463    template <typename VEC>
464    static CONSTEXPR BASE<T> translate(const VEC& t) {
465        BASE<T> r;
466        r[BASE<T>::NUM_COLS-1] = t;
467        return r;
468    }
469
470    template <typename VEC>
471    static constexpr BASE<T> scale(const VEC& s) {
472        return BASE<T>(s);
473    }
474
475    friend inline CONSTEXPR BASE<T> PURE abs(BASE<T> m) {
476        for (size_t col = 0; col < BASE<T>::NUM_COLS; ++col) {
477            m[col] = abs(m[col]);
478        }
479        return m;
480    }
481};
482
483// functions for 3x3 and 4x4 matrices
484template<template<typename U> class BASE, typename T>
485class TMatTransform {
486public:
487    inline constexpr TMatTransform() {
488        static_assert(BASE<T>::NUM_ROWS == 3 || BASE<T>::NUM_ROWS == 4, "3x3 or 4x4 matrices only");
489    }
490
491    template <typename A, typename VEC>
492    static CONSTEXPR BASE<T> rotate(A radian, const VEC& about) {
493        BASE<T> r;
494        T c = std::cos(radian);
495        T s = std::sin(radian);
496        if (about.x == 1 && about.y == 0 && about.z == 0) {
497            r[1][1] = c;   r[2][2] = c;
498            r[1][2] = s;   r[2][1] = -s;
499        } else if (about.x == 0 && about.y == 1 && about.z == 0) {
500            r[0][0] = c;   r[2][2] = c;
501            r[2][0] = s;   r[0][2] = -s;
502        } else if (about.x == 0 && about.y == 0 && about.z == 1) {
503            r[0][0] = c;   r[1][1] = c;
504            r[0][1] = s;   r[1][0] = -s;
505        } else {
506            VEC nabout = normalize(about);
507            typename VEC::value_type x = nabout.x;
508            typename VEC::value_type y = nabout.y;
509            typename VEC::value_type z = nabout.z;
510            T nc = 1 - c;
511            T xy = x * y;
512            T yz = y * z;
513            T zx = z * x;
514            T xs = x * s;
515            T ys = y * s;
516            T zs = z * s;
517            r[0][0] = x*x*nc +  c;    r[1][0] =  xy*nc - zs;    r[2][0] =  zx*nc + ys;
518            r[0][1] =  xy*nc + zs;    r[1][1] = y*y*nc +  c;    r[2][1] =  yz*nc - xs;
519            r[0][2] =  zx*nc - ys;    r[1][2] =  yz*nc + xs;    r[2][2] = z*z*nc +  c;
520
521            // Clamp results to -1, 1.
522            for (size_t col = 0; col < 3; ++col) {
523                for (size_t row = 0; row < 3; ++row) {
524                    r[col][row] = std::min(std::max(r[col][row], T(-1)), T(1));
525                }
526            }
527        }
528        return r;
529    }
530
531    /**
532     * Create a matrix from euler angles using YPR around YXZ respectively
533     * @param yaw about Y axis
534     * @param pitch about X axis
535     * @param roll about Z axis
536     */
537    template <
538        typename Y, typename P, typename R,
539        typename = typename std::enable_if<std::is_arithmetic<Y>::value >::type,
540        typename = typename std::enable_if<std::is_arithmetic<P>::value >::type,
541        typename = typename std::enable_if<std::is_arithmetic<R>::value >::type
542    >
543    static CONSTEXPR BASE<T> eulerYXZ(Y yaw, P pitch, R roll) {
544        return eulerZYX(roll, pitch, yaw);
545    }
546
547    /**
548     * Create a matrix from euler angles using YPR around ZYX respectively
549     * @param roll about X axis
550     * @param pitch about Y axis
551     * @param yaw about Z axis
552     *
553     * The euler angles are applied in ZYX order. i.e: a vector is first rotated
554     * about X (roll) then Y (pitch) and then Z (yaw).
555     */
556    template <
557    typename Y, typename P, typename R,
558    typename = typename std::enable_if<std::is_arithmetic<Y>::value >::type,
559    typename = typename std::enable_if<std::is_arithmetic<P>::value >::type,
560    typename = typename std::enable_if<std::is_arithmetic<R>::value >::type
561    >
562    static CONSTEXPR BASE<T> eulerZYX(Y yaw, P pitch, R roll) {
563        BASE<T> r;
564        T cy = std::cos(yaw);
565        T sy = std::sin(yaw);
566        T cp = std::cos(pitch);
567        T sp = std::sin(pitch);
568        T cr = std::cos(roll);
569        T sr = std::sin(roll);
570        T cc = cr * cy;
571        T cs = cr * sy;
572        T sc = sr * cy;
573        T ss = sr * sy;
574        r[0][0] = cp * cy;
575        r[0][1] = cp * sy;
576        r[0][2] = -sp;
577        r[1][0] = sp * sc - cs;
578        r[1][1] = sp * ss + cc;
579        r[1][2] = cp * sr;
580        r[2][0] = sp * cc + ss;
581        r[2][1] = sp * cs - sc;
582        r[2][2] = cp * cr;
583
584        // Clamp results to -1, 1.
585        for (size_t col = 0; col < 3; ++col) {
586            for (size_t row = 0; row < 3; ++row) {
587                r[col][row] = std::min(std::max(r[col][row], T(-1)), T(1));
588            }
589        }
590        return r;
591    }
592
593    TQuaternion<T> toQuaternion() const {
594        return matrix::extractQuat(static_cast<const BASE<T>&>(*this));
595    }
596};
597
598
599template <template<typename T> class BASE, typename T>
600class TMatDebug {
601public:
602    friend std::ostream& operator<<(std::ostream& stream, const BASE<T>& m) {
603        for (size_t row = 0; row < BASE<T>::NUM_ROWS; ++row) {
604            if (row != 0) {
605                stream << std::endl;
606            }
607            if (row == 0) {
608                stream << "/ ";
609            } else if (row == BASE<T>::NUM_ROWS-1) {
610                stream << "\\ ";
611            } else {
612                stream << "| ";
613            }
614            for (size_t col = 0; col < BASE<T>::NUM_COLS; ++col) {
615                stream << std::setw(10) << std::to_string(m[col][row]);
616            }
617            if (row == 0) {
618                stream << " \\";
619            } else if (row == BASE<T>::NUM_ROWS-1) {
620                stream << " /";
621            } else {
622                stream << " |";
623            }
624        }
625        return stream;
626    }
627
628    String8 asString() const {
629        return matrix::asString(static_cast<const BASE<T>&>(*this));
630    }
631};
632
633// -------------------------------------------------------------------------------------
634}  // namespace details
635}  // namespace android
636
637#ifdef LIKELY_DEFINED_LOCAL
638#undef LIKELY_DEFINED_LOCAL
639#undef LIKELY
640#undef UNLIKELY
641#endif //LIKELY_DEFINED_LOCAL
642
643#undef PURE
644#undef CONSTEXPR
645