1package org.bouncycastle.crypto.signers;
2
3import java.security.SecureRandom;
4
5import org.bouncycastle.crypto.AsymmetricBlockCipher;
6import org.bouncycastle.crypto.CipherParameters;
7import org.bouncycastle.crypto.CryptoException;
8import org.bouncycastle.crypto.DataLengthException;
9import org.bouncycastle.crypto.Digest;
10import org.bouncycastle.crypto.Signer;
11import org.bouncycastle.crypto.params.ParametersWithRandom;
12import org.bouncycastle.crypto.params.RSAKeyParameters;
13
14/**
15 * RSA-PSS as described in PKCS# 1 v 2.1.
16 * <p>
17 * Note: the usual value for the salt length is the number of
18 * bytes in the hash function.
19 */
20public class PSSSigner
21    implements Signer
22{
23    static final public byte   TRAILER_IMPLICIT    = (byte)0xBC;
24
25    private Digest                      digest;
26    private AsymmetricBlockCipher       cipher;
27    private SecureRandom                random;
28
29    private int                         hLen;
30    private int                         sLen;
31    private int                         emBits;
32    private byte[]                      salt;
33    private byte[]                      mDash;
34    private byte[]                      block;
35    private byte                        trailer;
36
37    /**
38     * basic constructor
39     *
40     * @param cipher the assymetric cipher to use.
41     * @param digest the digest to use.
42     * @param sLen the length of the salt to use (in bytes).
43     */
44    public PSSSigner(
45        AsymmetricBlockCipher   cipher,
46        Digest                  digest,
47        int                     sLen)
48    {
49        this(cipher, digest, sLen, TRAILER_IMPLICIT);
50    }
51
52    public PSSSigner(
53        AsymmetricBlockCipher   cipher,
54        Digest                  digest,
55        int                     sLen,
56        byte                    trailer)
57    {
58        this.cipher = cipher;
59        this.digest = digest;
60        this.hLen = digest.getDigestSize();
61        this.sLen = sLen;
62        this.salt = new byte[sLen];
63        this.mDash = new byte[8 + sLen + hLen];
64        this.trailer = trailer;
65    }
66
67    public void init(
68        boolean                 forSigning,
69        CipherParameters        param)
70    {
71        RSAKeyParameters  kParam = null;
72
73        if (param instanceof ParametersWithRandom)
74        {
75            ParametersWithRandom    p = (ParametersWithRandom)param;
76
77            kParam = (RSAKeyParameters)p.getParameters();
78            random = p.getRandom();
79        }
80        else
81        {
82            kParam = (RSAKeyParameters)param;
83            if (forSigning)
84            {
85                random = new SecureRandom();
86            }
87        }
88
89        cipher.init(forSigning, kParam);
90
91        emBits = kParam.getModulus().bitLength() - 1;
92
93        block = new byte[(emBits + 7) / 8];
94
95        reset();
96    }
97
98    /**
99     * clear possible sensitive data
100     */
101    private void clearBlock(
102        byte[]  block)
103    {
104        for (int i = 0; i != block.length; i++)
105        {
106            block[i] = 0;
107        }
108    }
109
110    /**
111     * update the internal digest with the byte b
112     */
113    public void update(
114        byte    b)
115    {
116        digest.update(b);
117    }
118
119    /**
120     * update the internal digest with the byte array in
121     */
122    public void update(
123        byte[]  in,
124        int     off,
125        int     len)
126    {
127        digest.update(in, off, len);
128    }
129
130    /**
131     * reset the internal state
132     */
133    public void reset()
134    {
135        digest.reset();
136    }
137
138    /**
139     * generate a signature for the message we've been loaded with using
140     * the key we were initialised with.
141     */
142    public byte[] generateSignature()
143        throws CryptoException, DataLengthException
144    {
145        if (emBits < (8 * hLen + 8 * sLen + 9))
146        {
147            throw new DataLengthException("encoding error");
148        }
149
150        digest.doFinal(mDash, mDash.length - hLen - sLen);
151
152        if (sLen != 0)
153        {
154            random.nextBytes(salt);
155
156            System.arraycopy(salt, 0, mDash, mDash.length - sLen, sLen);
157        }
158
159        byte[]  h = new byte[hLen];
160
161        digest.update(mDash, 0, mDash.length);
162
163        digest.doFinal(h, 0);
164
165        block[block.length - sLen - 1 - hLen - 1] = 0x01;
166        System.arraycopy(salt, 0, block, block.length - sLen - hLen - 1, sLen);
167
168        byte[] dbMask = maskGeneratorFunction1(h, 0, h.length, block.length - hLen - 1);
169        for (int i = 0; i != dbMask.length; i++)
170        {
171            block[i] ^= dbMask[i];
172        }
173
174        block[0] &= (0xff >> ((block.length * 8) - emBits));
175
176        System.arraycopy(h, 0, block, block.length - hLen - 1, hLen);
177
178        block[block.length - 1] = trailer;
179
180        byte[]  b = cipher.processBlock(block, 0, block.length);
181
182        clearBlock(block);
183
184        return b;
185    }
186
187    /**
188     * return true if the internal state represents the signature described
189     * in the passed in array.
190     */
191    public boolean verifySignature(
192        byte[]      signature)
193    {
194        if (emBits < (8 * hLen + 8 * sLen + 9))
195        {
196            return false;
197        }
198
199        digest.doFinal(mDash, mDash.length - hLen - sLen);
200
201        try
202        {
203            byte[] b = cipher.processBlock(signature, 0, signature.length);
204            System.arraycopy(b, 0, block, block.length - b.length, b.length);
205        }
206        catch (Exception e)
207        {
208            return false;
209        }
210
211        if (block[block.length - 1] != trailer)
212        {
213            clearBlock(block);
214            return false;
215        }
216
217        byte[] dbMask = maskGeneratorFunction1(block, block.length - hLen - 1, hLen, block.length - hLen - 1);
218
219        for (int i = 0; i != dbMask.length; i++)
220        {
221            block[i] ^= dbMask[i];
222        }
223
224        block[0] &= (0xff >> ((block.length * 8) - emBits));
225
226        for (int i = 0; i != block.length - hLen - sLen - 2; i++)
227        {
228            if (block[i] != 0)
229            {
230                clearBlock(block);
231                return false;
232            }
233        }
234
235        if (block[block.length - hLen - sLen - 2] != 0x01)
236        {
237            clearBlock(block);
238            return false;
239        }
240
241        System.arraycopy(block, block.length - sLen - hLen - 1, mDash, mDash.length - sLen, sLen);
242
243        digest.update(mDash, 0, mDash.length);
244        digest.doFinal(mDash, mDash.length - hLen);
245
246        for (int i = block.length - hLen - 1, j = mDash.length - hLen;
247                                                 j != mDash.length; i++, j++)
248        {
249            if ((block[i] ^ mDash[j]) != 0)
250            {
251                clearBlock(mDash);
252                clearBlock(block);
253                return false;
254            }
255        }
256
257        clearBlock(mDash);
258        clearBlock(block);
259
260        return true;
261    }
262
263    /**
264     * int to octet string.
265     */
266    private void ItoOSP(
267        int     i,
268        byte[]  sp)
269    {
270        sp[0] = (byte)(i >>> 24);
271        sp[1] = (byte)(i >>> 16);
272        sp[2] = (byte)(i >>> 8);
273        sp[3] = (byte)(i >>> 0);
274    }
275
276    /**
277     * mask generator function, as described in PKCS1v2.
278     */
279    private byte[] maskGeneratorFunction1(
280        byte[]  Z,
281        int     zOff,
282        int     zLen,
283        int     length)
284    {
285        byte[]  mask = new byte[length];
286        byte[]  hashBuf = new byte[hLen];
287        byte[]  C = new byte[4];
288        int     counter = 0;
289
290        digest.reset();
291
292        while (counter < (length / hLen))
293        {
294            ItoOSP(counter, C);
295
296            digest.update(Z, zOff, zLen);
297            digest.update(C, 0, C.length);
298            digest.doFinal(hashBuf, 0);
299
300            System.arraycopy(hashBuf, 0, mask, counter * hLen, hLen);
301
302            counter++;
303        }
304
305        if ((counter * hLen) < length)
306        {
307            ItoOSP(counter, C);
308
309            digest.update(Z, zOff, zLen);
310            digest.update(C, 0, C.length);
311            digest.doFinal(hashBuf, 0);
312
313            System.arraycopy(hashBuf, 0, mask, counter * hLen, mask.length - (counter * hLen));
314        }
315
316        return mask;
317    }
318}
319