1package org.bouncycastle.math.ec;
2
3import java.math.BigInteger;
4
5import org.bouncycastle.math.ec.endo.ECEndomorphism;
6import org.bouncycastle.math.ec.endo.GLVEndomorphism;
7import org.bouncycastle.math.field.FiniteField;
8import org.bouncycastle.math.field.PolynomialExtensionField;
9
10public class ECAlgorithms
11{
12    public static boolean isF2mCurve(ECCurve c)
13    {
14        FiniteField field = c.getField();
15        return field.getDimension() > 1 && field.getCharacteristic().equals(ECConstants.TWO)
16            && field instanceof PolynomialExtensionField;
17    }
18
19    public static boolean isFpCurve(ECCurve c)
20    {
21        return c.getField().getDimension() == 1;
22    }
23
24    public static ECPoint sumOfMultiplies(ECPoint[] ps, BigInteger[] ks)
25    {
26        if (ps == null || ks == null || ps.length != ks.length || ps.length < 1)
27        {
28            throw new IllegalArgumentException("point and scalar arrays should be non-null, and of equal, non-zero, length");
29        }
30
31        int count = ps.length;
32        switch (count)
33        {
34        case 1:
35            return ps[0].multiply(ks[0]);
36        case 2:
37            return sumOfTwoMultiplies(ps[0], ks[0], ps[1], ks[1]);
38        default:
39            break;
40        }
41
42        ECPoint p = ps[0];
43        ECCurve c = p.getCurve();
44
45        ECPoint[] imported = new ECPoint[count];
46        imported[0] = p;
47        for (int i = 1; i < count; ++i)
48        {
49            imported[i] = importPoint(c, ps[i]);
50        }
51
52        ECEndomorphism endomorphism = c.getEndomorphism();
53        if (endomorphism instanceof GLVEndomorphism)
54        {
55            return validatePoint(implSumOfMultipliesGLV(imported, ks, (GLVEndomorphism)endomorphism));
56        }
57
58        return validatePoint(implSumOfMultiplies(imported, ks));
59    }
60
61    public static ECPoint sumOfTwoMultiplies(ECPoint P, BigInteger a,
62        ECPoint Q, BigInteger b)
63    {
64        ECCurve cp = P.getCurve();
65        Q = importPoint(cp, Q);
66
67        // Point multiplication for Koblitz curves (using WTNAF) beats Shamir's trick
68        if (cp instanceof ECCurve.F2m)
69        {
70            ECCurve.F2m f2mCurve = (ECCurve.F2m)cp;
71            if (f2mCurve.isKoblitz())
72            {
73                return validatePoint(P.multiply(a).add(Q.multiply(b)));
74            }
75        }
76
77        ECEndomorphism endomorphism = cp.getEndomorphism();
78        if (endomorphism instanceof GLVEndomorphism)
79        {
80            return validatePoint(
81                implSumOfMultipliesGLV(new ECPoint[]{ P, Q }, new BigInteger[]{ a, b }, (GLVEndomorphism)endomorphism));
82        }
83
84        return validatePoint(implShamirsTrickWNaf(P, a, Q, b));
85    }
86
87    /*
88     * "Shamir's Trick", originally due to E. G. Straus
89     * (Addition chains of vectors. American Mathematical Monthly,
90     * 71(7):806-808, Aug./Sept. 1964)
91     * <pre>
92     * Input: The points P, Q, scalar k = (km?, ... , k1, k0)
93     * and scalar l = (lm?, ... , l1, l0).
94     * Output: R = k * P + l * Q.
95     * 1: Z <- P + Q
96     * 2: R <- O
97     * 3: for i from m-1 down to 0 do
98     * 4:        R <- R + R        {point doubling}
99     * 5:        if (ki = 1) and (li = 0) then R <- R + P end if
100     * 6:        if (ki = 0) and (li = 1) then R <- R + Q end if
101     * 7:        if (ki = 1) and (li = 1) then R <- R + Z end if
102     * 8: end for
103     * 9: return R
104     * </pre>
105     */
106    public static ECPoint shamirsTrick(ECPoint P, BigInteger k,
107        ECPoint Q, BigInteger l)
108    {
109        ECCurve cp = P.getCurve();
110        Q = importPoint(cp, Q);
111
112        return validatePoint(implShamirsTrickJsf(P, k, Q, l));
113    }
114
115    public static ECPoint importPoint(ECCurve c, ECPoint p)
116    {
117        ECCurve cp = p.getCurve();
118        if (!c.equals(cp))
119        {
120            throw new IllegalArgumentException("Point must be on the same curve");
121        }
122        return c.importPoint(p);
123    }
124
125    public static void montgomeryTrick(ECFieldElement[] zs, int off, int len)
126    {
127        montgomeryTrick(zs, off, len, null);
128    }
129
130    public static void montgomeryTrick(ECFieldElement[] zs, int off, int len, ECFieldElement scale)
131    {
132        /*
133         * Uses the "Montgomery Trick" to invert many field elements, with only a single actual
134         * field inversion. See e.g. the paper:
135         * "Fast Multi-scalar Multiplication Methods on Elliptic Curves with Precomputation Strategy Using Montgomery Trick"
136         * by Katsuyuki Okeya, Kouichi Sakurai.
137         */
138
139        ECFieldElement[] c = new ECFieldElement[len];
140        c[0] = zs[off];
141
142        int i = 0;
143        while (++i < len)
144        {
145            c[i] = c[i - 1].multiply(zs[off + i]);
146        }
147
148        --i;
149
150        if (scale != null)
151        {
152            c[i] = c[i].multiply(scale);
153        }
154
155        ECFieldElement u = c[i].invert();
156
157        while (i > 0)
158        {
159            int j = off + i--;
160            ECFieldElement tmp = zs[j];
161            zs[j] = c[i].multiply(u);
162            u = u.multiply(tmp);
163        }
164
165        zs[off] = u;
166    }
167
168    /**
169     * Simple shift-and-add multiplication. Serves as reference implementation
170     * to verify (possibly faster) implementations, and for very small scalars.
171     *
172     * @param p
173     *            The point to multiply.
174     * @param k
175     *            The multiplier.
176     * @return The result of the point multiplication <code>kP</code>.
177     */
178    public static ECPoint referenceMultiply(ECPoint p, BigInteger k)
179    {
180        BigInteger x = k.abs();
181        ECPoint q = p.getCurve().getInfinity();
182        int t = x.bitLength();
183        if (t > 0)
184        {
185            if (x.testBit(0))
186            {
187                q = p;
188            }
189            for (int i = 1; i < t; i++)
190            {
191                p = p.twice();
192                if (x.testBit(i))
193                {
194                    q = q.add(p);
195                }
196            }
197        }
198        return k.signum() < 0 ? q.negate() : q;
199    }
200
201    public static ECPoint validatePoint(ECPoint p)
202    {
203        if (!p.isValid())
204        {
205            throw new IllegalArgumentException("Invalid point");
206        }
207
208        return p;
209    }
210
211    static ECPoint implShamirsTrickJsf(ECPoint P, BigInteger k,
212        ECPoint Q, BigInteger l)
213    {
214        ECCurve curve = P.getCurve();
215        ECPoint infinity = curve.getInfinity();
216
217        // TODO conjugate co-Z addition (ZADDC) can return both of these
218        ECPoint PaddQ = P.add(Q);
219        ECPoint PsubQ = P.subtract(Q);
220
221        ECPoint[] points = new ECPoint[]{ Q, PsubQ, P, PaddQ };
222        curve.normalizeAll(points);
223
224        ECPoint[] table = new ECPoint[] {
225            points[3].negate(), points[2].negate(), points[1].negate(),
226            points[0].negate(), infinity, points[0],
227            points[1], points[2], points[3] };
228
229        byte[] jsf = WNafUtil.generateJSF(k, l);
230
231        ECPoint R = infinity;
232
233        int i = jsf.length;
234        while (--i >= 0)
235        {
236            int jsfi = jsf[i];
237
238            // NOTE: The shifting ensures the sign is extended correctly
239            int kDigit = ((jsfi << 24) >> 28), lDigit = ((jsfi << 28) >> 28);
240
241            int index = 4 + (kDigit * 3) + lDigit;
242            R = R.twicePlus(table[index]);
243        }
244
245        return R;
246    }
247
248    static ECPoint implShamirsTrickWNaf(ECPoint P, BigInteger k,
249        ECPoint Q, BigInteger l)
250    {
251        boolean negK = k.signum() < 0, negL = l.signum() < 0;
252
253        k = k.abs();
254        l = l.abs();
255
256        int widthP = Math.max(2, Math.min(16, WNafUtil.getWindowSize(k.bitLength())));
257        int widthQ = Math.max(2, Math.min(16, WNafUtil.getWindowSize(l.bitLength())));
258
259        WNafPreCompInfo infoP = WNafUtil.precompute(P, widthP, true);
260        WNafPreCompInfo infoQ = WNafUtil.precompute(Q, widthQ, true);
261
262        ECPoint[] preCompP = negK ? infoP.getPreCompNeg() : infoP.getPreComp();
263        ECPoint[] preCompQ = negL ? infoQ.getPreCompNeg() : infoQ.getPreComp();
264        ECPoint[] preCompNegP = negK ? infoP.getPreComp() : infoP.getPreCompNeg();
265        ECPoint[] preCompNegQ = negL ? infoQ.getPreComp() : infoQ.getPreCompNeg();
266
267        byte[] wnafP = WNafUtil.generateWindowNaf(widthP, k);
268        byte[] wnafQ = WNafUtil.generateWindowNaf(widthQ, l);
269
270        return implShamirsTrickWNaf(preCompP, preCompNegP, wnafP, preCompQ, preCompNegQ, wnafQ);
271    }
272
273    static ECPoint implShamirsTrickWNaf(ECPoint P, BigInteger k, ECPointMap pointMapQ, BigInteger l)
274    {
275        boolean negK = k.signum() < 0, negL = l.signum() < 0;
276
277        k = k.abs();
278        l = l.abs();
279
280        int width = Math.max(2, Math.min(16, WNafUtil.getWindowSize(Math.max(k.bitLength(), l.bitLength()))));
281
282        ECPoint Q = WNafUtil.mapPointWithPrecomp(P, width, true, pointMapQ);
283        WNafPreCompInfo infoP = WNafUtil.getWNafPreCompInfo(P);
284        WNafPreCompInfo infoQ = WNafUtil.getWNafPreCompInfo(Q);
285
286        ECPoint[] preCompP = negK ? infoP.getPreCompNeg() : infoP.getPreComp();
287        ECPoint[] preCompQ = negL ? infoQ.getPreCompNeg() : infoQ.getPreComp();
288        ECPoint[] preCompNegP = negK ? infoP.getPreComp() : infoP.getPreCompNeg();
289        ECPoint[] preCompNegQ = negL ? infoQ.getPreComp() : infoQ.getPreCompNeg();
290
291        byte[] wnafP = WNafUtil.generateWindowNaf(width, k);
292        byte[] wnafQ = WNafUtil.generateWindowNaf(width, l);
293
294        return implShamirsTrickWNaf(preCompP, preCompNegP, wnafP, preCompQ, preCompNegQ, wnafQ);
295    }
296
297    private static ECPoint implShamirsTrickWNaf(ECPoint[] preCompP, ECPoint[] preCompNegP, byte[] wnafP,
298        ECPoint[] preCompQ, ECPoint[] preCompNegQ, byte[] wnafQ)
299    {
300        int len = Math.max(wnafP.length, wnafQ.length);
301
302        ECCurve curve = preCompP[0].getCurve();
303        ECPoint infinity = curve.getInfinity();
304
305        ECPoint R = infinity;
306        int zeroes = 0;
307
308        for (int i = len - 1; i >= 0; --i)
309        {
310            int wiP = i < wnafP.length ? wnafP[i] : 0;
311            int wiQ = i < wnafQ.length ? wnafQ[i] : 0;
312
313            if ((wiP | wiQ) == 0)
314            {
315                ++zeroes;
316                continue;
317            }
318
319            ECPoint r = infinity;
320            if (wiP != 0)
321            {
322                int nP = Math.abs(wiP);
323                ECPoint[] tableP = wiP < 0 ? preCompNegP : preCompP;
324                r = r.add(tableP[nP >>> 1]);
325            }
326            if (wiQ != 0)
327            {
328                int nQ = Math.abs(wiQ);
329                ECPoint[] tableQ = wiQ < 0 ? preCompNegQ : preCompQ;
330                r = r.add(tableQ[nQ >>> 1]);
331            }
332
333            if (zeroes > 0)
334            {
335                R = R.timesPow2(zeroes);
336                zeroes = 0;
337            }
338
339            R = R.twicePlus(r);
340        }
341
342        if (zeroes > 0)
343        {
344            R = R.timesPow2(zeroes);
345        }
346
347        return R;
348    }
349
350    static ECPoint implSumOfMultiplies(ECPoint[] ps, BigInteger[] ks)
351    {
352        int count = ps.length;
353        boolean[] negs = new boolean[count];
354        WNafPreCompInfo[] infos = new WNafPreCompInfo[count];
355        byte[][] wnafs = new byte[count][];
356
357        for (int i = 0; i < count; ++i)
358        {
359            BigInteger ki = ks[i]; negs[i] = ki.signum() < 0; ki = ki.abs();
360
361            int width = Math.max(2, Math.min(16, WNafUtil.getWindowSize(ki.bitLength())));
362            infos[i] = WNafUtil.precompute(ps[i], width, true);
363            wnafs[i] = WNafUtil.generateWindowNaf(width, ki);
364        }
365
366        return implSumOfMultiplies(negs, infos, wnafs);
367    }
368
369    static ECPoint implSumOfMultipliesGLV(ECPoint[] ps, BigInteger[] ks, GLVEndomorphism glvEndomorphism)
370    {
371        BigInteger n = ps[0].getCurve().getOrder();
372
373        int len = ps.length;
374
375        BigInteger[] abs = new BigInteger[len << 1];
376        for (int i = 0, j = 0; i < len; ++i)
377        {
378            BigInteger[] ab = glvEndomorphism.decomposeScalar(ks[i].mod(n));
379            abs[j++] = ab[0];
380            abs[j++] = ab[1];
381        }
382
383        ECPointMap pointMap = glvEndomorphism.getPointMap();
384        if (glvEndomorphism.hasEfficientPointMap())
385        {
386            return ECAlgorithms.implSumOfMultiplies(ps, pointMap, abs);
387        }
388
389        ECPoint[] pqs = new ECPoint[len << 1];
390        for (int i = 0, j = 0; i < len; ++i)
391        {
392            ECPoint p = ps[i], q = pointMap.map(p);
393            pqs[j++] = p;
394            pqs[j++] = q;
395        }
396
397        return ECAlgorithms.implSumOfMultiplies(pqs, abs);
398
399    }
400
401    static ECPoint implSumOfMultiplies(ECPoint[] ps, ECPointMap pointMap, BigInteger[] ks)
402    {
403        int halfCount = ps.length, fullCount = halfCount << 1;
404
405        boolean[] negs = new boolean[fullCount];
406        WNafPreCompInfo[] infos = new WNafPreCompInfo[fullCount];
407        byte[][] wnafs = new byte[fullCount][];
408
409        for (int i = 0; i < halfCount; ++i)
410        {
411            int j0 = i << 1, j1 = j0 + 1;
412
413            BigInteger kj0 = ks[j0]; negs[j0] = kj0.signum() < 0; kj0 = kj0.abs();
414            BigInteger kj1 = ks[j1]; negs[j1] = kj1.signum() < 0; kj1 = kj1.abs();
415
416            int width = Math.max(2, Math.min(16, WNafUtil.getWindowSize(Math.max(kj0.bitLength(), kj1.bitLength()))));
417
418            ECPoint P = ps[i], Q = WNafUtil.mapPointWithPrecomp(P, width, true, pointMap);
419            infos[j0] = WNafUtil.getWNafPreCompInfo(P);
420            infos[j1] = WNafUtil.getWNafPreCompInfo(Q);
421            wnafs[j0] = WNafUtil.generateWindowNaf(width, kj0);
422            wnafs[j1] = WNafUtil.generateWindowNaf(width, kj1);
423        }
424
425        return implSumOfMultiplies(negs, infos, wnafs);
426    }
427
428    private static ECPoint implSumOfMultiplies(boolean[] negs, WNafPreCompInfo[] infos, byte[][] wnafs)
429    {
430        int len = 0, count = wnafs.length;
431        for (int i = 0; i < count; ++i)
432        {
433            len = Math.max(len, wnafs[i].length);
434        }
435
436        ECCurve curve = infos[0].getPreComp()[0].getCurve();
437        ECPoint infinity = curve.getInfinity();
438
439        ECPoint R = infinity;
440        int zeroes = 0;
441
442        for (int i = len - 1; i >= 0; --i)
443        {
444            ECPoint r = infinity;
445
446            for (int j = 0; j < count; ++j)
447            {
448                byte[] wnaf = wnafs[j];
449                int wi = i < wnaf.length ? wnaf[i] : 0;
450                if (wi != 0)
451                {
452                    int n = Math.abs(wi);
453                    WNafPreCompInfo info = infos[j];
454                    ECPoint[] table = (wi < 0 == negs[j]) ? info.getPreComp() : info.getPreCompNeg();
455                    r = r.add(table[n >>> 1]);
456                }
457            }
458
459            if (r == infinity)
460            {
461                ++zeroes;
462                continue;
463            }
464
465            if (zeroes > 0)
466            {
467                R = R.timesPow2(zeroes);
468                zeroes = 0;
469            }
470
471            R = R.twicePlus(r);
472        }
473
474        if (zeroes > 0)
475        {
476            R = R.timesPow2(zeroes);
477        }
478
479        return R;
480    }
481}
482