1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2001 Intel Corporation
5// Copyright (C) 2010 Gael Guennebaud <gael.guennebaud@inria.fr>
6// Copyright (C) 2009 Benoit Jacob <jacob.benoit.1@gmail.com>
7//
8// This Source Code Form is subject to the terms of the Mozilla
9// Public License v. 2.0. If a copy of the MPL was not distributed
10// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
11
12// The SSE code for the 4x4 float and double matrix inverse in this file
13// comes from the following Intel's library:
14// http://software.intel.com/en-us/articles/optimized-matrix-library-for-use-with-the-intel-pentiumr-4-processors-sse2-instructions/
15//
16// Here is the respective copyright and license statement:
17//
18//   Copyright (c) 2001 Intel Corporation.
19//
20// Permition is granted to use, copy, distribute and prepare derivative works
21// of this library for any purpose and without fee, provided, that the above
22// copyright notice and this statement appear in all copies.
23// Intel makes no representations about the suitability of this software for
24// any purpose, and specifically disclaims all warranties.
25// See LEGAL.TXT for all the legal information.
26
27#ifndef EIGEN_INVERSE_SSE_H
28#define EIGEN_INVERSE_SSE_H
29
30namespace Eigen {
31
32namespace internal {
33
34template<typename MatrixType, typename ResultType>
35struct compute_inverse_size4<Architecture::SSE, float, MatrixType, ResultType>
36{
37  enum {
38    MatrixAlignment     = traits<MatrixType>::Alignment,
39    ResultAlignment     = traits<ResultType>::Alignment,
40    StorageOrdersMatch  = (MatrixType::Flags&RowMajorBit) == (ResultType::Flags&RowMajorBit)
41  };
42  typedef typename conditional<(MatrixType::Flags&LinearAccessBit),MatrixType const &,typename MatrixType::PlainObject>::type ActualMatrixType;
43
44  static void run(const MatrixType& mat, ResultType& result)
45  {
46    ActualMatrixType matrix(mat);
47    EIGEN_ALIGN16 const unsigned int _Sign_PNNP[4] = { 0x00000000, 0x80000000, 0x80000000, 0x00000000 };
48
49    // Load the full matrix into registers
50    __m128 _L1 = matrix.template packet<MatrixAlignment>( 0);
51    __m128 _L2 = matrix.template packet<MatrixAlignment>( 4);
52    __m128 _L3 = matrix.template packet<MatrixAlignment>( 8);
53    __m128 _L4 = matrix.template packet<MatrixAlignment>(12);
54
55    // The inverse is calculated using "Divide and Conquer" technique. The
56    // original matrix is divide into four 2x2 sub-matrices. Since each
57    // register holds four matrix element, the smaller matrices are
58    // represented as a registers. Hence we get a better locality of the
59    // calculations.
60
61    __m128 A, B, C, D; // the four sub-matrices
62    if(!StorageOrdersMatch)
63    {
64      A = _mm_unpacklo_ps(_L1, _L2);
65      B = _mm_unpacklo_ps(_L3, _L4);
66      C = _mm_unpackhi_ps(_L1, _L2);
67      D = _mm_unpackhi_ps(_L3, _L4);
68    }
69    else
70    {
71      A = _mm_movelh_ps(_L1, _L2);
72      B = _mm_movehl_ps(_L2, _L1);
73      C = _mm_movelh_ps(_L3, _L4);
74      D = _mm_movehl_ps(_L4, _L3);
75    }
76
77    __m128 iA, iB, iC, iD,                 // partial inverse of the sub-matrices
78            DC, AB;
79    __m128 dA, dB, dC, dD;                 // determinant of the sub-matrices
80    __m128 det, d, d1, d2;
81    __m128 rd;                             // reciprocal of the determinant
82
83    //  AB = A# * B
84    AB = _mm_mul_ps(_mm_shuffle_ps(A,A,0x0F), B);
85    AB = _mm_sub_ps(AB,_mm_mul_ps(_mm_shuffle_ps(A,A,0xA5), _mm_shuffle_ps(B,B,0x4E)));
86    //  DC = D# * C
87    DC = _mm_mul_ps(_mm_shuffle_ps(D,D,0x0F), C);
88    DC = _mm_sub_ps(DC,_mm_mul_ps(_mm_shuffle_ps(D,D,0xA5), _mm_shuffle_ps(C,C,0x4E)));
89
90    //  dA = |A|
91    dA = _mm_mul_ps(_mm_shuffle_ps(A, A, 0x5F),A);
92    dA = _mm_sub_ss(dA, _mm_movehl_ps(dA,dA));
93    //  dB = |B|
94    dB = _mm_mul_ps(_mm_shuffle_ps(B, B, 0x5F),B);
95    dB = _mm_sub_ss(dB, _mm_movehl_ps(dB,dB));
96
97    //  dC = |C|
98    dC = _mm_mul_ps(_mm_shuffle_ps(C, C, 0x5F),C);
99    dC = _mm_sub_ss(dC, _mm_movehl_ps(dC,dC));
100    //  dD = |D|
101    dD = _mm_mul_ps(_mm_shuffle_ps(D, D, 0x5F),D);
102    dD = _mm_sub_ss(dD, _mm_movehl_ps(dD,dD));
103
104    //  d = trace(AB*DC) = trace(A#*B*D#*C)
105    d = _mm_mul_ps(_mm_shuffle_ps(DC,DC,0xD8),AB);
106
107    //  iD = C*A#*B
108    iD = _mm_mul_ps(_mm_shuffle_ps(C,C,0xA0), _mm_movelh_ps(AB,AB));
109    iD = _mm_add_ps(iD,_mm_mul_ps(_mm_shuffle_ps(C,C,0xF5), _mm_movehl_ps(AB,AB)));
110    //  iA = B*D#*C
111    iA = _mm_mul_ps(_mm_shuffle_ps(B,B,0xA0), _mm_movelh_ps(DC,DC));
112    iA = _mm_add_ps(iA,_mm_mul_ps(_mm_shuffle_ps(B,B,0xF5), _mm_movehl_ps(DC,DC)));
113
114    //  d = trace(AB*DC) = trace(A#*B*D#*C) [continue]
115    d  = _mm_add_ps(d, _mm_movehl_ps(d, d));
116    d  = _mm_add_ss(d, _mm_shuffle_ps(d, d, 1));
117    d1 = _mm_mul_ss(dA,dD);
118    d2 = _mm_mul_ss(dB,dC);
119
120    //  iD = D*|A| - C*A#*B
121    iD = _mm_sub_ps(_mm_mul_ps(D,_mm_shuffle_ps(dA,dA,0)), iD);
122
123    //  iA = A*|D| - B*D#*C;
124    iA = _mm_sub_ps(_mm_mul_ps(A,_mm_shuffle_ps(dD,dD,0)), iA);
125
126    //  det = |A|*|D| + |B|*|C| - trace(A#*B*D#*C)
127    det = _mm_sub_ss(_mm_add_ss(d1,d2),d);
128    rd  = _mm_div_ss(_mm_set_ss(1.0f), det);
129
130//     #ifdef ZERO_SINGULAR
131//         rd = _mm_and_ps(_mm_cmpneq_ss(det,_mm_setzero_ps()), rd);
132//     #endif
133
134    //  iB = D * (A#B)# = D*B#*A
135    iB = _mm_mul_ps(D, _mm_shuffle_ps(AB,AB,0x33));
136    iB = _mm_sub_ps(iB, _mm_mul_ps(_mm_shuffle_ps(D,D,0xB1), _mm_shuffle_ps(AB,AB,0x66)));
137    //  iC = A * (D#C)# = A*C#*D
138    iC = _mm_mul_ps(A, _mm_shuffle_ps(DC,DC,0x33));
139    iC = _mm_sub_ps(iC, _mm_mul_ps(_mm_shuffle_ps(A,A,0xB1), _mm_shuffle_ps(DC,DC,0x66)));
140
141    rd = _mm_shuffle_ps(rd,rd,0);
142    rd = _mm_xor_ps(rd, _mm_load_ps((float*)_Sign_PNNP));
143
144    //  iB = C*|B| - D*B#*A
145    iB = _mm_sub_ps(_mm_mul_ps(C,_mm_shuffle_ps(dB,dB,0)), iB);
146
147    //  iC = B*|C| - A*C#*D;
148    iC = _mm_sub_ps(_mm_mul_ps(B,_mm_shuffle_ps(dC,dC,0)), iC);
149
150    //  iX = iX / det
151    iA = _mm_mul_ps(rd,iA);
152    iB = _mm_mul_ps(rd,iB);
153    iC = _mm_mul_ps(rd,iC);
154    iD = _mm_mul_ps(rd,iD);
155
156    Index res_stride = result.outerStride();
157    float* res = result.data();
158    pstoret<float, Packet4f, ResultAlignment>(res+0,            _mm_shuffle_ps(iA,iB,0x77));
159    pstoret<float, Packet4f, ResultAlignment>(res+res_stride,   _mm_shuffle_ps(iA,iB,0x22));
160    pstoret<float, Packet4f, ResultAlignment>(res+2*res_stride, _mm_shuffle_ps(iC,iD,0x77));
161    pstoret<float, Packet4f, ResultAlignment>(res+3*res_stride, _mm_shuffle_ps(iC,iD,0x22));
162  }
163
164};
165
166template<typename MatrixType, typename ResultType>
167struct compute_inverse_size4<Architecture::SSE, double, MatrixType, ResultType>
168{
169  enum {
170    MatrixAlignment     = traits<MatrixType>::Alignment,
171    ResultAlignment     = traits<ResultType>::Alignment,
172    StorageOrdersMatch  = (MatrixType::Flags&RowMajorBit) == (ResultType::Flags&RowMajorBit)
173  };
174  typedef typename conditional<(MatrixType::Flags&LinearAccessBit),MatrixType const &,typename MatrixType::PlainObject>::type ActualMatrixType;
175
176  static void run(const MatrixType& mat, ResultType& result)
177  {
178    ActualMatrixType matrix(mat);
179    const __m128d _Sign_NP = _mm_castsi128_pd(_mm_set_epi32(0x0,0x0,0x80000000,0x0));
180    const __m128d _Sign_PN = _mm_castsi128_pd(_mm_set_epi32(0x80000000,0x0,0x0,0x0));
181
182    // The inverse is calculated using "Divide and Conquer" technique. The
183    // original matrix is divide into four 2x2 sub-matrices. Since each
184    // register of the matrix holds two elements, the smaller matrices are
185    // consisted of two registers. Hence we get a better locality of the
186    // calculations.
187
188    // the four sub-matrices
189    __m128d A1, A2, B1, B2, C1, C2, D1, D2;
190
191    if(StorageOrdersMatch)
192    {
193      A1 = matrix.template packet<MatrixAlignment>( 0); B1 = matrix.template packet<MatrixAlignment>( 2);
194      A2 = matrix.template packet<MatrixAlignment>( 4); B2 = matrix.template packet<MatrixAlignment>( 6);
195      C1 = matrix.template packet<MatrixAlignment>( 8); D1 = matrix.template packet<MatrixAlignment>(10);
196      C2 = matrix.template packet<MatrixAlignment>(12); D2 = matrix.template packet<MatrixAlignment>(14);
197    }
198    else
199    {
200      __m128d tmp;
201      A1 = matrix.template packet<MatrixAlignment>( 0); C1 = matrix.template packet<MatrixAlignment>( 2);
202      A2 = matrix.template packet<MatrixAlignment>( 4); C2 = matrix.template packet<MatrixAlignment>( 6);
203      tmp = A1;
204      A1 = _mm_unpacklo_pd(A1,A2);
205      A2 = _mm_unpackhi_pd(tmp,A2);
206      tmp = C1;
207      C1 = _mm_unpacklo_pd(C1,C2);
208      C2 = _mm_unpackhi_pd(tmp,C2);
209
210      B1 = matrix.template packet<MatrixAlignment>( 8); D1 = matrix.template packet<MatrixAlignment>(10);
211      B2 = matrix.template packet<MatrixAlignment>(12); D2 = matrix.template packet<MatrixAlignment>(14);
212      tmp = B1;
213      B1 = _mm_unpacklo_pd(B1,B2);
214      B2 = _mm_unpackhi_pd(tmp,B2);
215      tmp = D1;
216      D1 = _mm_unpacklo_pd(D1,D2);
217      D2 = _mm_unpackhi_pd(tmp,D2);
218    }
219
220    __m128d iA1, iA2, iB1, iB2, iC1, iC2, iD1, iD2,     // partial invese of the sub-matrices
221            DC1, DC2, AB1, AB2;
222    __m128d dA, dB, dC, dD;     // determinant of the sub-matrices
223    __m128d det, d1, d2, rd;
224
225    //  dA = |A|
226    dA = _mm_shuffle_pd(A2, A2, 1);
227    dA = _mm_mul_pd(A1, dA);
228    dA = _mm_sub_sd(dA, _mm_shuffle_pd(dA,dA,3));
229    //  dB = |B|
230    dB = _mm_shuffle_pd(B2, B2, 1);
231    dB = _mm_mul_pd(B1, dB);
232    dB = _mm_sub_sd(dB, _mm_shuffle_pd(dB,dB,3));
233
234    //  AB = A# * B
235    AB1 = _mm_mul_pd(B1, _mm_shuffle_pd(A2,A2,3));
236    AB2 = _mm_mul_pd(B2, _mm_shuffle_pd(A1,A1,0));
237    AB1 = _mm_sub_pd(AB1, _mm_mul_pd(B2, _mm_shuffle_pd(A1,A1,3)));
238    AB2 = _mm_sub_pd(AB2, _mm_mul_pd(B1, _mm_shuffle_pd(A2,A2,0)));
239
240    //  dC = |C|
241    dC = _mm_shuffle_pd(C2, C2, 1);
242    dC = _mm_mul_pd(C1, dC);
243    dC = _mm_sub_sd(dC, _mm_shuffle_pd(dC,dC,3));
244    //  dD = |D|
245    dD = _mm_shuffle_pd(D2, D2, 1);
246    dD = _mm_mul_pd(D1, dD);
247    dD = _mm_sub_sd(dD, _mm_shuffle_pd(dD,dD,3));
248
249    //  DC = D# * C
250    DC1 = _mm_mul_pd(C1, _mm_shuffle_pd(D2,D2,3));
251    DC2 = _mm_mul_pd(C2, _mm_shuffle_pd(D1,D1,0));
252    DC1 = _mm_sub_pd(DC1, _mm_mul_pd(C2, _mm_shuffle_pd(D1,D1,3)));
253    DC2 = _mm_sub_pd(DC2, _mm_mul_pd(C1, _mm_shuffle_pd(D2,D2,0)));
254
255    //  rd = trace(AB*DC) = trace(A#*B*D#*C)
256    d1 = _mm_mul_pd(AB1, _mm_shuffle_pd(DC1, DC2, 0));
257    d2 = _mm_mul_pd(AB2, _mm_shuffle_pd(DC1, DC2, 3));
258    rd = _mm_add_pd(d1, d2);
259    rd = _mm_add_sd(rd, _mm_shuffle_pd(rd, rd,3));
260
261    //  iD = C*A#*B
262    iD1 = _mm_mul_pd(AB1, _mm_shuffle_pd(C1,C1,0));
263    iD2 = _mm_mul_pd(AB1, _mm_shuffle_pd(C2,C2,0));
264    iD1 = _mm_add_pd(iD1, _mm_mul_pd(AB2, _mm_shuffle_pd(C1,C1,3)));
265    iD2 = _mm_add_pd(iD2, _mm_mul_pd(AB2, _mm_shuffle_pd(C2,C2,3)));
266
267    //  iA = B*D#*C
268    iA1 = _mm_mul_pd(DC1, _mm_shuffle_pd(B1,B1,0));
269    iA2 = _mm_mul_pd(DC1, _mm_shuffle_pd(B2,B2,0));
270    iA1 = _mm_add_pd(iA1, _mm_mul_pd(DC2, _mm_shuffle_pd(B1,B1,3)));
271    iA2 = _mm_add_pd(iA2, _mm_mul_pd(DC2, _mm_shuffle_pd(B2,B2,3)));
272
273    //  iD = D*|A| - C*A#*B
274    dA = _mm_shuffle_pd(dA,dA,0);
275    iD1 = _mm_sub_pd(_mm_mul_pd(D1, dA), iD1);
276    iD2 = _mm_sub_pd(_mm_mul_pd(D2, dA), iD2);
277
278    //  iA = A*|D| - B*D#*C;
279    dD = _mm_shuffle_pd(dD,dD,0);
280    iA1 = _mm_sub_pd(_mm_mul_pd(A1, dD), iA1);
281    iA2 = _mm_sub_pd(_mm_mul_pd(A2, dD), iA2);
282
283    d1 = _mm_mul_sd(dA, dD);
284    d2 = _mm_mul_sd(dB, dC);
285
286    //  iB = D * (A#B)# = D*B#*A
287    iB1 = _mm_mul_pd(D1, _mm_shuffle_pd(AB2,AB1,1));
288    iB2 = _mm_mul_pd(D2, _mm_shuffle_pd(AB2,AB1,1));
289    iB1 = _mm_sub_pd(iB1, _mm_mul_pd(_mm_shuffle_pd(D1,D1,1), _mm_shuffle_pd(AB2,AB1,2)));
290    iB2 = _mm_sub_pd(iB2, _mm_mul_pd(_mm_shuffle_pd(D2,D2,1), _mm_shuffle_pd(AB2,AB1,2)));
291
292    //  det = |A|*|D| + |B|*|C| - trace(A#*B*D#*C)
293    det = _mm_add_sd(d1, d2);
294    det = _mm_sub_sd(det, rd);
295
296    //  iC = A * (D#C)# = A*C#*D
297    iC1 = _mm_mul_pd(A1, _mm_shuffle_pd(DC2,DC1,1));
298    iC2 = _mm_mul_pd(A2, _mm_shuffle_pd(DC2,DC1,1));
299    iC1 = _mm_sub_pd(iC1, _mm_mul_pd(_mm_shuffle_pd(A1,A1,1), _mm_shuffle_pd(DC2,DC1,2)));
300    iC2 = _mm_sub_pd(iC2, _mm_mul_pd(_mm_shuffle_pd(A2,A2,1), _mm_shuffle_pd(DC2,DC1,2)));
301
302    rd = _mm_div_sd(_mm_set_sd(1.0), det);
303//     #ifdef ZERO_SINGULAR
304//         rd = _mm_and_pd(_mm_cmpneq_sd(det,_mm_setzero_pd()), rd);
305//     #endif
306    rd = _mm_shuffle_pd(rd,rd,0);
307
308    //  iB = C*|B| - D*B#*A
309    dB = _mm_shuffle_pd(dB,dB,0);
310    iB1 = _mm_sub_pd(_mm_mul_pd(C1, dB), iB1);
311    iB2 = _mm_sub_pd(_mm_mul_pd(C2, dB), iB2);
312
313    d1 = _mm_xor_pd(rd, _Sign_PN);
314    d2 = _mm_xor_pd(rd, _Sign_NP);
315
316    //  iC = B*|C| - A*C#*D;
317    dC = _mm_shuffle_pd(dC,dC,0);
318    iC1 = _mm_sub_pd(_mm_mul_pd(B1, dC), iC1);
319    iC2 = _mm_sub_pd(_mm_mul_pd(B2, dC), iC2);
320
321    Index res_stride = result.outerStride();
322    double* res = result.data();
323    pstoret<double, Packet2d, ResultAlignment>(res+0,             _mm_mul_pd(_mm_shuffle_pd(iA2, iA1, 3), d1));
324    pstoret<double, Packet2d, ResultAlignment>(res+res_stride,    _mm_mul_pd(_mm_shuffle_pd(iA2, iA1, 0), d2));
325    pstoret<double, Packet2d, ResultAlignment>(res+2,             _mm_mul_pd(_mm_shuffle_pd(iB2, iB1, 3), d1));
326    pstoret<double, Packet2d, ResultAlignment>(res+res_stride+2,  _mm_mul_pd(_mm_shuffle_pd(iB2, iB1, 0), d2));
327    pstoret<double, Packet2d, ResultAlignment>(res+2*res_stride,  _mm_mul_pd(_mm_shuffle_pd(iC2, iC1, 3), d1));
328    pstoret<double, Packet2d, ResultAlignment>(res+3*res_stride,  _mm_mul_pd(_mm_shuffle_pd(iC2, iC1, 0), d2));
329    pstoret<double, Packet2d, ResultAlignment>(res+2*res_stride+2,_mm_mul_pd(_mm_shuffle_pd(iD2, iD1, 3), d1));
330    pstoret<double, Packet2d, ResultAlignment>(res+3*res_stride+2,_mm_mul_pd(_mm_shuffle_pd(iD2, iD1, 0), d2));
331  }
332};
333
334} // end namespace internal
335
336} // end namespace Eigen
337
338#endif // EIGEN_INVERSE_SSE_H
339