1/*
2 * PKCS#1 encoding and decoding functions.
3 * This file is believed to contain no code licensed from other parties.
4 *
5 * This Source Code Form is subject to the terms of the Mozilla Public
6 * License, v. 2.0. If a copy of the MPL was not distributed with this
7 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
8
9#include "seccomon.h"
10#include "secerr.h"
11#include "sechash.h"
12
13/* Needed for RSA-PSS functions */
14static const unsigned char eightZeros[] = { 0, 0, 0, 0, 0, 0, 0, 0 };
15
16/*
17 * Mask generation function MGF1 as defined in PKCS #1 v2.1 / RFC 3447.
18 */
19static SECStatus
20MGF1(HASH_HashType hashAlg, unsigned char *mask, unsigned int maskLen,
21     const unsigned char *mgfSeed, unsigned int mgfSeedLen)
22{
23    unsigned int digestLen;
24    PRUint32 counter, rounds;
25    unsigned char *tempHash, *temp;
26    const SECHashObject *hash;
27    void *hashContext;
28    unsigned char C[4];
29
30    hash = HASH_GetHashObject(hashAlg);
31    if (hash == NULL)
32        return SECFailure;
33
34    hashContext = (*hash->create)();
35    rounds = (maskLen + hash->length - 1) / hash->length;
36    for (counter = 0; counter < rounds; counter++) {
37        C[0] = (unsigned char)((counter >> 24) & 0xff);
38        C[1] = (unsigned char)((counter >> 16) & 0xff);
39        C[2] = (unsigned char)((counter >> 8) & 0xff);
40        C[3] = (unsigned char)(counter & 0xff);
41
42        /* This could be optimized when the clone functions in
43         * rawhash.c are implemented. */
44        (*hash->begin)(hashContext);
45        (*hash->update)(hashContext, mgfSeed, mgfSeedLen);
46        (*hash->update)(hashContext, C, sizeof C);
47
48        tempHash = mask + counter * hash->length;
49        if (counter != (rounds-1)) {
50            (*hash->end)(hashContext, tempHash, &digestLen, hash->length);
51        } else { /* we're in the last round and need to cut the hash */
52            temp = (unsigned char *)PORT_Alloc(hash->length);
53            (*hash->end)(hashContext, temp, &digestLen, hash->length);
54            PORT_Memcpy(tempHash, temp, maskLen - counter * hash->length);
55            PORT_Free(temp);
56        }
57    }
58    (*hash->destroy)(hashContext, PR_TRUE);
59
60    return SECSuccess;
61}
62
63/*
64 * Verify a RSA-PSS signature.
65 * Described in RFC 3447, section 9.1.2.
66 * We use mHash instead of M as input.
67 * emBits from the RFC is just modBits - 1, see section 8.1.2.
68 * We only support MGF1 as the MGF.
69 *
70 * NOTE: this code assumes modBits is a multiple of 8.
71 */
72SECStatus
73emsa_pss_verify(const unsigned char *mHash,
74                const unsigned char *em, unsigned int emLen,
75                HASH_HashType hashAlg, HASH_HashType maskHashAlg,
76                unsigned int sLen)
77{
78    const SECHashObject *hash;
79    void *hash_context;
80    unsigned char *db;
81    unsigned char *H_;  /* H' from the RFC */
82    unsigned int i, dbMaskLen;
83    SECStatus rv;
84
85    hash = HASH_GetHashObject(hashAlg);
86    dbMaskLen = emLen - hash->length - 1;
87
88    /* Step 3 + 4 + 6 */
89    if ((emLen < (hash->length + sLen + 2)) ||
90	(em[emLen - 1] != 0xbc) ||
91	((em[0] & 0x80) != 0)) {
92	PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
93	return SECFailure;
94    }
95
96    /* Step 7 */
97    db = (unsigned char *)PORT_Alloc(dbMaskLen);
98    if (db == NULL) {
99	PORT_SetError(SEC_ERROR_NO_MEMORY);
100	return SECFailure;
101    }
102    /* &em[dbMaskLen] points to H, used as mgfSeed */
103    MGF1(maskHashAlg, db, dbMaskLen, &em[dbMaskLen], hash->length);
104
105    /* Step 8 */
106    for (i = 0; i < dbMaskLen; i++) {
107	db[i] ^= em[i];
108    }
109
110    /* Step 9 */
111    db[0] &= 0x7f;
112
113    /* Step 10 */
114    for (i = 0; i < (dbMaskLen - sLen - 1); i++) {
115	if (db[i] != 0) {
116	    PORT_Free(db);
117	    PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
118	    return SECFailure;
119	}
120    }
121    if (db[dbMaskLen - sLen - 1] != 0x01) {
122	PORT_Free(db);
123	PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
124	return SECFailure;
125    }
126
127    /* Step 12 + 13 */
128    H_ = (unsigned char *)PORT_Alloc(hash->length);
129    if (H_ == NULL) {
130	PORT_Free(db);
131	PORT_SetError(SEC_ERROR_NO_MEMORY);
132	return SECFailure;
133    }
134    hash_context = (*hash->create)();
135    if (hash_context == NULL) {
136	PORT_Free(db);
137	PORT_Free(H_);
138	PORT_SetError(SEC_ERROR_NO_MEMORY);
139	return SECFailure;
140    }
141    (*hash->begin)(hash_context);
142    (*hash->update)(hash_context, eightZeros, 8);
143    (*hash->update)(hash_context, mHash, hash->length);
144    (*hash->update)(hash_context, &db[dbMaskLen - sLen], sLen);
145    (*hash->end)(hash_context, H_, &i, hash->length);
146    (*hash->destroy)(hash_context, PR_TRUE);
147
148    PORT_Free(db);
149
150    /* Step 14 */
151    if (PORT_Memcmp(H_, &em[dbMaskLen], hash->length) != 0) {
152	PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
153	rv = SECFailure;
154    } else {
155	rv = SECSuccess;
156    }
157
158    PORT_Free(H_);
159    return rv;
160}
161