1package org.bouncycastle.crypto.modes;
2
3import java.io.ByteArrayOutputStream;
4
5import org.bouncycastle.crypto.BlockCipher;
6import org.bouncycastle.crypto.CipherParameters;
7import org.bouncycastle.crypto.DataLengthException;
8import org.bouncycastle.crypto.InvalidCipherTextException;
9import org.bouncycastle.crypto.Mac;
10import org.bouncycastle.crypto.macs.CBCBlockCipherMac;
11import org.bouncycastle.crypto.params.AEADParameters;
12import org.bouncycastle.crypto.params.ParametersWithIV;
13import org.bouncycastle.util.Arrays;
14
15/**
16 * Implements the Counter with Cipher Block Chaining mode (CCM) detailed in
17 * NIST Special Publication 800-38C.
18 * <p>
19 * <b>Note</b>: this mode is a packet mode - it needs all the data up front.
20 */
21public class CCMBlockCipher
22    implements AEADBlockCipher
23{
24    private BlockCipher           cipher;
25    private int                   blockSize;
26    private boolean               forEncryption;
27    private byte[]                nonce;
28    private byte[]                initialAssociatedText;
29    private int                   macSize;
30    private CipherParameters      keyParam;
31    private byte[]                macBlock;
32    private ExposedByteArrayOutputStream associatedText = new ExposedByteArrayOutputStream();
33    private ExposedByteArrayOutputStream data = new ExposedByteArrayOutputStream();
34
35    /**
36     * Basic constructor.
37     *
38     * @param c the block cipher to be used.
39     */
40    public CCMBlockCipher(BlockCipher c)
41    {
42        this.cipher = c;
43        this.blockSize = c.getBlockSize();
44        this.macBlock = new byte[blockSize];
45
46        if (blockSize != 16)
47        {
48            throw new IllegalArgumentException("cipher required with a block size of 16.");
49        }
50    }
51
52    /**
53     * return the underlying block cipher that we are wrapping.
54     *
55     * @return the underlying block cipher that we are wrapping.
56     */
57    public BlockCipher getUnderlyingCipher()
58    {
59        return cipher;
60    }
61
62
63    public void init(boolean forEncryption, CipherParameters params)
64          throws IllegalArgumentException
65    {
66        this.forEncryption = forEncryption;
67
68        CipherParameters cipherParameters;
69        if (params instanceof AEADParameters)
70        {
71            AEADParameters param = (AEADParameters)params;
72
73            nonce = param.getNonce();
74            initialAssociatedText = param.getAssociatedText();
75            macSize = param.getMacSize() / 8;
76            cipherParameters = param.getKey();
77        }
78        else if (params instanceof ParametersWithIV)
79        {
80            ParametersWithIV param = (ParametersWithIV)params;
81
82            nonce = param.getIV();
83            initialAssociatedText = null;
84            macSize = macBlock.length / 2;
85            cipherParameters = param.getParameters();
86        }
87        else
88        {
89            throw new IllegalArgumentException("invalid parameters passed to CCM");
90        }
91
92        // NOTE: Very basic support for key re-use, but no performance gain from it
93        if (cipherParameters != null)
94        {
95            keyParam = cipherParameters;
96        }
97
98        if (nonce == null || nonce.length < 7 || nonce.length > 13)
99        {
100            throw new IllegalArgumentException("nonce must have length from 7 to 13 octets");
101        }
102
103        reset();
104    }
105
106    public String getAlgorithmName()
107    {
108        return cipher.getAlgorithmName() + "/CCM";
109    }
110
111    public void processAADByte(byte in)
112    {
113        associatedText.write(in);
114    }
115
116    public void processAADBytes(byte[] in, int inOff, int len)
117    {
118        // TODO: Process AAD online
119        associatedText.write(in, inOff, len);
120    }
121
122    public int processByte(byte in, byte[] out, int outOff)
123        throws DataLengthException, IllegalStateException
124    {
125        data.write(in);
126
127        return 0;
128    }
129
130    public int processBytes(byte[] in, int inOff, int inLen, byte[] out, int outOff)
131        throws DataLengthException, IllegalStateException
132    {
133        data.write(in, inOff, inLen);
134
135        return 0;
136    }
137
138    public int doFinal(byte[] out, int outOff)
139        throws IllegalStateException, InvalidCipherTextException
140    {
141        int len = processPacket(data.getBuffer(), 0, data.size(), out, outOff);
142
143        reset();
144
145        return len;
146    }
147
148    public void reset()
149    {
150        cipher.reset();
151        associatedText.reset();
152        data.reset();
153    }
154
155    /**
156     * Returns a byte array containing the mac calculated as part of the
157     * last encrypt or decrypt operation.
158     *
159     * @return the last mac calculated.
160     */
161    public byte[] getMac()
162    {
163        byte[] mac = new byte[macSize];
164
165        System.arraycopy(macBlock, 0, mac, 0, mac.length);
166
167        return mac;
168    }
169
170    public int getUpdateOutputSize(int len)
171    {
172        return 0;
173    }
174
175    public int getOutputSize(int len)
176    {
177        int totalData = len + data.size();
178
179        if (forEncryption)
180        {
181             return totalData + macSize;
182        }
183
184        return totalData < macSize ? 0 : totalData - macSize;
185    }
186
187    /**
188     * Process a packet of data for either CCM decryption or encryption.
189     *
190     * @param in data for processing.
191     * @param inOff offset at which data starts in the input array.
192     * @param inLen length of the data in the input array.
193     * @return a byte array containing the processed input..
194     * @throws IllegalStateException if the cipher is not appropriately set up.
195     * @throws InvalidCipherTextException if the input data is truncated or the mac check fails.
196     */
197    public byte[] processPacket(byte[] in, int inOff, int inLen)
198        throws IllegalStateException, InvalidCipherTextException
199    {
200        byte[] output;
201
202        if (forEncryption)
203        {
204            output = new byte[inLen + macSize];
205        }
206        else
207        {
208            if (inLen < macSize)
209            {
210                throw new InvalidCipherTextException("data too short");
211            }
212            output = new byte[inLen - macSize];
213        }
214
215        processPacket(in, inOff, inLen, output, 0);
216
217        return output;
218    }
219
220    /**
221     * Process a packet of data for either CCM decryption or encryption.
222     *
223     * @param in data for processing.
224     * @param inOff offset at which data starts in the input array.
225     * @param inLen length of the data in the input array.
226     * @param output output array.
227     * @param outOff offset into output array to start putting processed bytes.
228     * @return the number of bytes added to output.
229     * @throws IllegalStateException if the cipher is not appropriately set up.
230     * @throws InvalidCipherTextException if the input data is truncated or the mac check fails.
231     * @throws DataLengthException if output buffer too short.
232     */
233    public int processPacket(byte[] in, int inOff, int inLen, byte[] output, int outOff)
234        throws IllegalStateException, InvalidCipherTextException, DataLengthException
235    {
236        // TODO: handle null keyParam (e.g. via RepeatedKeySpec)
237        // Need to keep the CTR and CBC Mac parts around and reset
238        if (keyParam == null)
239        {
240            throw new IllegalStateException("CCM cipher unitialized.");
241        }
242
243        int n = nonce.length;
244        int q = 15 - n;
245        if (q < 4)
246        {
247            int limitLen = 1 << (8 * q);
248            if (inLen >= limitLen)
249            {
250                throw new IllegalStateException("CCM packet too large for choice of q.");
251            }
252        }
253
254        byte[] iv = new byte[blockSize];
255        iv[0] = (byte)((q - 1) & 0x7);
256        System.arraycopy(nonce, 0, iv, 1, nonce.length);
257
258        BlockCipher ctrCipher = new SICBlockCipher(cipher);
259        ctrCipher.init(forEncryption, new ParametersWithIV(keyParam, iv));
260
261        int outputLen;
262        int inIndex = inOff;
263        int outIndex = outOff;
264
265        if (forEncryption)
266        {
267            outputLen = inLen + macSize;
268            if (output.length < (outputLen + outOff))
269            {
270                throw new DataLengthException("Output buffer too short.");
271            }
272
273            calculateMac(in, inOff, inLen, macBlock);
274
275            ctrCipher.processBlock(macBlock, 0, macBlock, 0);   // S0
276
277            while (inIndex < (inOff + inLen - blockSize))                 // S1...
278            {
279                ctrCipher.processBlock(in, inIndex, output, outIndex);
280                outIndex += blockSize;
281                inIndex += blockSize;
282            }
283
284            byte[] block = new byte[blockSize];
285
286            System.arraycopy(in, inIndex, block, 0, inLen + inOff - inIndex);
287
288            ctrCipher.processBlock(block, 0, block, 0);
289
290            System.arraycopy(block, 0, output, outIndex, inLen + inOff - inIndex);
291
292            System.arraycopy(macBlock, 0, output, outOff + inLen, macSize);
293        }
294        else
295        {
296            if (inLen < macSize)
297            {
298                throw new InvalidCipherTextException("data too short");
299            }
300            outputLen = inLen - macSize;
301            if (output.length < (outputLen + outOff))
302            {
303                throw new DataLengthException("Output buffer too short.");
304            }
305
306            System.arraycopy(in, inOff + outputLen, macBlock, 0, macSize);
307
308            ctrCipher.processBlock(macBlock, 0, macBlock, 0);
309
310            for (int i = macSize; i != macBlock.length; i++)
311            {
312                macBlock[i] = 0;
313            }
314
315            while (inIndex < (inOff + outputLen - blockSize))
316            {
317                ctrCipher.processBlock(in, inIndex, output, outIndex);
318                outIndex += blockSize;
319                inIndex += blockSize;
320            }
321
322            byte[] block = new byte[blockSize];
323
324            System.arraycopy(in, inIndex, block, 0, outputLen - (inIndex - inOff));
325
326            ctrCipher.processBlock(block, 0, block, 0);
327
328            System.arraycopy(block, 0, output, outIndex, outputLen - (inIndex - inOff));
329
330            byte[] calculatedMacBlock = new byte[blockSize];
331
332            calculateMac(output, outOff, outputLen, calculatedMacBlock);
333
334            if (!Arrays.constantTimeAreEqual(macBlock, calculatedMacBlock))
335            {
336                throw new InvalidCipherTextException("mac check in CCM failed");
337            }
338        }
339
340        return outputLen;
341    }
342
343    private int calculateMac(byte[] data, int dataOff, int dataLen, byte[] macBlock)
344    {
345        Mac cMac = new CBCBlockCipherMac(cipher, macSize * 8);
346
347        cMac.init(keyParam);
348
349        //
350        // build b0
351        //
352        byte[] b0 = new byte[16];
353
354        if (hasAssociatedText())
355        {
356            b0[0] |= 0x40;
357        }
358
359        b0[0] |= (((cMac.getMacSize() - 2) / 2) & 0x7) << 3;
360
361        b0[0] |= ((15 - nonce.length) - 1) & 0x7;
362
363        System.arraycopy(nonce, 0, b0, 1, nonce.length);
364
365        int q = dataLen;
366        int count = 1;
367        while (q > 0)
368        {
369            b0[b0.length - count] = (byte)(q & 0xff);
370            q >>>= 8;
371            count++;
372        }
373
374        cMac.update(b0, 0, b0.length);
375
376        //
377        // process associated text
378        //
379        if (hasAssociatedText())
380        {
381            int extra;
382
383            int textLength = getAssociatedTextLength();
384            if (textLength < ((1 << 16) - (1 << 8)))
385            {
386                cMac.update((byte)(textLength >> 8));
387                cMac.update((byte)textLength);
388
389                extra = 2;
390            }
391            else // can't go any higher than 2^32
392            {
393                cMac.update((byte)0xff);
394                cMac.update((byte)0xfe);
395                cMac.update((byte)(textLength >> 24));
396                cMac.update((byte)(textLength >> 16));
397                cMac.update((byte)(textLength >> 8));
398                cMac.update((byte)textLength);
399
400                extra = 6;
401            }
402
403            if (initialAssociatedText != null)
404            {
405                cMac.update(initialAssociatedText, 0, initialAssociatedText.length);
406            }
407            if (associatedText.size() > 0)
408            {
409                cMac.update(associatedText.getBuffer(), 0, associatedText.size());
410            }
411
412            extra = (extra + textLength) % 16;
413            if (extra != 0)
414            {
415                for (int i = extra; i != 16; i++)
416                {
417                    cMac.update((byte)0x00);
418                }
419            }
420        }
421
422        //
423        // add the text
424        //
425        cMac.update(data, dataOff, dataLen);
426
427        return cMac.doFinal(macBlock, 0);
428    }
429
430    private int getAssociatedTextLength()
431    {
432        return associatedText.size() + ((initialAssociatedText == null) ? 0 : initialAssociatedText.length);
433    }
434
435    private boolean hasAssociatedText()
436    {
437        return getAssociatedTextLength() > 0;
438    }
439
440    private class ExposedByteArrayOutputStream
441        extends ByteArrayOutputStream
442    {
443        public ExposedByteArrayOutputStream()
444        {
445        }
446
447        public byte[] getBuffer()
448        {
449            return this.buf;
450        }
451    }
452}
453