1package org.bouncycastle.crypto.engines;
2
3import org.bouncycastle.crypto.BlockCipher;
4import org.bouncycastle.crypto.CipherParameters;
5import org.bouncycastle.crypto.DataLengthException;
6import org.bouncycastle.crypto.InvalidCipherTextException;
7import org.bouncycastle.crypto.Wrapper;
8import org.bouncycastle.crypto.params.KeyParameter;
9import org.bouncycastle.crypto.params.ParametersWithIV;
10import org.bouncycastle.crypto.params.ParametersWithRandom;
11import org.bouncycastle.util.Arrays;
12
13/**
14 * an implementation of the AES Key Wrapper from the NIST Key Wrap
15 * Specification as described in RFC 3394.
16 * <p>
17 * For further details see: <a href="http://www.ietf.org/rfc/rfc3394.txt">http://www.ietf.org/rfc/rfc3394.txt</a>
18 * and  <a href="http://csrc.nist.gov/encryption/kms/key-wrap.pdf">http://csrc.nist.gov/encryption/kms/key-wrap.pdf</a>.
19 */
20public class RFC3394WrapEngine
21    implements Wrapper
22{
23    private BlockCipher     engine;
24    private KeyParameter    param;
25    private boolean         forWrapping;
26
27    private byte[]          iv = {
28                              (byte)0xa6, (byte)0xa6, (byte)0xa6, (byte)0xa6,
29                              (byte)0xa6, (byte)0xa6, (byte)0xa6, (byte)0xa6 };
30
31    public RFC3394WrapEngine(BlockCipher engine)
32    {
33        this.engine = engine;
34    }
35
36    public void init(
37        boolean             forWrapping,
38        CipherParameters    param)
39    {
40        this.forWrapping = forWrapping;
41
42        if (param instanceof ParametersWithRandom)
43        {
44            param = ((ParametersWithRandom) param).getParameters();
45        }
46
47        if (param instanceof KeyParameter)
48        {
49            this.param = (KeyParameter)param;
50        }
51        else if (param instanceof ParametersWithIV)
52        {
53            this.iv = ((ParametersWithIV)param).getIV();
54            this.param = (KeyParameter)((ParametersWithIV) param).getParameters();
55            if (this.iv.length != 8)
56            {
57               throw new IllegalArgumentException("IV not equal to 8");
58            }
59        }
60    }
61
62    public String getAlgorithmName()
63    {
64        return engine.getAlgorithmName();
65    }
66
67    public byte[] wrap(
68        byte[]  in,
69        int     inOff,
70        int     inLen)
71    {
72        if (!forWrapping)
73        {
74            throw new IllegalStateException("not set for wrapping");
75        }
76
77        int     n = inLen / 8;
78
79        if ((n * 8) != inLen)
80        {
81            throw new DataLengthException("wrap data must be a multiple of 8 bytes");
82        }
83
84        byte[]  block = new byte[inLen + iv.length];
85        byte[]  buf = new byte[8 + iv.length];
86
87        System.arraycopy(iv, 0, block, 0, iv.length);
88        System.arraycopy(in, 0, block, iv.length, inLen);
89
90        engine.init(true, param);
91
92        for (int j = 0; j != 6; j++)
93        {
94            for (int i = 1; i <= n; i++)
95            {
96                System.arraycopy(block, 0, buf, 0, iv.length);
97                System.arraycopy(block, 8 * i, buf, iv.length, 8);
98                engine.processBlock(buf, 0, buf, 0);
99
100                int t = n * j + i;
101                for (int k = 1; t != 0; k++)
102                {
103                    byte    v = (byte)t;
104
105                    buf[iv.length - k] ^= v;
106
107                    t >>>= 8;
108                }
109
110                System.arraycopy(buf, 0, block, 0, 8);
111                System.arraycopy(buf, 8, block, 8 * i, 8);
112            }
113        }
114
115        return block;
116    }
117
118    public byte[] unwrap(
119        byte[]  in,
120        int     inOff,
121        int     inLen)
122        throws InvalidCipherTextException
123    {
124        if (forWrapping)
125        {
126            throw new IllegalStateException("not set for unwrapping");
127        }
128
129        int     n = inLen / 8;
130
131        if ((n * 8) != inLen)
132        {
133            throw new InvalidCipherTextException("unwrap data must be a multiple of 8 bytes");
134        }
135
136        byte[]  block = new byte[inLen - iv.length];
137        byte[]  a = new byte[iv.length];
138        byte[]  buf = new byte[8 + iv.length];
139
140        System.arraycopy(in, 0, a, 0, iv.length);
141        System.arraycopy(in, iv.length, block, 0, inLen - iv.length);
142
143        engine.init(false, param);
144
145        n = n - 1;
146
147        for (int j = 5; j >= 0; j--)
148        {
149            for (int i = n; i >= 1; i--)
150            {
151                System.arraycopy(a, 0, buf, 0, iv.length);
152                System.arraycopy(block, 8 * (i - 1), buf, iv.length, 8);
153
154                int t = n * j + i;
155                for (int k = 1; t != 0; k++)
156                {
157                    byte    v = (byte)t;
158
159                    buf[iv.length - k] ^= v;
160
161                    t >>>= 8;
162                }
163
164                engine.processBlock(buf, 0, buf, 0);
165                System.arraycopy(buf, 0, a, 0, 8);
166                System.arraycopy(buf, 8, block, 8 * (i - 1), 8);
167            }
168        }
169
170        if (!Arrays.constantTimeAreEqual(a, iv))
171        {
172            throw new InvalidCipherTextException("checksum failed");
173        }
174
175        return block;
176    }
177}
178