1/*
2 * Copyright (C) 2008 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17/* ---- includes ----------------------------------------------------------- */
18
19#include "b_TensorEm/Int32Mat.h"
20#include "b_TensorEm/Functions.h"
21#include "b_BasicEm/Math.h"
22#include "b_BasicEm/Functions.h"
23#include "b_BasicEm/Memory.h"
24
25/* ------------------------------------------------------------------------- */
26
27/* ========================================================================= */
28/*                                                                           */
29/* ---- \ghd{ auxiliary functions } ---------------------------------------- */
30/*                                                                           */
31/* ========================================================================= */
32
33/* ------------------------------------------------------------------------- */
34
35void bts_Int32Mat_reduceToNBits( int32* ptrA, uint32 sizeA, int32* bbpPtrA, uint32 nBitsA )
36{
37	int32 shiftL;
38
39	/* find max element */
40	int32 maxL = 0;
41	int32* ptrL = ptrA;
42	int32 iL = sizeA;
43	while( iL-- )
44	{
45		int32 xL = *ptrL++;
46		if( xL < 0 ) xL = -xL;
47		if( xL > maxL ) maxL = xL;
48	}
49
50	/* determine shift */
51	shiftL = bts_absIntLog2( maxL ) + 1 - nBitsA;
52
53	if( shiftL > 0 )
54	{
55		ptrL = ptrA;
56		iL = sizeA;
57		while( iL-- )
58		{
59			*ptrL = ( ( *ptrL >> ( shiftL - 1 ) ) + 1 ) >> 1;
60			ptrL++;
61		}
62
63		*bbpPtrA -= shiftL;
64	}
65}
66
67/* ------------------------------------------------------------------------- */
68
69/* ========================================================================= */
70/*                                                                           */
71/* ---- \ghd{ constructor / destructor } ----------------------------------- */
72/*                                                                           */
73/* ========================================================================= */
74
75/* ------------------------------------------------------------------------- */
76
77void bts_Int32Mat_init( struct bbs_Context* cpA,
78					    struct bts_Int32Mat* ptrA )
79{
80	ptrA->widthE = 0;
81	bbs_Int32Arr_init( cpA, &ptrA->arrE );
82}
83
84/* ------------------------------------------------------------------------- */
85
86void bts_Int32Mat_exit( struct bbs_Context* cpA,
87					    struct bts_Int32Mat* ptrA )
88{
89	ptrA->widthE = 0;
90	bbs_Int32Arr_exit( cpA, &ptrA->arrE );
91}
92/* ------------------------------------------------------------------------- */
93
94/* ========================================================================= */
95/*                                                                           */
96/* ---- \ghd{ operators } -------------------------------------------------- */
97/*                                                                           */
98/* ========================================================================= */
99
100/* ------------------------------------------------------------------------- */
101
102/* ========================================================================= */
103/*                                                                           */
104/* ---- \ghd{ query functions } -------------------------------------------- */
105/*                                                                           */
106/* ========================================================================= */
107
108/* ------------------------------------------------------------------------- */
109
110/* ========================================================================= */
111/*                                                                           */
112/* ---- \ghd{ modify functions } ------------------------------------------- */
113/*                                                                           */
114/* ========================================================================= */
115
116/* ------------------------------------------------------------------------- */
117
118void bts_Int32Mat_create( struct bbs_Context* cpA,
119						  struct bts_Int32Mat* ptrA,
120						  int32 widthA,
121				          struct bbs_MemSeg* mspA )
122{
123	if( bbs_Context_error( cpA ) ) return;
124	bbs_Int32Arr_create( cpA, &ptrA->arrE, widthA * widthA, mspA );
125	ptrA->widthE = widthA;
126}
127
128/* ------------------------------------------------------------------------- */
129
130void bts_Int32Mat_copy( struct bbs_Context* cpA,
131					    struct bts_Int32Mat* ptrA,
132						const struct bts_Int32Mat* srcPtrA )
133{
134	if( ptrA->widthE != srcPtrA->widthE )
135	{
136		bbs_ERROR0( "void bts_Int32Mat_copy( struct bts_Int32Mat* ptrA, struct bts_Int32Mat* srcPtrA ):\n"
137			       "size mismatch" );
138		return;
139	}
140
141	bbs_Int32Arr_copy( cpA, &ptrA->arrE, &srcPtrA->arrE );
142}
143
144/* ------------------------------------------------------------------------- */
145
146/* ========================================================================= */
147/*                                                                           */
148/* ---- \ghd{ I/O } -------------------------------------------------------- */
149/*                                                                           */
150/* ========================================================================= */
151
152/* ------------------------------------------------------------------------- */
153
154uint32 bts_Int32Mat_memSize( struct bbs_Context* cpA,
155							 const struct bts_Int32Mat *ptrA )
156{
157	return  bbs_SIZEOF16( uint32 )
158		  + bbs_SIZEOF16( uint32 ) /* version */
159		  + bbs_SIZEOF16( ptrA->widthE )
160		  + bbs_Int32Arr_memSize( cpA, &ptrA->arrE );
161}
162
163/* ------------------------------------------------------------------------- */
164
165uint32 bts_Int32Mat_memWrite( struct bbs_Context* cpA,
166							  const struct bts_Int32Mat* ptrA,
167							  uint16* memPtrA )
168{
169	uint32 memSizeL = bts_Int32Mat_memSize( cpA, ptrA );
170	memPtrA += bbs_memWrite32( &memSizeL, memPtrA );
171	memPtrA += bbs_memWriteUInt32( bts_INT32MAT_VERSION, memPtrA );
172	memPtrA += bbs_memWrite32( &ptrA->widthE, memPtrA );
173	memPtrA += bbs_Int32Arr_memWrite( cpA, &ptrA->arrE, memPtrA );
174	return memSizeL;
175}
176
177/* ------------------------------------------------------------------------- */
178
179uint32 bts_Int32Mat_memRead( struct bbs_Context* cpA,
180							 struct bts_Int32Mat* ptrA,
181							 const uint16* memPtrA,
182				             struct bbs_MemSeg* mspA )
183{
184	uint32 memSizeL, versionL;
185	if( bbs_Context_error( cpA ) ) return 0;
186	memPtrA += bbs_memRead32( &memSizeL, memPtrA );
187	memPtrA += bbs_memReadVersion32( cpA, &versionL, bts_INT32MAT_VERSION, memPtrA );
188	memPtrA += bbs_memRead32( &ptrA->widthE, memPtrA );
189	memPtrA += bbs_Int32Arr_memRead( cpA, &ptrA->arrE, memPtrA, mspA );
190
191	if( memSizeL != bts_Int32Mat_memSize( cpA, ptrA ) )
192	{
193		bbs_ERR0( bbs_ERR_CORRUPT_DATA, "uint32 bts_Int32Mat_memRead( const struct bts_Int32Mat* ptrA, const void* memPtrA ):\n"
194                  "size mismatch" );
195	}
196	return memSizeL;
197}
198
199/* ------------------------------------------------------------------------- */
200
201/* ========================================================================= */
202/*                                                                           */
203/* ---- \ghd{ exec functions } --------------------------------------------- */
204/*                                                                           */
205/* ========================================================================= */
206
207/* ------------------------------------------------------------------------- */
208
209flag bts_Int32Mat_solve( struct bbs_Context* cpA,
210						 const int32* matA,
211						 int32 matWidthA,
212						 const int32* inVecA,
213						 int32* outVecA,
214						 int32 bbpA,
215						 int32* tmpMatA,
216						 int32* tmpVecA )
217{
218	bbs_memcpy32( tmpMatA, matA, ( matWidthA * matWidthA ) * bbs_SIZEOF32( int32 ) );
219
220	return bts_Int32Mat_solve2( cpA,
221		                        tmpMatA,
222								matWidthA,
223								inVecA,
224								outVecA,
225								bbpA,
226								tmpVecA );
227}
228
229/* ------------------------------------------------------------------------- */
230
231flag bts_Int32Mat_solve2( struct bbs_Context* cpA,
232						  int32* matA,
233						  int32 matWidthA,
234						  const int32* inVecA,
235						  int32* outVecA,
236						  int32 bbpA,
237						  int32* tmpVecA )
238{
239	int32 sizeL = matWidthA;
240	int32 bbpL = bbpA;
241	int32 iL, jL, kL;
242	int32 iPivL;
243	int32 jPivL;
244
245	int32* vecL      = outVecA;
246	int32* matL      = matA;
247	int32* checkArrL = tmpVecA;
248
249	for( iL = 0; iL < sizeL; iL++ )
250	{
251		checkArrL[ iL ] = 0;
252	}
253
254	bbs_memcpy32( outVecA, inVecA, sizeL * bbs_SIZEOF32( int32 ) );
255
256	iPivL = 0;
257
258	for( kL = 0; kL < sizeL; kL++ )
259	{
260		/* find pivot */
261		int32 maxAbsL = 0;
262		int32* pivRowL;
263
264		int32 bbp_pivRowL, bbp_vecL, shiftL;
265
266		jPivL = -1;
267		for( iL = 0; iL < sizeL; iL++ )
268		{
269			if( checkArrL[ iL ] != 1 )
270			{
271				int32* rowL = matL + ( iL * sizeL );
272				for( jL = 0; jL < sizeL; jL++ )
273				{
274					if( checkArrL[ jL ] == 0 )
275					{
276						int32 absElemL = rowL[ jL ];
277						if( absElemL < 0 ) absElemL = -absElemL;
278						if( maxAbsL < absElemL )
279						{
280							maxAbsL = absElemL;
281							iPivL = iL;
282							jPivL = jL;
283						}
284					}
285					else if( checkArrL[ jL ] > 1 )
286					{
287						return FALSE;
288					}
289				}
290			}
291		}
292
293		/* successfull ? */
294		if( jPivL < 0 )
295		{
296			return FALSE;
297		}
298
299		checkArrL[ jPivL ]++;
300
301		/* exchange rows to put pivot on diagonal, if neccessary */
302		if( iPivL != jPivL )
303		{
304			int32* row1PtrL = matL + ( iPivL * sizeL );
305			int32* row2PtrL = matL + ( jPivL * sizeL );
306			for( jL = 0; jL < sizeL; jL++ )
307			{
308				int32 tmpL = *row1PtrL;
309				*row1PtrL++ = *row2PtrL;
310				*row2PtrL++ = tmpL;
311			}
312
313			{
314				int32 tmpL = vecL[ jPivL ];
315				vecL[ jPivL ] = vecL[ iPivL ];
316				vecL[ iPivL ] = tmpL;
317			}
318		}
319		/* now index jPivL specifies pivot row and maximum element */
320
321
322		/**	Overflow protection: only if the highest bit of the largest matrix element is set,
323		 *	we need to shift the whole matrix and the right side vector 1 bit to the right,
324		 *	to make sure there can be no overflow when the pivot row gets subtracted from the
325		 *	other rows.
326		 *	Getting that close to overflow is a rare event, so this shift will happen only
327		 *	occasionally, or not at all.
328		 */
329		if( maxAbsL & 1073741824 )  /*( 1 << 30 )*/
330		{
331			/* right shift matrix by 1 */
332			int32 iL = sizeL * sizeL;
333			int32* ptrL = matL;
334			while( iL-- )
335			{
336				*ptrL = ( *ptrL + 1 ) >> 1;
337				ptrL++;
338			}
339
340			/* right shift right side vector by 1 */
341			iL = sizeL;
342			ptrL = vecL;
343			while( iL-- )
344			{
345				*ptrL = ( *ptrL + 1 ) >> 1;
346				ptrL++;
347			}
348
349			/* decrement bbpL */
350			bbpL--;
351		}
352
353
354		/* reduce elements of pivot row to 15 bit */
355		pivRowL = matL + jPivL * sizeL;
356		bbp_pivRowL = bbpL;
357		bts_Int32Mat_reduceToNBits( pivRowL, sizeL, &bbp_pivRowL, 15 );
358
359		/* scale pivot row such that maximum equals 1 */
360		{
361			int32 maxL = pivRowL[ jPivL ];
362			int32 bbp_maxL = bbp_pivRowL;
363			int32 factorL = 1073741824 / maxL; /*( 1 << 30 )*/
364
365			for( jL = 0; jL < sizeL; jL++ )
366			{
367				pivRowL[ jL ] = ( pivRowL[ jL ] * factorL + ( 1 << 14 ) ) >> 15;
368			}
369			bbp_pivRowL = 15;
370
371			/* set to 1 to avoid computational errors */
372			pivRowL[ jPivL ] = ( int32 )1 << bbp_pivRowL;
373
374			shiftL = 30 - bts_absIntLog2( vecL[ jPivL ] );
375
376			vecL[ jPivL ] = ( vecL[ jPivL ] << shiftL ) / maxL;
377			bbp_vecL = bbpL + shiftL - bbp_maxL;
378
379			bbs_int32ReduceToNBits( &( vecL[ jPivL ] ), &bbp_vecL, 15 );
380		}
381
382		/* subtract pivot row from all other rows */
383		for( iL = 0; iL < sizeL; iL++ )
384		{
385			if( iL != jPivL )
386			{
387				int32* rowPtrL = matL + iL * sizeL;
388
389				int32 tmpL = *( rowPtrL + jPivL );
390				int32 bbp_tmpL = bbpL;
391				bbs_int32ReduceToNBits( &tmpL, &bbp_tmpL, 15 );
392
393				shiftL = bbp_tmpL + bbp_pivRowL - bbpL;
394				if( shiftL > 0 )
395				{
396					for( jL = 0; jL < sizeL; jL++ )
397					{
398						*rowPtrL++ -= ( ( ( tmpL * pivRowL[ jL ] ) >> ( shiftL - 1 ) ) + 1 ) >> 1;
399					}
400				}
401				else
402				{
403					for( jL = 0; jL < sizeL; jL++ )
404					{
405						*rowPtrL++ -= ( tmpL * pivRowL[ jL ] ) << -shiftL;
406					}
407				}
408
409				shiftL = bbp_tmpL + bbp_vecL - bbpL;
410				if( shiftL > 0 )
411				{
412					vecL[ iL ] -= ( ( ( tmpL * vecL[ jPivL ] ) >> ( shiftL - 1 ) ) + 1 ) >> 1;
413				}
414				else
415				{
416					vecL[ iL ] -= ( tmpL * vecL[ jPivL ] ) << -shiftL;
417				}
418			}
419		}
420
421		/* change bbp of pivot row back to bbpL */
422		shiftL = bbpL - bbp_pivRowL;
423		if( shiftL >= 0 )
424		{
425			for( jL = 0; jL < sizeL; jL++ )
426			{
427				pivRowL[ jL ] <<= shiftL;
428			}
429		}
430		else
431		{
432			shiftL = -shiftL;
433			for( jL = 0; jL < sizeL; jL++ )
434			{
435				pivRowL[ jL ] = ( ( pivRowL[ jL ] >> ( shiftL - 1 ) ) + 1 ) >> 1;
436			}
437		}
438
439		shiftL = bbpL - bbp_vecL;
440		if( shiftL >= 0 )
441		{
442			vecL[ jPivL ] <<= shiftL;
443		}
444		else
445		{
446			shiftL = -shiftL;
447			vecL[ jPivL ] = ( ( vecL[ jPivL ] >> ( shiftL - 1 ) ) + 1 ) >> 1;
448		}
449/*
450if( sizeL <= 5 ) bts_Int32Mat_print( matL, vecL, sizeL, bbpL );
451*/
452	}	/* of kL */
453
454	/* in case bbpL has been decreased by the overflow protection, change it back now */
455	if( bbpA > bbpL )
456	{
457		/* find largest element of solution vector */
458		int32 maxL = 0;
459		int32 iL, shiftL;
460		for( iL = 0; iL < sizeL; iL++ )
461		{
462			int32 xL = vecL[ iL ];
463			if( xL < 0 ) xL = -xL;
464			if( xL > maxL ) maxL = xL;
465		}
466
467		/* check whether we can left shift without overflow */
468		shiftL = 30 - bts_absIntLog2( maxL );
469		if( shiftL < ( bbpA - bbpL ) )
470		{
471			/*
472			    bbs_WARNING1( "flag bts_Int32Mat_solve2( ... ): getting overflow when trying to "
473				"compute solution vector with bbp = %d. Choose smaller bbp.\n", bbpA );
474			*/
475
476			return FALSE;
477		}
478
479		/* shift left */
480		shiftL = bbpA - bbpL;
481		for( iL = 0; iL < sizeL; iL++ ) vecL[ iL ] <<= shiftL;
482	}
483
484	return TRUE;
485}
486
487/* ------------------------------------------------------------------------- */
488
489/* ========================================================================= */
490
491