1/*
2 * Copyright (c) 2008-2016 Stefan Krah. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without
5 * modification, are permitted provided that the following conditions
6 * are met:
7 *
8 * 1. Redistributions of source code must retain the above copyright
9 *    notice, this list of conditions and the following disclaimer.
10 *
11 * 2. Redistributions in binary form must reproduce the above copyright
12 *    notice, this list of conditions and the following disclaimer in the
13 *    documentation and/or other materials provided with the distribution.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND
16 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
19 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
21 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
22 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
23 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
24 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
25 * SUCH DAMAGE.
26 */
27
28
29#ifndef UMODARITH_H
30#define UMODARITH_H
31
32
33#include "constants.h"
34#include "mpdecimal.h"
35#include "typearith.h"
36
37
38/* Bignum: Low level routines for unsigned modular arithmetic. These are
39   used in the fast convolution functions for very large coefficients. */
40
41
42/**************************************************************************/
43/*                        ANSI modular arithmetic                         */
44/**************************************************************************/
45
46
47/*
48 * Restrictions: a < m and b < m
49 * ACL2 proof: umodarith.lisp: addmod-correct
50 */
51static inline mpd_uint_t
52addmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
53{
54    mpd_uint_t s;
55
56    s = a + b;
57    s = (s < a) ? s - m : s;
58    s = (s >= m) ? s - m : s;
59
60    return s;
61}
62
63/*
64 * Restrictions: a < m and b < m
65 * ACL2 proof: umodarith.lisp: submod-2-correct
66 */
67static inline mpd_uint_t
68submod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
69{
70    mpd_uint_t d;
71
72    d = a - b;
73    d = (a < b) ? d + m : d;
74
75    return d;
76}
77
78/*
79 * Restrictions: a < 2m and b < 2m
80 * ACL2 proof: umodarith.lisp: section ext-submod
81 */
82static inline mpd_uint_t
83ext_submod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
84{
85    mpd_uint_t d;
86
87    a = (a >= m) ? a - m : a;
88    b = (b >= m) ? b - m : b;
89
90    d = a - b;
91    d = (a < b) ? d + m : d;
92
93    return d;
94}
95
96/*
97 * Reduce double word modulo m.
98 * Restrictions: m != 0
99 * ACL2 proof: umodarith.lisp: section dw-reduce
100 */
101static inline mpd_uint_t
102dw_reduce(mpd_uint_t hi, mpd_uint_t lo, mpd_uint_t m)
103{
104    mpd_uint_t r1, r2, w;
105
106    _mpd_div_word(&w, &r1, hi, m);
107    _mpd_div_words(&w, &r2, r1, lo, m);
108
109    return r2;
110}
111
112/*
113 * Subtract double word from a.
114 * Restrictions: a < m
115 * ACL2 proof: umodarith.lisp: section dw-submod
116 */
117static inline mpd_uint_t
118dw_submod(mpd_uint_t a, mpd_uint_t hi, mpd_uint_t lo, mpd_uint_t m)
119{
120    mpd_uint_t d, r;
121
122    r = dw_reduce(hi, lo, m);
123    d = a - r;
124    d = (a < r) ? d + m : d;
125
126    return d;
127}
128
129#ifdef CONFIG_64
130
131/**************************************************************************/
132/*                        64-bit modular arithmetic                       */
133/**************************************************************************/
134
135/*
136 * A proof of the algorithm is in literature/mulmod-64.txt. An ACL2
137 * proof is in umodarith.lisp: section "Fast modular reduction".
138 *
139 * Algorithm: calculate (a * b) % p:
140 *
141 *   a) hi, lo <- a * b       # Calculate a * b.
142 *
143 *   b) hi, lo <-  R(hi, lo)  # Reduce modulo p.
144 *
145 *   c) Repeat step b) until 0 <= hi * 2**64 + lo < 2*p.
146 *
147 *   d) If the result is less than p, return lo. Otherwise return lo - p.
148 */
149
150static inline mpd_uint_t
151x64_mulmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
152{
153    mpd_uint_t hi, lo, x, y;
154
155
156    _mpd_mul_words(&hi, &lo, a, b);
157
158    if (m & (1ULL<<32)) { /* P1 */
159
160        /* first reduction */
161        x = y = hi;
162        hi >>= 32;
163
164        x = lo - x;
165        if (x > lo) hi--;
166
167        y <<= 32;
168        lo = y + x;
169        if (lo < y) hi++;
170
171        /* second reduction */
172        x = y = hi;
173        hi >>= 32;
174
175        x = lo - x;
176        if (x > lo) hi--;
177
178        y <<= 32;
179        lo = y + x;
180        if (lo < y) hi++;
181
182        return (hi || lo >= m ? lo - m : lo);
183    }
184    else if (m & (1ULL<<34)) { /* P2 */
185
186        /* first reduction */
187        x = y = hi;
188        hi >>= 30;
189
190        x = lo - x;
191        if (x > lo) hi--;
192
193        y <<= 34;
194        lo = y + x;
195        if (lo < y) hi++;
196
197        /* second reduction */
198        x = y = hi;
199        hi >>= 30;
200
201        x = lo - x;
202        if (x > lo) hi--;
203
204        y <<= 34;
205        lo = y + x;
206        if (lo < y) hi++;
207
208        /* third reduction */
209        x = y = hi;
210        hi >>= 30;
211
212        x = lo - x;
213        if (x > lo) hi--;
214
215        y <<= 34;
216        lo = y + x;
217        if (lo < y) hi++;
218
219        return (hi || lo >= m ? lo - m : lo);
220    }
221    else { /* P3 */
222
223        /* first reduction */
224        x = y = hi;
225        hi >>= 24;
226
227        x = lo - x;
228        if (x > lo) hi--;
229
230        y <<= 40;
231        lo = y + x;
232        if (lo < y) hi++;
233
234        /* second reduction */
235        x = y = hi;
236        hi >>= 24;
237
238        x = lo - x;
239        if (x > lo) hi--;
240
241        y <<= 40;
242        lo = y + x;
243        if (lo < y) hi++;
244
245        /* third reduction */
246        x = y = hi;
247        hi >>= 24;
248
249        x = lo - x;
250        if (x > lo) hi--;
251
252        y <<= 40;
253        lo = y + x;
254        if (lo < y) hi++;
255
256        return (hi || lo >= m ? lo - m : lo);
257    }
258}
259
260static inline void
261x64_mulmod2c(mpd_uint_t *a, mpd_uint_t *b, mpd_uint_t w, mpd_uint_t m)
262{
263    *a = x64_mulmod(*a, w, m);
264    *b = x64_mulmod(*b, w, m);
265}
266
267static inline void
268x64_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
269            mpd_uint_t m)
270{
271    *a0 = x64_mulmod(*a0, b0, m);
272    *a1 = x64_mulmod(*a1, b1, m);
273}
274
275static inline mpd_uint_t
276x64_powmod(mpd_uint_t base, mpd_uint_t exp, mpd_uint_t umod)
277{
278    mpd_uint_t r = 1;
279
280    while (exp > 0) {
281        if (exp & 1)
282            r = x64_mulmod(r, base, umod);
283        base = x64_mulmod(base, base, umod);
284        exp >>= 1;
285    }
286
287    return r;
288}
289
290/* END CONFIG_64 */
291#else /* CONFIG_32 */
292
293
294/**************************************************************************/
295/*                        32-bit modular arithmetic                       */
296/**************************************************************************/
297
298#if defined(ANSI)
299#if !defined(LEGACY_COMPILER)
300/* HAVE_UINT64_T */
301static inline mpd_uint_t
302std_mulmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
303{
304    return ((mpd_uuint_t) a * b) % m;
305}
306
307static inline void
308std_mulmod2c(mpd_uint_t *a, mpd_uint_t *b, mpd_uint_t w, mpd_uint_t m)
309{
310    *a = ((mpd_uuint_t) *a * w) % m;
311    *b = ((mpd_uuint_t) *b * w) % m;
312}
313
314static inline void
315std_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
316            mpd_uint_t m)
317{
318    *a0 = ((mpd_uuint_t) *a0 * b0) % m;
319    *a1 = ((mpd_uuint_t) *a1 * b1) % m;
320}
321/* END HAVE_UINT64_T */
322#else
323/* LEGACY_COMPILER */
324static inline mpd_uint_t
325std_mulmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
326{
327    mpd_uint_t hi, lo, q, r;
328    _mpd_mul_words(&hi, &lo, a, b);
329    _mpd_div_words(&q, &r, hi, lo, m);
330    return r;
331}
332
333static inline void
334std_mulmod2c(mpd_uint_t *a, mpd_uint_t *b, mpd_uint_t w, mpd_uint_t m)
335{
336    *a = std_mulmod(*a, w, m);
337    *b = std_mulmod(*b, w, m);
338}
339
340static inline void
341std_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
342            mpd_uint_t m)
343{
344    *a0 = std_mulmod(*a0, b0, m);
345    *a1 = std_mulmod(*a1, b1, m);
346}
347/* END LEGACY_COMPILER */
348#endif
349
350static inline mpd_uint_t
351std_powmod(mpd_uint_t base, mpd_uint_t exp, mpd_uint_t umod)
352{
353    mpd_uint_t r = 1;
354
355    while (exp > 0) {
356        if (exp & 1)
357            r = std_mulmod(r, base, umod);
358        base = std_mulmod(base, base, umod);
359        exp >>= 1;
360    }
361
362    return r;
363}
364#endif /* ANSI CONFIG_32 */
365
366
367/**************************************************************************/
368/*                    Pentium Pro modular arithmetic                      */
369/**************************************************************************/
370
371/*
372 * A proof of the algorithm is in literature/mulmod-ppro.txt. The FPU
373 * control word must be set to 64-bit precision and truncation mode
374 * prior to using these functions.
375 *
376 * Algorithm: calculate (a * b) % p:
377 *
378 *   p    := prime < 2**31
379 *   pinv := (long double)1.0 / p (precalculated)
380 *
381 *   a) n = a * b              # Calculate exact product.
382 *   b) qest = n * pinv        # Calculate estimate for q = n / p.
383 *   c) q = (qest+2**63)-2**63 # Truncate qest to the exact quotient.
384 *   d) r = n - q * p          # Calculate remainder.
385 *
386 * Remarks:
387 *
388 *   - p = dmod and pinv = dinvmod.
389 *   - dinvmod points to an array of three uint32_t, which is interpreted
390 *     as an 80 bit long double by fldt.
391 *   - Intel compilers prior to version 11 do not seem to handle the
392 *     __GNUC__ inline assembly correctly.
393 *   - random tests are provided in tests/extended/ppro_mulmod.c
394 */
395
396#if defined(PPRO)
397#if defined(ASM)
398
399/* Return (a * b) % dmod */
400static inline mpd_uint_t
401ppro_mulmod(mpd_uint_t a, mpd_uint_t b, double *dmod, uint32_t *dinvmod)
402{
403    mpd_uint_t retval;
404
405    __asm__ (
406            "fildl  %2\n\t"
407            "fildl  %1\n\t"
408            "fmulp  %%st, %%st(1)\n\t"
409            "fldt   (%4)\n\t"
410            "fmul   %%st(1), %%st\n\t"
411            "flds   %5\n\t"
412            "fadd   %%st, %%st(1)\n\t"
413            "fsubrp %%st, %%st(1)\n\t"
414            "fldl   (%3)\n\t"
415            "fmulp  %%st, %%st(1)\n\t"
416            "fsubrp %%st, %%st(1)\n\t"
417            "fistpl %0\n\t"
418            : "=m" (retval)
419            : "m" (a), "m" (b), "r" (dmod), "r" (dinvmod), "m" (MPD_TWO63)
420            : "st", "memory"
421    );
422
423    return retval;
424}
425
426/*
427 * Two modular multiplications in parallel:
428 *      *a0 = (*a0 * w) % dmod
429 *      *a1 = (*a1 * w) % dmod
430 */
431static inline void
432ppro_mulmod2c(mpd_uint_t *a0, mpd_uint_t *a1, mpd_uint_t w,
433              double *dmod, uint32_t *dinvmod)
434{
435    __asm__ (
436            "fildl  %2\n\t"
437            "fildl  (%1)\n\t"
438            "fmul   %%st(1), %%st\n\t"
439            "fxch   %%st(1)\n\t"
440            "fildl  (%0)\n\t"
441            "fmulp  %%st, %%st(1) \n\t"
442            "fldt   (%4)\n\t"
443            "flds   %5\n\t"
444            "fld    %%st(2)\n\t"
445            "fmul   %%st(2)\n\t"
446            "fadd   %%st(1)\n\t"
447            "fsub   %%st(1)\n\t"
448            "fmull  (%3)\n\t"
449            "fsubrp %%st, %%st(3)\n\t"
450            "fxch   %%st(2)\n\t"
451            "fistpl (%0)\n\t"
452            "fmul   %%st(2)\n\t"
453            "fadd   %%st(1)\n\t"
454            "fsubp  %%st, %%st(1)\n\t"
455            "fmull  (%3)\n\t"
456            "fsubrp %%st, %%st(1)\n\t"
457            "fistpl (%1)\n\t"
458            : : "r" (a0), "r" (a1), "m" (w),
459                "r" (dmod), "r" (dinvmod),
460                "m" (MPD_TWO63)
461            : "st", "memory"
462    );
463}
464
465/*
466 * Two modular multiplications in parallel:
467 *      *a0 = (*a0 * b0) % dmod
468 *      *a1 = (*a1 * b1) % dmod
469 */
470static inline void
471ppro_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
472             double *dmod, uint32_t *dinvmod)
473{
474    __asm__ (
475            "fildl  %3\n\t"
476            "fildl  (%2)\n\t"
477            "fmulp  %%st, %%st(1)\n\t"
478            "fildl  %1\n\t"
479            "fildl  (%0)\n\t"
480            "fmulp  %%st, %%st(1)\n\t"
481            "fldt   (%5)\n\t"
482            "fld    %%st(2)\n\t"
483            "fmul   %%st(1), %%st\n\t"
484            "fxch   %%st(1)\n\t"
485            "fmul   %%st(2), %%st\n\t"
486            "flds   %6\n\t"
487            "fldl   (%4)\n\t"
488            "fxch   %%st(3)\n\t"
489            "fadd   %%st(1), %%st\n\t"
490            "fxch   %%st(2)\n\t"
491            "fadd   %%st(1), %%st\n\t"
492            "fxch   %%st(2)\n\t"
493            "fsub   %%st(1), %%st\n\t"
494            "fxch   %%st(2)\n\t"
495            "fsubp  %%st, %%st(1)\n\t"
496            "fxch   %%st(1)\n\t"
497            "fmul   %%st(2), %%st\n\t"
498            "fxch   %%st(1)\n\t"
499            "fmulp  %%st, %%st(2)\n\t"
500            "fsubrp %%st, %%st(3)\n\t"
501            "fsubrp %%st, %%st(1)\n\t"
502            "fxch   %%st(1)\n\t"
503            "fistpl (%2)\n\t"
504            "fistpl (%0)\n\t"
505            : : "r" (a0), "m" (b0), "r" (a1), "m" (b1),
506                "r" (dmod), "r" (dinvmod),
507                "m" (MPD_TWO63)
508            : "st", "memory"
509    );
510}
511/* END PPRO GCC ASM */
512#elif defined(MASM)
513
514/* Return (a * b) % dmod */
515static inline mpd_uint_t __cdecl
516ppro_mulmod(mpd_uint_t a, mpd_uint_t b, double *dmod, uint32_t *dinvmod)
517{
518    mpd_uint_t retval;
519
520    __asm {
521        mov     eax, dinvmod
522        mov     edx, dmod
523        fild    b
524        fild    a
525        fmulp   st(1), st
526        fld     TBYTE PTR [eax]
527        fmul    st, st(1)
528        fld     MPD_TWO63
529        fadd    st(1), st
530        fsubp   st(1), st
531        fld     QWORD PTR [edx]
532        fmulp   st(1), st
533        fsubp   st(1), st
534        fistp   retval
535    }
536
537    return retval;
538}
539
540/*
541 * Two modular multiplications in parallel:
542 *      *a0 = (*a0 * w) % dmod
543 *      *a1 = (*a1 * w) % dmod
544 */
545static inline mpd_uint_t __cdecl
546ppro_mulmod2c(mpd_uint_t *a0, mpd_uint_t *a1, mpd_uint_t w,
547              double *dmod, uint32_t *dinvmod)
548{
549    __asm {
550        mov     ecx, dmod
551        mov     edx, a1
552        mov     ebx, dinvmod
553        mov     eax, a0
554        fild    w
555        fild    DWORD PTR [edx]
556        fmul    st, st(1)
557        fxch    st(1)
558        fild    DWORD PTR [eax]
559        fmulp   st(1), st
560        fld     TBYTE PTR [ebx]
561        fld     MPD_TWO63
562        fld     st(2)
563        fmul    st, st(2)
564        fadd    st, st(1)
565        fsub    st, st(1)
566        fmul    QWORD PTR [ecx]
567        fsubp   st(3), st
568        fxch    st(2)
569        fistp   DWORD PTR [eax]
570        fmul    st, st(2)
571        fadd    st, st(1)
572        fsubrp  st(1), st
573        fmul    QWORD PTR [ecx]
574        fsubp   st(1), st
575        fistp   DWORD PTR [edx]
576    }
577}
578
579/*
580 * Two modular multiplications in parallel:
581 *      *a0 = (*a0 * b0) % dmod
582 *      *a1 = (*a1 * b1) % dmod
583 */
584static inline void __cdecl
585ppro_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
586             double *dmod, uint32_t *dinvmod)
587{
588    __asm {
589        mov     ecx, dmod
590        mov     edx, a1
591        mov     ebx, dinvmod
592        mov     eax, a0
593        fild    b1
594        fild    DWORD PTR [edx]
595        fmulp   st(1), st
596        fild    b0
597        fild    DWORD PTR [eax]
598        fmulp   st(1), st
599        fld     TBYTE PTR [ebx]
600        fld     st(2)
601        fmul    st, st(1)
602        fxch    st(1)
603        fmul    st, st(2)
604        fld     DWORD PTR MPD_TWO63
605        fld     QWORD PTR [ecx]
606        fxch    st(3)
607        fadd    st, st(1)
608        fxch    st(2)
609        fadd    st, st(1)
610        fxch    st(2)
611        fsub    st, st(1)
612        fxch    st(2)
613        fsubrp  st(1), st
614        fxch    st(1)
615        fmul    st, st(2)
616        fxch    st(1)
617        fmulp   st(2), st
618        fsubp   st(3), st
619        fsubp   st(1), st
620        fxch    st(1)
621        fistp   DWORD PTR [edx]
622        fistp   DWORD PTR [eax]
623    }
624}
625#endif /* PPRO MASM (_MSC_VER) */
626
627
628/* Return (base ** exp) % dmod */
629static inline mpd_uint_t
630ppro_powmod(mpd_uint_t base, mpd_uint_t exp, double *dmod, uint32_t *dinvmod)
631{
632    mpd_uint_t r = 1;
633
634    while (exp > 0) {
635        if (exp & 1)
636            r = ppro_mulmod(r, base, dmod, dinvmod);
637        base = ppro_mulmod(base, base, dmod, dinvmod);
638        exp >>= 1;
639    }
640
641    return r;
642}
643#endif /* PPRO */
644#endif /* CONFIG_32 */
645
646
647#endif /* UMODARITH_H */
648
649
650
651