1package org.bouncycastle.crypto.modes;
2
3import org.bouncycastle.crypto.BlockCipher;
4import org.bouncycastle.crypto.CipherParameters;
5import org.bouncycastle.crypto.DataLengthException;
6import org.bouncycastle.crypto.InvalidCipherTextException;
7import org.bouncycastle.crypto.modes.gcm.GCMExponentiator;
8import org.bouncycastle.crypto.modes.gcm.GCMMultiplier;
9import org.bouncycastle.crypto.modes.gcm.Tables1kGCMExponentiator;
10import org.bouncycastle.crypto.modes.gcm.Tables8kGCMMultiplier;
11import org.bouncycastle.crypto.params.AEADParameters;
12import org.bouncycastle.crypto.params.KeyParameter;
13import org.bouncycastle.crypto.params.ParametersWithIV;
14import org.bouncycastle.crypto.util.Pack;
15import org.bouncycastle.util.Arrays;
16
17/**
18 * Implements the Galois/Counter mode (GCM) detailed in
19 * NIST Special Publication 800-38D.
20 */
21public class GCMBlockCipher
22    implements AEADBlockCipher
23{
24    private static final int BLOCK_SIZE = 16;
25
26    // not final due to a compiler bug
27    private BlockCipher   cipher;
28    private GCMMultiplier multiplier;
29    private GCMExponentiator exp;
30
31    // These fields are set by init and not modified by processing
32    private boolean             forEncryption;
33    private int                 macSize;
34    private byte[]              nonce;
35    private byte[]              initialAssociatedText;
36    private byte[]              H;
37    private byte[]              J0;
38
39    // These fields are modified during processing
40    private byte[]      bufBlock;
41    private byte[]      macBlock;
42    private byte[]      S, S_at, S_atPre;
43    private byte[]      counter;
44    private int         bufOff;
45    private long        totalLength;
46    private byte[]      atBlock;
47    private int         atBlockPos;
48    private long        atLength;
49    private long        atLengthPre;
50
51    public GCMBlockCipher(BlockCipher c)
52    {
53        this(c, null);
54    }
55
56    public GCMBlockCipher(BlockCipher c, GCMMultiplier m)
57    {
58        if (c.getBlockSize() != BLOCK_SIZE)
59        {
60            throw new IllegalArgumentException(
61                "cipher required with a block size of " + BLOCK_SIZE + ".");
62        }
63
64        if (m == null)
65        {
66            // TODO Consider a static property specifying default multiplier
67            m = new Tables8kGCMMultiplier();
68        }
69
70        this.cipher = c;
71        this.multiplier = m;
72    }
73
74    public BlockCipher getUnderlyingCipher()
75    {
76        return cipher;
77    }
78
79    public String getAlgorithmName()
80    {
81        return cipher.getAlgorithmName() + "/GCM";
82    }
83
84    public void init(boolean forEncryption, CipherParameters params)
85        throws IllegalArgumentException
86    {
87        this.forEncryption = forEncryption;
88        this.macBlock = null;
89
90        KeyParameter keyParam;
91
92        if (params instanceof AEADParameters)
93        {
94            AEADParameters param = (AEADParameters)params;
95
96            nonce = param.getNonce();
97            initialAssociatedText = param.getAssociatedText();
98
99            int macSizeBits = param.getMacSize();
100            if (macSizeBits < 96 || macSizeBits > 128 || macSizeBits % 8 != 0)
101            {
102                throw new IllegalArgumentException("Invalid value for MAC size: " + macSizeBits);
103            }
104
105            macSize = macSizeBits / 8;
106            keyParam = param.getKey();
107        }
108        else if (params instanceof ParametersWithIV)
109        {
110            ParametersWithIV param = (ParametersWithIV)params;
111
112            nonce = param.getIV();
113            initialAssociatedText  = null;
114            macSize = 16;
115            keyParam = (KeyParameter)param.getParameters();
116        }
117        else
118        {
119            throw new IllegalArgumentException("invalid parameters passed to GCM");
120        }
121
122        int bufLength = forEncryption ? BLOCK_SIZE : (BLOCK_SIZE + macSize);
123        this.bufBlock = new byte[bufLength];
124
125        if (nonce == null || nonce.length < 1)
126        {
127            throw new IllegalArgumentException("IV must be at least 1 byte");
128        }
129
130        // TODO This should be configurable by init parameters
131        // (but must be 16 if nonce length not 12) (BLOCK_SIZE?)
132//        this.tagLength = 16;
133
134        // Cipher always used in forward mode
135        // if keyParam is null we're reusing the last key.
136        if (keyParam != null)
137        {
138            cipher.init(true, keyParam);
139
140            this.H = new byte[BLOCK_SIZE];
141            cipher.processBlock(H, 0, H, 0);
142
143            // GCMMultiplier tables don't change unless the key changes (and are expensive to init)
144            multiplier.init(H);
145            exp = null;
146        }
147
148        this.J0 = new byte[BLOCK_SIZE];
149
150        if (nonce.length == 12)
151        {
152            System.arraycopy(nonce, 0, J0, 0, nonce.length);
153            this.J0[BLOCK_SIZE - 1] = 0x01;
154        }
155        else
156        {
157            gHASH(J0, nonce, nonce.length);
158            byte[] X = new byte[BLOCK_SIZE];
159            Pack.longToBigEndian((long)nonce.length * 8, X, 8);
160            gHASHBlock(J0, X);
161        }
162
163        this.S = new byte[BLOCK_SIZE];
164        this.S_at = new byte[BLOCK_SIZE];
165        this.S_atPre = new byte[BLOCK_SIZE];
166        this.atBlock = new byte[BLOCK_SIZE];
167        this.atBlockPos = 0;
168        this.atLength = 0;
169        this.atLengthPre = 0;
170        this.counter = Arrays.clone(J0);
171        this.bufOff = 0;
172        this.totalLength = 0;
173
174        if (initialAssociatedText != null)
175        {
176            processAADBytes(initialAssociatedText, 0, initialAssociatedText.length);
177        }
178    }
179
180    public byte[] getMac()
181    {
182        return Arrays.clone(macBlock);
183    }
184
185    public int getOutputSize(int len)
186    {
187        int totalData = len + bufOff;
188
189        if (forEncryption)
190        {
191             return totalData + macSize;
192        }
193
194        return totalData < macSize ? 0 : totalData - macSize;
195    }
196
197    public int getUpdateOutputSize(int len)
198    {
199        int totalData = len + bufOff;
200        if (!forEncryption)
201        {
202            if (totalData < macSize)
203            {
204                return 0;
205            }
206            totalData -= macSize;
207        }
208        return totalData - totalData % BLOCK_SIZE;
209    }
210
211    public void processAADByte(byte in)
212    {
213        atBlock[atBlockPos] = in;
214        if (++atBlockPos == BLOCK_SIZE)
215        {
216            // Hash each block as it fills
217            gHASHBlock(S_at, atBlock);
218            atBlockPos = 0;
219            atLength += BLOCK_SIZE;
220        }
221    }
222
223    public void processAADBytes(byte[] in, int inOff, int len)
224    {
225        for (int i = 0; i < len; ++i)
226        {
227            atBlock[atBlockPos] = in[inOff + i];
228            if (++atBlockPos == BLOCK_SIZE)
229            {
230                // Hash each block as it fills
231                gHASHBlock(S_at, atBlock);
232                atBlockPos = 0;
233                atLength += BLOCK_SIZE;
234            }
235        }
236    }
237
238    private void initCipher()
239    {
240        if (atLength > 0)
241        {
242            System.arraycopy(S_at, 0, S_atPre, 0, BLOCK_SIZE);
243            atLengthPre = atLength;
244        }
245
246        // Finish hash for partial AAD block
247        if (atBlockPos > 0)
248        {
249            gHASHPartial(S_atPre, atBlock, 0, atBlockPos);
250            atLengthPre += atBlockPos;
251        }
252
253        if (atLengthPre > 0)
254        {
255            System.arraycopy(S_atPre, 0, S, 0, BLOCK_SIZE);
256        }
257    }
258
259    public int processByte(byte in, byte[] out, int outOff)
260        throws DataLengthException
261    {
262        bufBlock[bufOff] = in;
263        if (++bufOff == bufBlock.length)
264        {
265            outputBlock(out, outOff);
266            return BLOCK_SIZE;
267        }
268        return 0;
269    }
270
271    public int processBytes(byte[] in, int inOff, int len, byte[] out, int outOff)
272        throws DataLengthException
273    {
274        int resultLen = 0;
275
276        for (int i = 0; i < len; ++i)
277        {
278            bufBlock[bufOff] = in[inOff + i];
279            if (++bufOff == bufBlock.length)
280            {
281                outputBlock(out, outOff + resultLen);
282                resultLen += BLOCK_SIZE;
283            }
284        }
285
286        return resultLen;
287    }
288
289    private void outputBlock(byte[] output, int offset)
290    {
291        if (totalLength == 0)
292        {
293            initCipher();
294        }
295        gCTRBlock(bufBlock, output, offset);
296        if (forEncryption)
297        {
298            bufOff = 0;
299        }
300        else
301        {
302            System.arraycopy(bufBlock, BLOCK_SIZE, bufBlock, 0, macSize);
303            bufOff = macSize;
304        }
305    }
306
307    public int doFinal(byte[] out, int outOff)
308        throws IllegalStateException, InvalidCipherTextException
309    {
310        if (totalLength == 0)
311        {
312            initCipher();
313        }
314
315        int extra = bufOff;
316        if (!forEncryption)
317        {
318            if (extra < macSize)
319            {
320                throw new InvalidCipherTextException("data too short");
321            }
322            extra -= macSize;
323        }
324
325        if (extra > 0)
326        {
327            gCTRPartial(bufBlock, 0, extra, out, outOff);
328        }
329
330        atLength += atBlockPos;
331
332        if (atLength > atLengthPre)
333        {
334            /*
335             *  Some AAD was sent after the cipher started. We determine the difference b/w the hash value
336             *  we actually used when the cipher started (S_atPre) and the final hash value calculated (S_at).
337             *  Then we carry this difference forward by multiplying by H^c, where c is the number of (full or
338             *  partial) cipher-text blocks produced, and adjust the current hash.
339             */
340
341            // Finish hash for partial AAD block
342            if (atBlockPos > 0)
343            {
344                gHASHPartial(S_at, atBlock, 0, atBlockPos);
345            }
346
347            // Find the difference between the AAD hashes
348            if (atLengthPre > 0)
349            {
350                xor(S_at, S_atPre);
351            }
352
353            // Number of cipher-text blocks produced
354            long c = ((totalLength * 8) + 127) >>> 7;
355
356            // Calculate the adjustment factor
357            byte[] H_c = new byte[16];
358            if (exp == null)
359            {
360                exp = new Tables1kGCMExponentiator();
361                exp.init(H);
362            }
363            exp.exponentiateX(c, H_c);
364
365            // Carry the difference forward
366            multiply(S_at, H_c);
367
368            // Adjust the current hash
369            xor(S, S_at);
370        }
371
372        // Final gHASH
373        byte[] X = new byte[BLOCK_SIZE];
374        Pack.longToBigEndian(atLength * 8, X, 0);
375        Pack.longToBigEndian(totalLength * 8, X, 8);
376
377        gHASHBlock(S, X);
378
379        // TODO Fix this if tagLength becomes configurable
380        // T = MSBt(GCTRk(J0,S))
381        byte[] tag = new byte[BLOCK_SIZE];
382        cipher.processBlock(J0, 0, tag, 0);
383        xor(tag, S);
384
385        int resultLen = extra;
386
387        // We place into macBlock our calculated value for T
388        this.macBlock = new byte[macSize];
389        System.arraycopy(tag, 0, macBlock, 0, macSize);
390
391        if (forEncryption)
392        {
393            // Append T to the message
394            System.arraycopy(macBlock, 0, out, outOff + bufOff, macSize);
395            resultLen += macSize;
396        }
397        else
398        {
399            // Retrieve the T value from the message and compare to calculated one
400            byte[] msgMac = new byte[macSize];
401            System.arraycopy(bufBlock, extra, msgMac, 0, macSize);
402            if (!Arrays.constantTimeAreEqual(this.macBlock, msgMac))
403            {
404                throw new InvalidCipherTextException("mac check in GCM failed");
405            }
406        }
407
408        reset(false);
409
410        return resultLen;
411    }
412
413    public void reset()
414    {
415        reset(true);
416    }
417
418    private void reset(
419        boolean clearMac)
420    {
421        cipher.reset();
422
423        S = new byte[BLOCK_SIZE];
424        S_at = new byte[BLOCK_SIZE];
425        S_atPre = new byte[BLOCK_SIZE];
426        atBlock = new byte[BLOCK_SIZE];
427        atBlockPos = 0;
428        atLength = 0;
429        atLengthPre = 0;
430        counter = Arrays.clone(J0);
431        bufOff = 0;
432        totalLength = 0;
433
434        if (bufBlock != null)
435        {
436            Arrays.fill(bufBlock, (byte)0);
437        }
438
439        if (clearMac)
440        {
441            macBlock = null;
442        }
443
444        if (initialAssociatedText != null)
445        {
446            processAADBytes(initialAssociatedText, 0, initialAssociatedText.length);
447        }
448    }
449
450    private void gCTRBlock(byte[] block, byte[] out, int outOff)
451    {
452        byte[] tmp = getNextCounterBlock();
453
454        xor(tmp, block);
455        System.arraycopy(tmp, 0, out, outOff, BLOCK_SIZE);
456
457        gHASHBlock(S, forEncryption ? tmp : block);
458
459        totalLength += BLOCK_SIZE;
460    }
461
462    private void gCTRPartial(byte[] buf, int off, int len, byte[] out, int outOff)
463    {
464        byte[] tmp = getNextCounterBlock();
465
466        xor(tmp, buf, off, len);
467        System.arraycopy(tmp, 0, out, outOff, len);
468
469        gHASHPartial(S, forEncryption ? tmp : buf, 0, len);
470
471        totalLength += len;
472    }
473
474    private void gHASH(byte[] Y, byte[] b, int len)
475    {
476        for (int pos = 0; pos < len; pos += BLOCK_SIZE)
477        {
478            int num = Math.min(len - pos, BLOCK_SIZE);
479            gHASHPartial(Y, b, pos, num);
480        }
481    }
482
483    private void gHASHBlock(byte[] Y, byte[] b)
484    {
485        xor(Y, b);
486        multiplier.multiplyH(Y);
487    }
488
489    private void gHASHPartial(byte[] Y, byte[] b, int off, int len)
490    {
491        xor(Y, b, off, len);
492        multiplier.multiplyH(Y);
493    }
494
495    private byte[] getNextCounterBlock()
496    {
497        for (int i = 15; i >= 12; --i)
498        {
499            byte b = (byte)((counter[i] + 1) & 0xff);
500            counter[i] = b;
501
502            if (b != 0)
503            {
504                break;
505            }
506        }
507
508        byte[] tmp = new byte[BLOCK_SIZE];
509        // TODO Sure would be nice if ciphers could operate on int[]
510        cipher.processBlock(counter, 0, tmp, 0);
511        return tmp;
512    }
513
514    private static void multiply(byte[] block, byte[] val)
515    {
516        byte[] tmp = Arrays.clone(block);
517        byte[] c = new byte[16];
518
519        for (int i = 0; i < 16; ++i)
520        {
521            byte bits = val[i];
522            for (int j = 7; j >= 0; --j)
523            {
524                if ((bits & (1 << j)) != 0)
525                {
526                    xor(c, tmp);
527                }
528
529                boolean lsb = (tmp[15] & 1) != 0;
530                shiftRight(tmp);
531                if (lsb)
532                {
533                    // R = new byte[]{ 0xe1, ... };
534//                    xor(v, R);
535                    tmp[0] ^= (byte)0xe1;
536                }
537            }
538        }
539
540        System.arraycopy(c, 0, block, 0, 16);
541    }
542
543    private static void shiftRight(byte[] block)
544    {
545        int i = 0;
546        int bit = 0;
547        for (;;)
548        {
549            int b = block[i] & 0xff;
550            block[i] = (byte) ((b >>> 1) | bit);
551            if (++i == 16)
552            {
553                break;
554            }
555            bit = (b & 1) << 7;
556        }
557    }
558
559    private static void xor(byte[] block, byte[] val)
560    {
561        for (int i = 15; i >= 0; --i)
562        {
563            block[i] ^= val[i];
564        }
565    }
566
567    private static void xor(byte[] block, byte[] val, int off, int len)
568    {
569        while (len-- > 0)
570        {
571            block[len] ^= val[off + len];
572        }
573    }
574}
575